aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/test
diff options
context:
space:
mode:
authorTathagata Das <tathagata.das1565@gmail.com>2017-03-19 14:07:49 -0700
committerTathagata Das <tathagata.das1565@gmail.com>2017-03-19 14:07:49 -0700
commit990af630d0d569880edd9c7ce9932e10037a28ab (patch)
tree3c25483808ca877693f42d3d10ebd49987e86645 /sql/core/src/test
parent0ee9fbf51ac863e015d57ae7824a39bd3b36141a (diff)
downloadspark-990af630d0d569880edd9c7ce9932e10037a28ab.tar.gz
spark-990af630d0d569880edd9c7ce9932e10037a28ab.tar.bz2
spark-990af630d0d569880edd9c7ce9932e10037a28ab.zip
[SPARK-19067][SS] Processing-time-based timeout in MapGroupsWithState
## What changes were proposed in this pull request? When a key does not get any new data in `mapGroupsWithState`, the mapping function is never called on it. So we need a timeout feature that calls the function again in such cases, so that the user can decide whether to continue waiting or clean up (remove state, save stuff externally, etc.). Timeouts can be either based on processing time or event time. This JIRA is for processing time, but defines the high level API design for both. The usage would look like this. ``` def stateFunction(key: K, value: Iterator[V], state: KeyedState[S]): U = { ... state.setTimeoutDuration(10000) ... } dataset // type is Dataset[T] .groupByKey[K](keyingFunc) // generates KeyValueGroupedDataset[K, T] .mapGroupsWithState[S, U]( func = stateFunction, timeout = KeyedStateTimeout.withProcessingTime) // returns Dataset[U] ``` Note the following design aspects. - The timeout type is provided as a param in mapGroupsWithState as a parameter global to all the keys. This is so that the planner knows this at planning time, and accordingly optimize the execution based on whether to saves extra info in state or not (e.g. timeout durations or timestamps). - The exact timeout duration is provided inside the function call so that it can be customized on a per key basis. - When the timeout occurs for a key, the function is called with no values, and KeyedState.isTimingOut() set to true. - The timeout is reset for key every time the function is called on the key, that is, when the key has new data, or the key has timed out. So the user has to set the timeout duration everytime the function is called, otherwise there will not be any timeout set. Guarantees provided on timeout of key, when timeout duration is D ms: - Timeout will never be called before real clock time has advanced by D ms - Timeout will be called eventually when there is a trigger with any data in it (i.e. after D ms). So there is a no strict upper bound on when the timeout would occur. For example, if there is no data in the stream (for any key) for a while, then the timeout will not be hit. Implementation details: - Added new param to `mapGroupsWithState` for timeout - Added new method to `StateStore` to filter data based on timeout timestamp - Changed the internal map type of `HDFSBackedStateStore` from Java's `HashMap` to `ConcurrentHashMap` as the latter allows weakly-consistent fail-safe iterators on the map data. See comments in code for more details. - Refactored logic of `MapGroupsWithStateExec` to - Save timeout info to state store for each key that has data. - Then, filter states that should be timed out based on the current batch processing timestamp. - Moved KeyedState for `o.a.s.sql` to `o.a.s.sql.streaming`. I remember that this was a feedback in the MapGroupsWithState PR that I had forgotten to address. ## How was this patch tested? New unit tests in - MapGroupsWithStateSuite for timeouts. - StateStoreSuite for new APIs in StateStore. Author: Tathagata Das <tathagata.das1565@gmail.com> Closes #17179 from tdas/mapgroupwithstate-timeout.
Diffstat (limited to 'sql/core/src/test')
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala24
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala546
3 files changed, 524 insertions, 50 deletions
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 439cac3dfb..ca9e5ad2ea 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -23,6 +23,7 @@ import java.sql.Date;
import java.sql.Timestamp;
import java.util.*;
+import org.apache.spark.sql.streaming.KeyedStateTimeout;
import org.apache.spark.sql.streaming.OutputMode;
import scala.Tuple2;
import scala.Tuple3;
@@ -208,7 +209,8 @@ public class JavaDatasetSuite implements Serializable {
},
OutputMode.Append(),
Encoders.LONG(),
- Encoders.STRING());
+ Encoders.STRING(),
+ KeyedStateTimeout.NoTimeout());
Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped2.collectAsList()));
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
index e848f74e31..ebb7422765 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
@@ -123,6 +123,30 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth
assert(getDataFromFiles(provider, version = 2) === Set("b" -> 2, "c" -> 4))
}
+ test("filter and concurrent updates") {
+ val provider = newStoreProvider()
+
+ // Verify state before starting a new set of updates
+ assert(provider.latestIterator.isEmpty)
+ val store = provider.getStore(0)
+ put(store, "a", 1)
+ put(store, "b", 2)
+
+ // Updates should work while iterating of filtered entries
+ val filtered = store.filter { case (keyRow, _) => rowToString(keyRow) == "a" }
+ filtered.foreach { case (keyRow, valueRow) =>
+ store.put(keyRow, intToRow(rowToInt(valueRow) + 1))
+ }
+ assert(get(store, "a") === Some(2))
+
+ // Removes should work while iterating of filtered entries
+ val filtered2 = store.filter { case (keyRow, _) => rowToString(keyRow) == "b" }
+ filtered2.foreach { case (keyRow, _) =>
+ store.remove(keyRow)
+ }
+ assert(get(store, "b") === None)
+ }
+
test("updates iterator with all combos of updates and removes") {
val provider = newStoreProvider()
var currentVersion: Int = 0
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
index 902b842e97..7daa5e6a0f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
@@ -17,20 +17,33 @@
package org.apache.spark.sql.streaming
+import java.util
+import java.util.concurrent.ConcurrentHashMap
+
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.SparkException
-import org.apache.spark.sql.KeyedState
+import org.apache.spark.api.java.function.FlatMapGroupsWithStateFunction
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsWithState
+import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
-import org.apache.spark.sql.execution.streaming.{KeyedStateImpl, MemoryStream}
-import org.apache.spark.sql.execution.streaming.state.StateStore
+import org.apache.spark.sql.execution.RDDScanExec
+import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, KeyedStateImpl, MemoryStream}
+import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, StoreUpdate}
+import org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite.MemoryStateStore
+import org.apache.spark.sql.types.{DataType, IntegerType}
/** Class to check custom state types */
case class RunningCount(count: Long)
+case class Result(key: Long, count: Int)
+
class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterAll {
import testImplicits._
+ import KeyedStateImpl._
override def afterAll(): Unit = {
super.afterAll()
@@ -54,8 +67,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
}
}
assert(state.getOption === expectedData)
- assert(state.isUpdated === shouldBeUpdated)
- assert(state.isRemoved === shouldBeRemoved)
+ assert(state.hasUpdated === shouldBeUpdated)
+ assert(state.hasRemoved === shouldBeRemoved)
}
// Updating empty state
@@ -83,6 +96,79 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
}
}
+ test("KeyedState - setTimeoutDuration, hasTimedOut") {
+ import KeyedStateImpl._
+ var state: KeyedStateImpl[Int] = null
+
+ // When isTimeoutEnabled = false, then setTimeoutDuration() is not allowed
+ for (initState <- Seq(None, Some(5))) {
+ // for different initial state
+ state = new KeyedStateImpl(initState, 1000, isTimeoutEnabled = false, hasTimedOut = false)
+ assert(state.hasTimedOut === false)
+ assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET)
+ intercept[UnsupportedOperationException] {
+ state.setTimeoutDuration(1000)
+ }
+ intercept[UnsupportedOperationException] {
+ state.setTimeoutDuration("1 day")
+ }
+ assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET)
+ }
+
+ def testTimeoutNotAllowed(): Unit = {
+ intercept[IllegalStateException] {
+ state.setTimeoutDuration(1000)
+ }
+ assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET)
+ intercept[IllegalStateException] {
+ state.setTimeoutDuration("2 second")
+ }
+ assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET)
+ }
+
+ // When isTimeoutEnabled = true, then setTimeoutDuration() is not allowed until the
+ // state is be defined
+ state = new KeyedStateImpl(None, 1000, isTimeoutEnabled = true, hasTimedOut = false)
+ assert(state.hasTimedOut === false)
+ assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET)
+ testTimeoutNotAllowed()
+
+ // After state has been set, setTimeoutDuration() is allowed, and
+ // getTimeoutTimestamp returned correct timestamp
+ state.update(5)
+ assert(state.hasTimedOut === false)
+ assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET)
+ state.setTimeoutDuration(1000)
+ assert(state.getTimeoutTimestamp === 2000)
+ state.setTimeoutDuration("2 second")
+ assert(state.getTimeoutTimestamp === 3000)
+ assert(state.hasTimedOut === false)
+
+ // setTimeoutDuration() with negative values or 0 is not allowed
+ def testIllegalTimeout(body: => Unit): Unit = {
+ intercept[IllegalArgumentException] { body }
+ assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET)
+ }
+ state = new KeyedStateImpl(Some(5), 1000, isTimeoutEnabled = true, hasTimedOut = false)
+ testIllegalTimeout { state.setTimeoutDuration(-1000) }
+ testIllegalTimeout { state.setTimeoutDuration(0) }
+ testIllegalTimeout { state.setTimeoutDuration("-2 second") }
+ testIllegalTimeout { state.setTimeoutDuration("-1 month") }
+ testIllegalTimeout { state.setTimeoutDuration("1 month -1 day") }
+
+ // Test remove() clear timeout timestamp, and setTimeoutDuration() is not allowed after that
+ state = new KeyedStateImpl(Some(5), 1000, isTimeoutEnabled = true, hasTimedOut = false)
+ state.remove()
+ assert(state.hasTimedOut === false)
+ assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET)
+ testTimeoutNotAllowed()
+
+ // Test hasTimedOut = true
+ state = new KeyedStateImpl(Some(5), 1000, isTimeoutEnabled = true, hasTimedOut = true)
+ assert(state.hasTimedOut === true)
+ assert(state.getTimeoutTimestamp === TIMEOUT_TIMESTAMP_NOT_SET)
+ }
+
test("KeyedState - primitive type") {
var intState = new KeyedStateImpl[Int](None)
intercept[NoSuchElementException] {
@@ -100,6 +186,151 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
}
}
+ // Values used for testing StateStoreUpdater
+ val currentTimestamp = 1000
+ val beforeCurrentTimestamp = 999
+ val afterCurrentTimestamp = 1001
+
+ // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout is disabled
+ for (priorState <- Seq(None, Some(0))) {
+ val priorStateStr = if (priorState.nonEmpty) "prior state set" else "no prior state"
+ val testName = s"timeout disabled - $priorStateStr - "
+
+ testStateUpdateWithData(
+ testName + "no update",
+ stateUpdates = state => { /* do nothing */ },
+ timeoutType = KeyedStateTimeout.NoTimeout,
+ priorState = priorState,
+ expectedState = priorState) // should not change
+
+ testStateUpdateWithData(
+ testName + "state updated",
+ stateUpdates = state => { state.update(5) },
+ timeoutType = KeyedStateTimeout.NoTimeout,
+ priorState = priorState,
+ expectedState = Some(5)) // should change
+
+ testStateUpdateWithData(
+ testName + "state removed",
+ stateUpdates = state => { state.remove() },
+ timeoutType = KeyedStateTimeout.NoTimeout,
+ priorState = priorState,
+ expectedState = None) // should be removed
+ }
+
+ // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout is enabled
+ for (priorState <- Seq(None, Some(0))) {
+ for (priorTimeoutTimestamp <- Seq(TIMEOUT_TIMESTAMP_NOT_SET, 1000)) {
+ var testName = s"timeout enabled - "
+ if (priorState.nonEmpty) {
+ testName += "prior state set, "
+ if (priorTimeoutTimestamp == 1000) {
+ testName += "prior timeout set - "
+ } else {
+ testName += "no prior timeout - "
+ }
+ } else {
+ testName += "no prior state - "
+ }
+
+ testStateUpdateWithData(
+ testName + "no update",
+ stateUpdates = state => { /* do nothing */ },
+ timeoutType = KeyedStateTimeout.ProcessingTimeTimeout,
+ priorState = priorState,
+ priorTimeoutTimestamp = priorTimeoutTimestamp,
+ expectedState = priorState, // state should not change
+ expectedTimeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET) // timestamp should be reset
+
+ testStateUpdateWithData(
+ testName + "state updated",
+ stateUpdates = state => { state.update(5) },
+ timeoutType = KeyedStateTimeout.ProcessingTimeTimeout,
+ priorState = priorState,
+ priorTimeoutTimestamp = priorTimeoutTimestamp,
+ expectedState = Some(5), // state should change
+ expectedTimeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET) // timestamp should be reset
+
+ testStateUpdateWithData(
+ testName + "state removed",
+ stateUpdates = state => { state.remove() },
+ timeoutType = KeyedStateTimeout.ProcessingTimeTimeout,
+ priorState = priorState,
+ priorTimeoutTimestamp = priorTimeoutTimestamp,
+ expectedState = None) // state should be removed
+
+ testStateUpdateWithData(
+ testName + "timeout and state updated",
+ stateUpdates = state => { state.update(5); state.setTimeoutDuration(5000) },
+ timeoutType = KeyedStateTimeout.ProcessingTimeTimeout,
+ priorState = priorState,
+ priorTimeoutTimestamp = priorTimeoutTimestamp,
+ expectedState = Some(5), // state should change
+ expectedTimeoutTimestamp = currentTimestamp + 5000) // timestamp should change
+ }
+ }
+
+ // Tests for StateStoreUpdater.updateStateForTimedOutKeys()
+ val preTimeoutState = Some(5)
+
+ testStateUpdateWithTimeout(
+ "should not timeout",
+ stateUpdates = state => { assert(false, "function called without timeout") },
+ priorTimeoutTimestamp = afterCurrentTimestamp,
+ expectedState = preTimeoutState, // state should not change
+ expectedTimeoutTimestamp = afterCurrentTimestamp) // timestamp should not change
+
+ testStateUpdateWithTimeout(
+ "should timeout - no update/remove",
+ stateUpdates = state => { /* do nothing */ },
+ priorTimeoutTimestamp = beforeCurrentTimestamp,
+ expectedState = preTimeoutState, // state should not change
+ expectedTimeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET) // timestamp should be reset
+
+ testStateUpdateWithTimeout(
+ "should timeout - update state",
+ stateUpdates = state => { state.update(5) },
+ priorTimeoutTimestamp = beforeCurrentTimestamp,
+ expectedState = Some(5), // state should change
+ expectedTimeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET) // timestamp should be reset
+
+ testStateUpdateWithTimeout(
+ "should timeout - remove state",
+ stateUpdates = state => { state.remove() },
+ priorTimeoutTimestamp = beforeCurrentTimestamp,
+ expectedState = None, // state should be removed
+ expectedTimeoutTimestamp = TIMEOUT_TIMESTAMP_NOT_SET)
+
+ testStateUpdateWithTimeout(
+ "should timeout - timeout updated",
+ stateUpdates = state => { state.setTimeoutDuration(2000) },
+ priorTimeoutTimestamp = beforeCurrentTimestamp,
+ expectedState = preTimeoutState, // state should not change
+ expectedTimeoutTimestamp = currentTimestamp + 2000) // timestamp should change
+
+ testStateUpdateWithTimeout(
+ "should timeout - timeout and state updated",
+ stateUpdates = state => { state.update(5); state.setTimeoutDuration(2000) },
+ priorTimeoutTimestamp = beforeCurrentTimestamp,
+ expectedState = Some(5), // state should change
+ expectedTimeoutTimestamp = currentTimestamp + 2000) // timestamp should change
+
+ test("StateStoreUpdater - rows are cloned before writing to StateStore") {
+ // function for running count
+ val func = (key: Int, values: Iterator[Int], state: KeyedState[Int]) => {
+ state.update(state.getOption.getOrElse(0) + values.size)
+ Iterator.empty
+ }
+ val store = newStateStore()
+ val plan = newFlatMapGroupsWithStateExec(func)
+ val updater = new plan.StateStoreUpdater(store)
+ val data = Seq(1, 1, 2)
+ val returnIter = updater.updateStateForKeysWithData(data.iterator.map(intToRow))
+ returnIter.size // consume the iterator to force store updates
+ val storeData = store.iterator.map { case (k, v) => (rowToInt(k), rowToInt(v)) }.toSet
+ assert(storeData === Set((1, 2), (2, 1)))
+ }
+
test("flatMapGroupsWithState - streaming") {
// Function to maintain running count up to 2, and then remove the count
// Returns the data and the count if state is defined, otherwise does not return anything
@@ -119,7 +350,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
val result =
inputData.toDS()
.groupByKey(x => x)
- .flatMapGroupsWithState(stateFunc, Update) // State: Int, Out: (Str, Str)
+ .flatMapGroupsWithState(Update, KeyedStateTimeout.NoTimeout)(stateFunc)
testStream(result, Update)(
AddData(inputData, "a"),
@@ -162,8 +393,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
val result =
inputData.toDS()
.groupByKey(x => x)
- .flatMapGroupsWithState(stateFunc, Update) // State: Int, Out: (Str, Str)
-
+ .flatMapGroupsWithState(Update, KeyedStateTimeout.NoTimeout)(stateFunc)
testStream(result, Update)(
AddData(inputData, "a", "a", "b"),
CheckLastBatch(("a", "1"), ("a", "2"), ("b", "1")),
@@ -178,59 +408,118 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
)
}
+ test("flatMapGroupsWithState - streaming + aggregation") {
+ // Function to maintain running count up to 2, and then remove the count
+ // Returns the data and the count (-1 if count reached beyond 2 and state was just removed)
+ val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
+
+ val count = state.getOption.map(_.count).getOrElse(0L) + values.size
+ if (count == 3) {
+ state.remove()
+ Iterator(key -> "-1")
+ } else {
+ state.update(RunningCount(count))
+ Iterator(key -> count.toString)
+ }
+ }
+
+ val inputData = MemoryStream[String]
+ val result =
+ inputData.toDS()
+ .groupByKey(x => x)
+ .flatMapGroupsWithState(Append, KeyedStateTimeout.NoTimeout)(stateFunc)
+ .groupByKey(_._1)
+ .count()
+
+ testStream(result, Complete)(
+ AddData(inputData, "a"),
+ CheckLastBatch(("a", 1)),
+ AddData(inputData, "a", "b"),
+ // mapGroups generates ("a", "2"), ("b", "1"); so increases counts of a and b by 1
+ CheckLastBatch(("a", 2), ("b", 1)),
+ StopStream,
+ StartStream(),
+ AddData(inputData, "a", "b"),
+ // mapGroups should remove state for "a" and generate ("a", "-1"), ("b", "2") ;
+ // so increment a and b by 1
+ CheckLastBatch(("a", 3), ("b", 2)),
+ StopStream,
+ StartStream(),
+ AddData(inputData, "a", "c"),
+ // mapGroups should recreate state for "a" and generate ("a", "1"), ("c", "1") ;
+ // so increment a and c by 1
+ CheckLastBatch(("a", 4), ("b", 2), ("c", 1))
+ )
+ }
+
test("flatMapGroupsWithState - batch") {
// Function that returns running count only if its even, otherwise does not return
val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
if (state.exists) throw new IllegalArgumentException("state.exists should be false")
Iterator((key, values.size))
}
- checkAnswer(
- Seq("a", "a", "b").toDS.groupByKey(x => x).flatMapGroupsWithState(stateFunc, Update).toDF,
- Seq(("a", 2), ("b", 1)).toDF)
+ val df = Seq("a", "a", "b").toDS
+ .groupByKey(x => x)
+ .flatMapGroupsWithState(Update, KeyedStateTimeout.NoTimeout)(stateFunc).toDF
+ checkAnswer(df, Seq(("a", 2), ("b", 1)).toDF)
}
- test("mapGroupsWithState - streaming") {
+ test("flatMapGroupsWithState - streaming with processing time timeout") {
// Function to maintain running count up to 2, and then remove the count
// Returns the data and the count (-1 if count reached beyond 2 and state was just removed)
val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
-
- val count = state.getOption.map(_.count).getOrElse(0L) + values.size
- if (count == 3) {
+ if (state.hasTimedOut) {
state.remove()
- (key, "-1")
+ Iterator((key, "-1"))
} else {
+ val count = state.getOption.map(_.count).getOrElse(0L) + values.size
state.update(RunningCount(count))
- (key, count.toString)
+ state.setTimeoutDuration("10 seconds")
+ Iterator((key, count.toString))
}
}
+ val clock = new StreamManualClock
val inputData = MemoryStream[String]
+ val timeout = KeyedStateTimeout.ProcessingTimeTimeout
val result =
inputData.toDS()
.groupByKey(x => x)
- .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str)
+ .flatMapGroupsWithState(Update, timeout)(stateFunc)
testStream(result, Update)(
+ StartStream(ProcessingTime("1 second"), triggerClock = clock),
AddData(inputData, "a"),
+ AdvanceManualClock(1 * 1000),
CheckLastBatch(("a", "1")),
assertNumStateRows(total = 1, updated = 1),
- AddData(inputData, "a", "b"),
- CheckLastBatch(("a", "2"), ("b", "1")),
- assertNumStateRows(total = 2, updated = 2),
- StopStream,
- StartStream(),
- AddData(inputData, "a", "b"), // should remove state for "a" and return count as -1
+
+ AddData(inputData, "b"),
+ AdvanceManualClock(1 * 1000),
+ CheckLastBatch(("b", "1")),
+ assertNumStateRows(total = 2, updated = 1),
+
+ AddData(inputData, "b"),
+ AdvanceManualClock(10 * 1000),
CheckLastBatch(("a", "-1"), ("b", "2")),
assertNumStateRows(total = 1, updated = 2),
+
StopStream,
- StartStream(),
- AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1
- CheckLastBatch(("a", "1"), ("c", "1")),
- assertNumStateRows(total = 3, updated = 2)
+ StartStream(ProcessingTime("1 second"), triggerClock = clock),
+
+ AddData(inputData, "c"),
+ AdvanceManualClock(20 * 1000),
+ CheckLastBatch(("b", "-1"), ("c", "1")),
+ assertNumStateRows(total = 1, updated = 2),
+
+ AddData(inputData, "c"),
+ AdvanceManualClock(20 * 1000),
+ CheckLastBatch(("c", "2")),
+ assertNumStateRows(total = 1, updated = 1)
)
}
- test("flatMapGroupsWithState - streaming + aggregation") {
+ test("mapGroupsWithState - streaming") {
// Function to maintain running count up to 2, and then remove the count
// Returns the data and the count (-1 if count reached beyond 2 and state was just removed)
val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
@@ -238,10 +527,10 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
val count = state.getOption.map(_.count).getOrElse(0L) + values.size
if (count == 3) {
state.remove()
- Iterator(key -> "-1")
+ (key, "-1")
} else {
state.update(RunningCount(count))
- Iterator(key -> count.toString)
+ (key, count.toString)
}
}
@@ -249,28 +538,25 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
val result =
inputData.toDS()
.groupByKey(x => x)
- .flatMapGroupsWithState(stateFunc, Append) // Types = State: MyState, Out: (Str, Str)
- .groupByKey(_._1)
- .count()
+ .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str)
- testStream(result, Complete)(
+ testStream(result, Update)(
AddData(inputData, "a"),
- CheckLastBatch(("a", 1)),
+ CheckLastBatch(("a", "1")),
+ assertNumStateRows(total = 1, updated = 1),
AddData(inputData, "a", "b"),
- // mapGroups generates ("a", "2"), ("b", "1"); so increases counts of a and b by 1
- CheckLastBatch(("a", 2), ("b", 1)),
+ CheckLastBatch(("a", "2"), ("b", "1")),
+ assertNumStateRows(total = 2, updated = 2),
StopStream,
StartStream(),
- AddData(inputData, "a", "b"),
- // mapGroups should remove state for "a" and generate ("a", "-1"), ("b", "2") ;
- // so increment a and b by 1
- CheckLastBatch(("a", 3), ("b", 2)),
+ AddData(inputData, "a", "b"), // should remove state for "a" and return count as -1
+ CheckLastBatch(("a", "-1"), ("b", "2")),
+ assertNumStateRows(total = 1, updated = 2),
StopStream,
StartStream(),
- AddData(inputData, "a", "c"),
- // mapGroups should recreate state for "a" and generate ("a", "1"), ("c", "1") ;
- // so increment a and c by 1
- CheckLastBatch(("a", 4), ("b", 2), ("c", 1))
+ AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1
+ CheckLastBatch(("a", "1"), ("c", "1")),
+ assertNumStateRows(total = 3, updated = 2)
)
}
@@ -322,23 +608,185 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf
)
}
+ test("output partitioning is unknown") {
+ val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => key
+ val inputData = MemoryStream[String]
+ val result = inputData.toDS.groupByKey(x => x).mapGroupsWithState(stateFunc)
+ result
+ testStream(result, Update)(
+ AddData(inputData, "a"),
+ CheckLastBatch("a"),
+ AssertOnQuery(_.lastExecution.executedPlan.outputPartitioning === UnknownPartitioning(0))
+ )
+ }
+
test("disallow complete mode") {
- val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => {
+ val stateFunc = (key: String, values: Iterator[String], state: KeyedState[Int]) => {
Iterator[String]()
}
var e = intercept[IllegalArgumentException] {
- MemoryStream[String].toDS().groupByKey(x => x).flatMapGroupsWithState(stateFunc, Complete)
+ MemoryStream[String].toDS().groupByKey(x => x).flatMapGroupsWithState(
+ OutputMode.Complete, KeyedStateTimeout.NoTimeout)(stateFunc)
}
assert(e.getMessage === "The output mode of function should be append or update")
+ val javaStateFunc = new FlatMapGroupsWithStateFunction[String, String, Int, String] {
+ import java.util.{Iterator => JIterator}
+ override def call(
+ key: String,
+ values: JIterator[String],
+ state: KeyedState[Int]): JIterator[String] = { null }
+ }
e = intercept[IllegalArgumentException] {
- MemoryStream[String].toDS().groupByKey(x => x).flatMapGroupsWithState(stateFunc, "complete")
+ MemoryStream[String].toDS().groupByKey(x => x).flatMapGroupsWithState(
+ javaStateFunc, OutputMode.Complete,
+ implicitly[Encoder[Int]], implicitly[Encoder[String]], KeyedStateTimeout.NoTimeout)
}
assert(e.getMessage === "The output mode of function should be append or update")
}
+
+ def testStateUpdateWithData(
+ testName: String,
+ stateUpdates: KeyedState[Int] => Unit,
+ timeoutType: KeyedStateTimeout = KeyedStateTimeout.NoTimeout,
+ priorState: Option[Int],
+ priorTimeoutTimestamp: Long = TIMEOUT_TIMESTAMP_NOT_SET,
+ expectedState: Option[Int] = None,
+ expectedTimeoutTimestamp: Long = TIMEOUT_TIMESTAMP_NOT_SET): Unit = {
+
+ if (priorState.isEmpty && priorTimeoutTimestamp != TIMEOUT_TIMESTAMP_NOT_SET) {
+ return // there can be no prior timestamp, when there is no prior state
+ }
+ test(s"StateStoreUpdater - updates with data - $testName") {
+ val mapGroupsFunc = (key: Int, values: Iterator[Int], state: KeyedState[Int]) => {
+ assert(state.hasTimedOut === false, "hasTimedOut not false")
+ assert(values.nonEmpty, "Some value is expected")
+ stateUpdates(state)
+ Iterator.empty
+ }
+ testStateUpdate(
+ testTimeoutUpdates = false, mapGroupsFunc, timeoutType,
+ priorState, priorTimeoutTimestamp, expectedState, expectedTimeoutTimestamp)
+ }
+ }
+
+ def testStateUpdateWithTimeout(
+ testName: String,
+ stateUpdates: KeyedState[Int] => Unit,
+ priorTimeoutTimestamp: Long,
+ expectedState: Option[Int],
+ expectedTimeoutTimestamp: Long = TIMEOUT_TIMESTAMP_NOT_SET): Unit = {
+
+ test(s"StateStoreUpdater - updates for timeout - $testName") {
+ val mapGroupsFunc = (key: Int, values: Iterator[Int], state: KeyedState[Int]) => {
+ assert(state.hasTimedOut === true, "hasTimedOut not true")
+ assert(values.isEmpty, "values not empty")
+ stateUpdates(state)
+ Iterator.empty
+ }
+ testStateUpdate(
+ testTimeoutUpdates = true, mapGroupsFunc, KeyedStateTimeout.ProcessingTimeTimeout,
+ preTimeoutState, priorTimeoutTimestamp,
+ expectedState, expectedTimeoutTimestamp)
+ }
+ }
+
+ def testStateUpdate(
+ testTimeoutUpdates: Boolean,
+ mapGroupsFunc: (Int, Iterator[Int], KeyedState[Int]) => Iterator[Int],
+ timeoutType: KeyedStateTimeout,
+ priorState: Option[Int],
+ priorTimeoutTimestamp: Long,
+ expectedState: Option[Int],
+ expectedTimeoutTimestamp: Long): Unit = {
+
+ val store = newStateStore()
+ val mapGroupsSparkPlan = newFlatMapGroupsWithStateExec(
+ mapGroupsFunc, timeoutType, currentTimestamp)
+ val updater = new mapGroupsSparkPlan.StateStoreUpdater(store)
+ val key = intToRow(0)
+ // Prepare store with prior state configs
+ if (priorState.nonEmpty) {
+ val row = updater.getStateRow(priorState.get)
+ updater.setTimeoutTimestamp(row, priorTimeoutTimestamp)
+ store.put(key.copy(), row.copy())
+ }
+
+ // Call updating function to update state store
+ val returnedIter = if (testTimeoutUpdates) {
+ updater.updateStateForTimedOutKeys()
+ } else {
+ updater.updateStateForKeysWithData(Iterator(key))
+ }
+ returnedIter.size // consumer the iterator to force state updates
+
+ // Verify updated state in store
+ val updatedStateRow = store.get(key)
+ assert(
+ updater.getStateObj(updatedStateRow).map(_.toString.toInt) === expectedState,
+ "final state not as expected")
+ if (updatedStateRow.nonEmpty) {
+ assert(
+ updater.getTimeoutTimestamp(updatedStateRow.get) === expectedTimeoutTimestamp,
+ "final timeout timestamp not as expected")
+ }
+ }
+
+ def newFlatMapGroupsWithStateExec(
+ func: (Int, Iterator[Int], KeyedState[Int]) => Iterator[Int],
+ timeoutType: KeyedStateTimeout = KeyedStateTimeout.NoTimeout,
+ batchTimestampMs: Long = NO_BATCH_PROCESSING_TIMESTAMP): FlatMapGroupsWithStateExec = {
+ MemoryStream[Int]
+ .toDS
+ .groupByKey(x => x)
+ .flatMapGroupsWithState[Int, Int](Append, timeoutConf = timeoutType)(func)
+ .logicalPlan.collectFirst {
+ case FlatMapGroupsWithState(f, k, v, g, d, o, s, m, _, t, _) =>
+ FlatMapGroupsWithStateExec(
+ f, k, v, g, d, o, None, s, m, t, currentTimestamp,
+ RDDScanExec(g, null, "rdd"))
+ }.get
+ }
+
+ def newStateStore(): StateStore = new MemoryStateStore()
+
+ val intProj = UnsafeProjection.create(Array[DataType](IntegerType))
+ def intToRow(i: Int): UnsafeRow = {
+ intProj.apply(new GenericInternalRow(Array[Any](i))).copy()
+ }
+
+ def rowToInt(row: UnsafeRow): Int = row.getInt(0)
}
object FlatMapGroupsWithStateSuite {
+
var failInTask = true
+
+ class MemoryStateStore extends StateStore() {
+ import scala.collection.JavaConverters._
+ private val map = new ConcurrentHashMap[UnsafeRow, UnsafeRow]
+
+ override def iterator(): Iterator[(UnsafeRow, UnsafeRow)] = {
+ map.entrySet.iterator.asScala.map { case e => (e.getKey, e.getValue) }
+ }
+
+ override def filter(c: (UnsafeRow, UnsafeRow) => Boolean): Iterator[(UnsafeRow, UnsafeRow)] = {
+ iterator.filter { case (k, v) => c(k, v) }
+ }
+
+ override def get(key: UnsafeRow): Option[UnsafeRow] = Option(map.get(key))
+ override def put(key: UnsafeRow, newValue: UnsafeRow): Unit = map.put(key, newValue)
+ override def remove(key: UnsafeRow): Unit = { map.remove(key) }
+ override def remove(condition: (UnsafeRow) => Boolean): Unit = {
+ iterator.map(_._1).filter(condition).foreach(map.remove)
+ }
+ override def commit(): Long = version + 1
+ override def abort(): Unit = { }
+ override def id: StateStoreId = null
+ override def version: Long = 0
+ override def updates(): Iterator[StoreUpdate] = { throw new UnsupportedOperationException }
+ override def numKeys(): Long = map.size
+ override def hasCommitted: Boolean = true
+ }
}