Example stateful applications

This article contains code examples for custom stateful applications. Databricks recommends using built-in stateful methods for common operations such as aggregations and joins.

The patterns in this article use the transformWithState operator and associated classes available in Public Preview in Databricks Runtime 16.2 and above. See Build a custom stateful application.

Note

Python uses the transformWithStateInPandas operator to provide the same functionality. The examples below provide code in Python and Scala.

Requirements

The transformWithState operator and the related APIs and classes have the following requirements:

  • Available in Databricks Runtime 16.2 and above.
  • Compute must use dedicated or no-isolation access mode.
  • You must use the RocksDB state store provider. Databricks recommends enabling RocksDB as part of the compute configuration.

Note

To enable the RocksDB state store provider for the current session, run the following:

spark.conf.set("spark.sql.streaming.stateStore.providerClass", "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider")

Slowly changing dimension (SCD) type 1

The following code is an example of implementing SCD type 1 using transformWithState. SCD type 1 only tracks the most recent value for a given field.

Note

You can use Streaming tables and APPLY CHANGES INTO to implement SCD type 1 or type 2 using Delta Lake-backed tables. This example implements SCD type 1 in the state store, which provides lower latency for near real-time applications.

Python

import pandas as pd
from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle
from pyspark.sql.types import StructType, StructField, LongType, StringType
from typing import Iterator
spark.conf.set("spark.sql.streaming.stateStore.providerClass","org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider")

output_schema = StructType([
    StructField("user", StringType(), True),
    StructField("time", LongType(), True),
    StructField("location", StringType(), True)
])

class SCDType1StatefulProcessor(StatefulProcessor):
  def init(self, handle: StatefulProcessorHandle) -> None:
    value_state_schema = StructType([
        StructField("user", StringType(), True),
        StructField("time", LongType(), True),
        StructField("location", StringType(), True)
    ])
    self.latest_location = handle.getValueState("latestLocation", value_state_schema)

  def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]:
    max_row = None
    max_time = float('-inf')
    for pdf in rows:
      for _, pd_row in pdf.iterrows():
        time_value = pd_row["time"]
        if time_value > max_time:
            max_time = time_value
            max_row = tuple(pd_row)
    exists = self.latest_location.exists()
    if not exists or max_row[1] > self.latest_location.get()[1]:
      self.latest_location.update(max_row)
      yield pd.DataFrame(
              {"user": (max_row[0],), "time": (max_row[1],), "location": (max_row[2],)}
          )
    yield pd.DataFrame()

  def close(self) -> None:
    pass

(df.groupBy("user")
  .transformWithStateInPandas(
      statefulProcessor=SCDType1StatefulProcessor(),
      outputStructType=output_schema,
      outputMode="Update",
      timeMode="None",
  )
  .writeStream...
)

Scala

case class UserLocation(
    user: String,
    time: Long,
    location: String)

class SCDType1StatefulProcessor extends StatefulProcessor[String, UserLocation, UserLocation] {
  import org.apache.spark.sql.{Encoders}

  @transient private var _latestLocation: ValueState[UserLocation] = _

  override def init(
      outputMode: OutputMode,
      timeMode: TimeMode): Unit = {
    _latestLocation = getHandle.getValueState[UserLocation]("locationState",
      Encoders.product[UserLocation], TTLConfig.NONE)
  }

  override def handleInputRows(
      key: String,
      inputRows: Iterator[UserLocation],
      timerValues: TimerValues): Iterator[UserLocation] = {
    val maxNewLocation = inputRows.maxBy(_.time)
    if (_latestLocation.getOption().isEmpty || maxNewLocation.time > _latestLocation.get().time) {
      _latestLocation.update(maxNewLocation)
      Iterator.single(maxNewLocation)
    } else {
      Iterator.empty
    }
  }
}

Downtime detector

transformWithState implements timers to allow you to take action based on elapsed time, even if no records for a given key are processed in a microbatch.

The following example implements a pattern for a downtime detector. Each time a new value is seen for a given key, it updates the lastSeen state value, clears any existing timers, and resets a timer for the future.

When a timer expires, the application emits the elapsed time since the last observed event for the key. It then sets a new timer to emit an update 10 seconds later.

Python

import datetime
import time

class DownTimeDetectorStatefulProcessor(StatefulProcessor):
    def init(self, handle: StatefulProcessorHandle) -> None:
        state_schema = StructType([StructField("value", TimestampType(), True)])
        self.handle = handle
        self.last_seen = handle.getValueState("last_seen", state_schema)

    def handleExpiredTimer(self, key, timerValues, expiredTimerInfo) -> Iterator[pd.DataFrame]:
        latest_from_existing = self.last_seen.get()
        downtime_duration = timerValues.getCurrentProcessingTimeInMs() - int(time.time() * 1000)
        self.handle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + 10000)
        yield pd.DataFrame(
            {
                "id": key,
                "timeValues": str(downtime_duration),
            }
        )

    def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
        max_row = max((tuple(pdf.iloc[0]) for pdf in rows), key=lambda row: row[1])
        if self.last_seen.exists():
            latest_from_existing = self.last_seen.get()
        else:
            latest_from_existing = datetime.fromtimestamp(0)

        if latest_from_existing < max_row[1]:
            for timer in self.handle.listTimers():
                self.handle.deleteTimer(timer)
            self.last_seen.update((max_row[1],))

        self.handle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + 5000)

        timestamp_in_millis = str(timerValues.getCurrentProcessingTimeInMs())

        yield pd.DataFrame({"id": key, "timeValues": timestamp_in_millis})

    def close(self) -> None:
        pass

Scala

import java.sql.Timestamp
import org.apache.spark.sql.Encoders

// The (String, Timestamp) schema represents an (id, time). We want to do downtime
// detection on every single unique sensor, where each sensor has a sensor ID.
class DowntimeDetector(duration: Duration) extends
  StatefulProcessor[String, (String, Timestamp), (String, Duration)] {

  @transient private var _lastSeen: ValueState[Timestamp] = _

  override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
    _lastSeen = getHandle.getValueState[Timestamp]("lastSeen", Encoders.TIMESTAMP, TTLConfig.NONE)
  }

  // The logic here is as follows: find the largest timestamp seen so far. Set a timer for
  // the duration later.
  override def handleInputRows(
      key: String,
      inputRows: Iterator[(String, Timestamp)],
      timerValues: TimerValues): Iterator[(String, Duration)] = {
    val latestRecordFromNewRows = inputRows.maxBy(_._2.getTime)

    // Use getOrElse to initiate state variable if it doesn't exist
    val latestTimestampFromExistingRows = _lastSeen.getOption().getOrElse(new Timestamp(0))
    val latestTimestampFromNewRows = latestRecordFromNewRows._2

    if (latestTimestampFromNewRows.after(latestTimestampFromExistingRows)) {
      // Cancel the one existing timer, since we have a new latest timestamp.
      // We call "listTimers()" just because we don't know ahead of time what
      // the timestamp of the existing timer is.
      getHandle.listTimers().foreach(timer => getHandle.deleteTimer(timer))

      _lastSeen.update(latestTimestampFromNewRows)
      // Use timerValues to schedule a timer using processing time.
      getHandle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + duration.toMillis)
    } else {
      // No new latest timestamp, so no need to update state or set a timer.
    }

    Iterator.empty
  }

  override def handleExpiredTimer(
    key: String,
    timerValues: TimerValues,
    expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, Duration)] = {
      val latestTimestamp = _lastSeen.get()
      val downtimeDuration = new Duration(
        timerValues.getCurrentProcessingTimeInMs() - latestTimestamp.getTime)

      // Register another timer that will fire in 10 seconds.
      // Timers can be registered anywhere but init()
      getHandle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + 10000)

      Iterator((key, downtimeDuration))
  }
}

Migrate existing state information

The following example demonstrates how to implement a stateful application that accepts an initial state. You can add initial state handling to any stateful application, but the initial state can only be set when first initializing the application.

This example uses the statestore reader to load existing state information from a checkpoint path. An example use case for this pattern is migrating from legacy stateful applications to transformWithState.

Python

import pandas as pd
from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle
from pyspark.sql.types import StructType, StructField, LongType, StringType
from typing import Iterator
spark.conf.set("spark.sql.streaming.stateStore.providerClass","org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider")

"""
Input schema is as below

input_schema = StructType(
    [StructField("id", StringType(), True)],
    [StructField("value", StringType(), True)]
)
"""

output_schema = StructType([
    StructField("id", StringType(), True),
    StructField("accumulated", StringType(), True)
])

class AccumulatedCounterStatefulProcessorWithInitialState(StatefulProcessor):

    def init(self, handle: StatefulProcessorHandle) -> None:
        state_schema = StructType([StructField("value", IntegerType(), True)])
        self.counter_state = handle.getValueState("counter_state", state_schema)
        self.handle = handle

    def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
        exists = self.counter_state.exists()
        if exists:
            value_row = self.counter_state.get()
            existing_value = value_row[0]
        else:
            existing_value = 0

        accumulated_value = existing_value

        for pdf in rows:
            value = pdf["value"].astype(int).sum()
            accumulated_value += value

        self.counter_state.update((accumulated_value,))

        yield pd.DataFrame({"id": key, "accumulated": str(accumulated_value)})

    def handleInitialState(self, key, initialState, timerValues) -> None:
        init_val = initialState.at[0, "initVal"]
        self.counter_state.update((init_val,))

    def close(self) -> None:
        pass

initial_state = spark.read.format("statestore")
  .option("path", "$checkpointsDir")
  .load()

df.groupBy("id")
  .transformWithStateInPandas(
      statefulProcessor=AccumulatedCounterStatefulProcessorWithInitialState(),
      outputStructType=output_schema,
      outputMode="Update",
      timeMode="None",
      initialState=initial_state,
  )
  .writeStream...

Scala

import org.apache.spark.sql.streaming._
import org.apache.spark.sql.{Dataset, Encoder, Encoders , DataFrame}
import org.apache.spark.sql.types._

class InitialStateStatefulProcessor extends StatefulProcessorWithInitialState[String, (String, String, String), (String, String), (String, Int)] {
  @transient protected var valueState: ValueState[Int] = _

  override def init(
      outputMode: OutputMode,
      timeMode: TimeMode): Unit = {
    valueState = getHandle.getValueState[Int]("valueState",
      Encoders.scalaInt, TTLConfig.NONE)
  }

  override def handleInputRows(
      key: String,
      inputRows: Iterator[(String, String, String)],
      timerValues: TimerValues): Iterator[(String, String)] = {
    var existingValue = 0
    if (valueState.exists()) {
      existingValue += valueState.get()
    }
    var accumulatedValue = existingValue
    for (row <- inputRows) {
      accumulatedValue += row._2.toInt
    }
    valueState.update(accumulatedValue)
    Iterator((key, accumulatedValue.toString))
  }

  override def handleInitialState(
      key: String, initialState: (String, Int), timerValues: TimerValues): Unit = {
    valueState.update(initialState._2)
  }
}