Example stateful applications
This article contains code examples for custom stateful applications. Databricks recommends using built-in stateful methods for common operations such as aggregations and joins.
The patterns in this article use the transformWithState
operator and associated classes available in Public Preview in Databricks Runtime 16.2 and above. See Build a custom stateful application.
Note
Python uses the transformWithStateInPandas
operator to provide the same functionality. The examples below provide code in Python and Scala.
Requirements
The transformWithState
operator and the related APIs and classes have the following requirements:
- Available in Databricks Runtime 16.2 and above.
- Compute must use dedicated or no-isolation access mode.
- You must use the RocksDB state store provider. Databricks recommends enabling RocksDB as part of the compute configuration.
Note
To enable the RocksDB state store provider for the current session, run the following:
spark.conf.set("spark.sql.streaming.stateStore.providerClass", "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider")
Slowly changing dimension (SCD) type 1
The following code is an example of implementing SCD type 1 using transformWithState
. SCD type 1 only tracks the most recent value for a given field.
Note
You can use Streaming tables and APPLY CHANGES INTO
to implement SCD type 1 or type 2 using Delta Lake-backed tables. This example implements SCD type 1 in the state store, which provides lower latency for near real-time applications.
Python
import pandas as pd
from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle
from pyspark.sql.types import StructType, StructField, LongType, StringType
from typing import Iterator
spark.conf.set("spark.sql.streaming.stateStore.providerClass","org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider")
output_schema = StructType([
StructField("user", StringType(), True),
StructField("time", LongType(), True),
StructField("location", StringType(), True)
])
class SCDType1StatefulProcessor(StatefulProcessor):
def init(self, handle: StatefulProcessorHandle) -> None:
value_state_schema = StructType([
StructField("user", StringType(), True),
StructField("time", LongType(), True),
StructField("location", StringType(), True)
])
self.latest_location = handle.getValueState("latestLocation", value_state_schema)
def handleInputRows(self, key, rows, timer_values) -> Iterator[pd.DataFrame]:
max_row = None
max_time = float('-inf')
for pdf in rows:
for _, pd_row in pdf.iterrows():
time_value = pd_row["time"]
if time_value > max_time:
max_time = time_value
max_row = tuple(pd_row)
exists = self.latest_location.exists()
if not exists or max_row[1] > self.latest_location.get()[1]:
self.latest_location.update(max_row)
yield pd.DataFrame(
{"user": (max_row[0],), "time": (max_row[1],), "location": (max_row[2],)}
)
yield pd.DataFrame()
def close(self) -> None:
pass
(df.groupBy("user")
.transformWithStateInPandas(
statefulProcessor=SCDType1StatefulProcessor(),
outputStructType=output_schema,
outputMode="Update",
timeMode="None",
)
.writeStream...
)
Scala
case class UserLocation(
user: String,
time: Long,
location: String)
class SCDType1StatefulProcessor extends StatefulProcessor[String, UserLocation, UserLocation] {
import org.apache.spark.sql.{Encoders}
@transient private var _latestLocation: ValueState[UserLocation] = _
override def init(
outputMode: OutputMode,
timeMode: TimeMode): Unit = {
_latestLocation = getHandle.getValueState[UserLocation]("locationState",
Encoders.product[UserLocation], TTLConfig.NONE)
}
override def handleInputRows(
key: String,
inputRows: Iterator[UserLocation],
timerValues: TimerValues): Iterator[UserLocation] = {
val maxNewLocation = inputRows.maxBy(_.time)
if (_latestLocation.getOption().isEmpty || maxNewLocation.time > _latestLocation.get().time) {
_latestLocation.update(maxNewLocation)
Iterator.single(maxNewLocation)
} else {
Iterator.empty
}
}
}
Downtime detector
transformWithState
implements timers to allow you to take action based on elapsed time, even if no records for a given key are processed in a microbatch.
The following example implements a pattern for a downtime detector. Each time a new value is seen for a given key, it updates the lastSeen
state value, clears any existing timers, and resets a timer for the future.
When a timer expires, the application emits the elapsed time since the last observed event for the key. It then sets a new timer to emit an update 10 seconds later.
Python
import datetime
import time
class DownTimeDetectorStatefulProcessor(StatefulProcessor):
def init(self, handle: StatefulProcessorHandle) -> None:
state_schema = StructType([StructField("value", TimestampType(), True)])
self.handle = handle
self.last_seen = handle.getValueState("last_seen", state_schema)
def handleExpiredTimer(self, key, timerValues, expiredTimerInfo) -> Iterator[pd.DataFrame]:
latest_from_existing = self.last_seen.get()
downtime_duration = timerValues.getCurrentProcessingTimeInMs() - int(time.time() * 1000)
self.handle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + 10000)
yield pd.DataFrame(
{
"id": key,
"timeValues": str(downtime_duration),
}
)
def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
max_row = max((tuple(pdf.iloc[0]) for pdf in rows), key=lambda row: row[1])
if self.last_seen.exists():
latest_from_existing = self.last_seen.get()
else:
latest_from_existing = datetime.fromtimestamp(0)
if latest_from_existing < max_row[1]:
for timer in self.handle.listTimers():
self.handle.deleteTimer(timer)
self.last_seen.update((max_row[1],))
self.handle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + 5000)
timestamp_in_millis = str(timerValues.getCurrentProcessingTimeInMs())
yield pd.DataFrame({"id": key, "timeValues": timestamp_in_millis})
def close(self) -> None:
pass
Scala
import java.sql.Timestamp
import org.apache.spark.sql.Encoders
// The (String, Timestamp) schema represents an (id, time). We want to do downtime
// detection on every single unique sensor, where each sensor has a sensor ID.
class DowntimeDetector(duration: Duration) extends
StatefulProcessor[String, (String, Timestamp), (String, Duration)] {
@transient private var _lastSeen: ValueState[Timestamp] = _
override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
_lastSeen = getHandle.getValueState[Timestamp]("lastSeen", Encoders.TIMESTAMP, TTLConfig.NONE)
}
// The logic here is as follows: find the largest timestamp seen so far. Set a timer for
// the duration later.
override def handleInputRows(
key: String,
inputRows: Iterator[(String, Timestamp)],
timerValues: TimerValues): Iterator[(String, Duration)] = {
val latestRecordFromNewRows = inputRows.maxBy(_._2.getTime)
// Use getOrElse to initiate state variable if it doesn't exist
val latestTimestampFromExistingRows = _lastSeen.getOption().getOrElse(new Timestamp(0))
val latestTimestampFromNewRows = latestRecordFromNewRows._2
if (latestTimestampFromNewRows.after(latestTimestampFromExistingRows)) {
// Cancel the one existing timer, since we have a new latest timestamp.
// We call "listTimers()" just because we don't know ahead of time what
// the timestamp of the existing timer is.
getHandle.listTimers().foreach(timer => getHandle.deleteTimer(timer))
_lastSeen.update(latestTimestampFromNewRows)
// Use timerValues to schedule a timer using processing time.
getHandle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + duration.toMillis)
} else {
// No new latest timestamp, so no need to update state or set a timer.
}
Iterator.empty
}
override def handleExpiredTimer(
key: String,
timerValues: TimerValues,
expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, Duration)] = {
val latestTimestamp = _lastSeen.get()
val downtimeDuration = new Duration(
timerValues.getCurrentProcessingTimeInMs() - latestTimestamp.getTime)
// Register another timer that will fire in 10 seconds.
// Timers can be registered anywhere but init()
getHandle.registerTimer(timerValues.getCurrentProcessingTimeInMs() + 10000)
Iterator((key, downtimeDuration))
}
}
Migrate existing state information
The following example demonstrates how to implement a stateful application that accepts an initial state. You can add initial state handling to any stateful application, but the initial state can only be set when first initializing the application.
This example uses the statestore
reader to load existing state information from a checkpoint path. An example use case for this pattern is migrating from legacy stateful applications to transformWithState
.
Python
import pandas as pd
from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle
from pyspark.sql.types import StructType, StructField, LongType, StringType
from typing import Iterator
spark.conf.set("spark.sql.streaming.stateStore.providerClass","org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider")
"""
Input schema is as below
input_schema = StructType(
[StructField("id", StringType(), True)],
[StructField("value", StringType(), True)]
)
"""
output_schema = StructType([
StructField("id", StringType(), True),
StructField("accumulated", StringType(), True)
])
class AccumulatedCounterStatefulProcessorWithInitialState(StatefulProcessor):
def init(self, handle: StatefulProcessorHandle) -> None:
state_schema = StructType([StructField("value", IntegerType(), True)])
self.counter_state = handle.getValueState("counter_state", state_schema)
self.handle = handle
def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
exists = self.counter_state.exists()
if exists:
value_row = self.counter_state.get()
existing_value = value_row[0]
else:
existing_value = 0
accumulated_value = existing_value
for pdf in rows:
value = pdf["value"].astype(int).sum()
accumulated_value += value
self.counter_state.update((accumulated_value,))
yield pd.DataFrame({"id": key, "accumulated": str(accumulated_value)})
def handleInitialState(self, key, initialState, timerValues) -> None:
init_val = initialState.at[0, "initVal"]
self.counter_state.update((init_val,))
def close(self) -> None:
pass
initial_state = spark.read.format("statestore")
.option("path", "$checkpointsDir")
.load()
df.groupBy("id")
.transformWithStateInPandas(
statefulProcessor=AccumulatedCounterStatefulProcessorWithInitialState(),
outputStructType=output_schema,
outputMode="Update",
timeMode="None",
initialState=initial_state,
)
.writeStream...
Scala
import org.apache.spark.sql.streaming._
import org.apache.spark.sql.{Dataset, Encoder, Encoders , DataFrame}
import org.apache.spark.sql.types._
class InitialStateStatefulProcessor extends StatefulProcessorWithInitialState[String, (String, String, String), (String, String), (String, Int)] {
@transient protected var valueState: ValueState[Int] = _
override def init(
outputMode: OutputMode,
timeMode: TimeMode): Unit = {
valueState = getHandle.getValueState[Int]("valueState",
Encoders.scalaInt, TTLConfig.NONE)
}
override def handleInputRows(
key: String,
inputRows: Iterator[(String, String, String)],
timerValues: TimerValues): Iterator[(String, String)] = {
var existingValue = 0
if (valueState.exists()) {
existingValue += valueState.get()
}
var accumulatedValue = existingValue
for (row <- inputRows) {
accumulatedValue += row._2.toInt
}
valueState.update(accumulatedValue)
Iterator((key, accumulatedValue.toString))
}
override def handleInitialState(
key: String, initialState: (String, Int), timerValues: TimerValues): Unit = {
valueState.update(initialState._2)
}
}