aboutsummaryrefslogtreecommitdiff
path: root/streaming
diff options
context:
space:
mode:
authorTathagata Das <tathagata.das1565@gmail.com>2015-11-12 19:02:49 -0800
committerTathagata Das <tathagata.das1565@gmail.com>2015-11-12 19:02:49 -0800
commite4e46b20f6475f8e148d5326f7c88c57850d46a1 (patch)
treea9d6c48bf1090d5e25cf32b8dd67edfb66bc158a /streaming
parent7786f9cc0790d27854a1e184f66a9b4df4d040a2 (diff)
downloadspark-e4e46b20f6475f8e148d5326f7c88c57850d46a1.tar.gz
spark-e4e46b20f6475f8e148d5326f7c88c57850d46a1.tar.bz2
spark-e4e46b20f6475f8e148d5326f7c88c57850d46a1.zip
[SPARK-11681][STREAMING] Correctly update state timestamp even when state is not updated
Bug: Timestamp is not updated if there is data but the corresponding state is not updated. This is wrong, and timeout is defined as "no data for a while", not "not state update for a while". Fix: Update timestamp when timestamp when timeout is specified, otherwise no need. Also refactored the code for better testability and added unit tests. Author: Tathagata Das <tathagata.das1565@gmail.com> Closes #9648 from tdas/SPARK-11681.
Diffstat (limited to 'streaming')
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala105
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala136
2 files changed, 192 insertions, 49 deletions
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala
index fc51496be4..7050378d0f 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala
@@ -32,8 +32,51 @@ import org.apache.spark._
* Record storing the keyed-state [[TrackStateRDD]]. Each record contains a [[StateMap]] and a
* sequence of records returned by the tracking function of `trackStateByKey`.
*/
-private[streaming] case class TrackStateRDDRecord[K, S, T](
- var stateMap: StateMap[K, S], var emittedRecords: Seq[T])
+private[streaming] case class TrackStateRDDRecord[K, S, E](
+ var stateMap: StateMap[K, S], var emittedRecords: Seq[E])
+
+private[streaming] object TrackStateRDDRecord {
+ def updateRecordWithData[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
+ prevRecord: Option[TrackStateRDDRecord[K, S, E]],
+ dataIterator: Iterator[(K, V)],
+ updateFunction: (Time, K, Option[V], State[S]) => Option[E],
+ batchTime: Time,
+ timeoutThresholdTime: Option[Long],
+ removeTimedoutData: Boolean
+ ): TrackStateRDDRecord[K, S, E] = {
+ // Create a new state map by cloning the previous one (if it exists) or by creating an empty one
+ val newStateMap = prevRecord.map { _.stateMap.copy() }. getOrElse { new EmptyStateMap[K, S]() }
+
+ val emittedRecords = new ArrayBuffer[E]
+ val wrappedState = new StateImpl[S]()
+
+ // Call the tracking function on each record in the data iterator, and accordingly
+ // update the states touched, and collect the data returned by the tracking function
+ dataIterator.foreach { case (key, value) =>
+ wrappedState.wrap(newStateMap.get(key))
+ val emittedRecord = updateFunction(batchTime, key, Some(value), wrappedState)
+ if (wrappedState.isRemoved) {
+ newStateMap.remove(key)
+ } else if (wrappedState.isUpdated || timeoutThresholdTime.isDefined) {
+ newStateMap.put(key, wrappedState.get(), batchTime.milliseconds)
+ }
+ emittedRecords ++= emittedRecord
+ }
+
+ // Get the timed out state records, call the tracking function on each and collect the
+ // data returned
+ if (removeTimedoutData && timeoutThresholdTime.isDefined) {
+ newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) =>
+ wrappedState.wrapTiminoutState(state)
+ val emittedRecord = updateFunction(batchTime, key, None, wrappedState)
+ emittedRecords ++= emittedRecord
+ newStateMap.remove(key)
+ }
+ }
+
+ TrackStateRDDRecord(newStateMap, emittedRecords)
+ }
+}
/**
* Partition of the [[TrackStateRDD]], which depends on corresponding partitions of prev state
@@ -72,16 +115,16 @@ private[streaming] class TrackStateRDDPartition(
* @param batchTime The time of the batch to which this RDD belongs to. Use to update
* @param timeoutThresholdTime The time to indicate which keys are timeout
*/
-private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
- private var prevStateRDD: RDD[TrackStateRDDRecord[K, S, T]],
+private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
+ private var prevStateRDD: RDD[TrackStateRDDRecord[K, S, E]],
private var partitionedDataRDD: RDD[(K, V)],
- trackingFunction: (Time, K, Option[V], State[S]) => Option[T],
+ trackingFunction: (Time, K, Option[V], State[S]) => Option[E],
batchTime: Time,
timeoutThresholdTime: Option[Long]
- ) extends RDD[TrackStateRDDRecord[K, S, T]](
+ ) extends RDD[TrackStateRDDRecord[K, S, E]](
partitionedDataRDD.sparkContext,
List(
- new OneToOneDependency[TrackStateRDDRecord[K, S, T]](prevStateRDD),
+ new OneToOneDependency[TrackStateRDDRecord[K, S, E]](prevStateRDD),
new OneToOneDependency(partitionedDataRDD))
) {
@@ -98,7 +141,7 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T:
}
override def compute(
- partition: Partition, context: TaskContext): Iterator[TrackStateRDDRecord[K, S, T]] = {
+ partition: Partition, context: TaskContext): Iterator[TrackStateRDDRecord[K, S, E]] = {
val stateRDDPartition = partition.asInstanceOf[TrackStateRDDPartition]
val prevStateRDDIterator = prevStateRDD.iterator(
@@ -106,42 +149,16 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T:
val dataIterator = partitionedDataRDD.iterator(
stateRDDPartition.partitionedDataRDDPartition, context)
- // Create a new state map by cloning the previous one (if it exists) or by creating an empty one
- val newStateMap = if (prevStateRDDIterator.hasNext) {
- prevStateRDDIterator.next().stateMap.copy()
- } else {
- new EmptyStateMap[K, S]()
- }
-
- val emittedRecords = new ArrayBuffer[T]
- val wrappedState = new StateImpl[S]()
-
- // Call the tracking function on each record in the data RDD partition, and accordingly
- // update the states touched, and the data returned by the tracking function.
- dataIterator.foreach { case (key, value) =>
- wrappedState.wrap(newStateMap.get(key))
- val emittedRecord = trackingFunction(batchTime, key, Some(value), wrappedState)
- if (wrappedState.isRemoved) {
- newStateMap.remove(key)
- } else if (wrappedState.isUpdated) {
- newStateMap.put(key, wrappedState.get(), batchTime.milliseconds)
- }
- emittedRecords ++= emittedRecord
- }
-
- // If the RDD is expected to be doing a full scan of all the data in the StateMap,
- // then use this opportunity to filter out those keys that have timed out.
- // For each of them call the tracking function.
- if (doFullScan && timeoutThresholdTime.isDefined) {
- newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) =>
- wrappedState.wrapTiminoutState(state)
- val emittedRecord = trackingFunction(batchTime, key, None, wrappedState)
- emittedRecords ++= emittedRecord
- newStateMap.remove(key)
- }
- }
-
- Iterator(TrackStateRDDRecord(newStateMap, emittedRecords))
+ val prevRecord = if (prevStateRDDIterator.hasNext) Some(prevStateRDDIterator.next()) else None
+ val newRecord = TrackStateRDDRecord.updateRecordWithData(
+ prevRecord,
+ dataIterator,
+ trackingFunction,
+ batchTime,
+ timeoutThresholdTime,
+ removeTimedoutData = doFullScan // remove timedout data only when full scan is enabled
+ )
+ Iterator(newRecord)
}
override protected def getPartitions: Array[Partition] = {
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
index f396b76e8d..19ef5a14f8 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
@@ -23,6 +23,7 @@ import scala.reflect.ClassTag
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.rdd.RDD
+import org.apache.spark.streaming.util.OpenHashMapBasedStateMap
import org.apache.spark.streaming.{Time, State}
import org.apache.spark.{HashPartitioner, SparkConf, SparkContext, SparkFunSuite}
@@ -52,6 +53,131 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
assert(rdd.partitioner === Some(partitioner))
}
+ test("updating state and generating emitted data in TrackStateRecord") {
+
+ val initialTime = 1000L
+ val updatedTime = 2000L
+ val thresholdTime = 1500L
+ @volatile var functionCalled = false
+
+ /**
+ * Assert that applying given data on a prior record generates correct updated record, with
+ * correct state map and emitted data
+ */
+ def assertRecordUpdate(
+ initStates: Iterable[Int],
+ data: Iterable[String],
+ expectedStates: Iterable[(Int, Long)],
+ timeoutThreshold: Option[Long] = None,
+ removeTimedoutData: Boolean = false,
+ expectedOutput: Iterable[Int] = None,
+ expectedTimingOutStates: Iterable[Int] = None,
+ expectedRemovedStates: Iterable[Int] = None
+ ): Unit = {
+ val initialStateMap = new OpenHashMapBasedStateMap[String, Int]()
+ initStates.foreach { s => initialStateMap.put("key", s, initialTime) }
+ functionCalled = false
+ val record = TrackStateRDDRecord[String, Int, Int](initialStateMap, Seq.empty)
+ val dataIterator = data.map { v => ("key", v) }.iterator
+ val removedStates = new ArrayBuffer[Int]
+ val timingOutStates = new ArrayBuffer[Int]
+ /**
+ * Tracking function that updates/removes state based on instructions in the data, and
+ * return state (when instructed or when state is timing out).
+ */
+ def testFunc(t: Time, key: String, data: Option[String], state: State[Int]): Option[Int] = {
+ functionCalled = true
+
+ assert(t.milliseconds === updatedTime, "tracking func called with wrong time")
+
+ data match {
+ case Some("noop") =>
+ None
+ case Some("get-state") =>
+ Some(state.getOption().getOrElse(-1))
+ case Some("update-state") =>
+ if (state.exists) state.update(state.get + 1) else state.update(0)
+ None
+ case Some("remove-state") =>
+ removedStates += state.get()
+ state.remove()
+ None
+ case None =>
+ assert(state.isTimingOut() === true, "State is not timing out when data = None")
+ timingOutStates += state.get()
+ None
+ case _ =>
+ fail("Unexpected test data")
+ }
+ }
+
+ val updatedRecord = TrackStateRDDRecord.updateRecordWithData[String, String, Int, Int](
+ Some(record), dataIterator, testFunc,
+ Time(updatedTime), timeoutThreshold, removeTimedoutData)
+
+ val updatedStateData = updatedRecord.stateMap.getAll().map { x => (x._2, x._3) }
+ assert(updatedStateData.toSet === expectedStates.toSet,
+ "states do not match after updating the TrackStateRecord")
+
+ assert(updatedRecord.emittedRecords.toSet === expectedOutput.toSet,
+ "emitted data do not match after updating the TrackStateRecord")
+
+ assert(timingOutStates.toSet === expectedTimingOutStates.toSet, "timing out states do not " +
+ "match those that were expected to do so while updating the TrackStateRecord")
+
+ assert(removedStates.toSet === expectedRemovedStates.toSet, "removed states do not " +
+ "match those that were expected to do so while updating the TrackStateRecord")
+
+ }
+
+ // No data, no state should be changed, function should not be called,
+ assertRecordUpdate(initStates = Nil, data = None, expectedStates = Nil)
+ assert(functionCalled === false)
+ assertRecordUpdate(initStates = Seq(0), data = None, expectedStates = Seq((0, initialTime)))
+ assert(functionCalled === false)
+
+ // Data present, function should be called irrespective of whether state exists
+ assertRecordUpdate(initStates = Seq(0), data = Seq("noop"),
+ expectedStates = Seq((0, initialTime)))
+ assert(functionCalled === true)
+ assertRecordUpdate(initStates = None, data = Some("noop"), expectedStates = None)
+ assert(functionCalled === true)
+
+ // Function called with right state data
+ assertRecordUpdate(initStates = None, data = Seq("get-state"),
+ expectedStates = None, expectedOutput = Seq(-1))
+ assertRecordUpdate(initStates = Seq(123), data = Seq("get-state"),
+ expectedStates = Seq((123, initialTime)), expectedOutput = Seq(123))
+
+ // Update state and timestamp, when timeout not present
+ assertRecordUpdate(initStates = Nil, data = Seq("update-state"),
+ expectedStates = Seq((0, updatedTime)))
+ assertRecordUpdate(initStates = Seq(0), data = Seq("update-state"),
+ expectedStates = Seq((1, updatedTime)))
+
+ // Remove state
+ assertRecordUpdate(initStates = Seq(345), data = Seq("remove-state"),
+ expectedStates = Nil, expectedRemovedStates = Seq(345))
+
+ // State strictly older than timeout threshold should be timed out
+ assertRecordUpdate(initStates = Seq(123), data = Nil,
+ timeoutThreshold = Some(initialTime), removeTimedoutData = true,
+ expectedStates = Seq((123, initialTime)), expectedTimingOutStates = Nil)
+
+ assertRecordUpdate(initStates = Seq(123), data = Nil,
+ timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true,
+ expectedStates = Nil, expectedTimingOutStates = Seq(123))
+
+ // State should not be timed out after it has received data
+ assertRecordUpdate(initStates = Seq(123), data = Seq("noop"),
+ timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true,
+ expectedStates = Seq((123, updatedTime)), expectedTimingOutStates = Nil)
+ assertRecordUpdate(initStates = Seq(123), data = Seq("remove-state"),
+ timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true,
+ expectedStates = Nil, expectedTimingOutStates = Nil, expectedRemovedStates = Seq(123))
+
+ }
+
test("states generated by TrackStateRDD") {
val initStates = Seq(("k1", 0), ("k2", 0))
val initTime = 123
@@ -148,9 +274,8 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
val rdd7 = testStateUpdates( // should remove k2's state
rdd6, Seq(("k2", 2), ("k0", 2), ("k3", 1)), Set(("k3", 0, updateTime)))
- val rdd8 = testStateUpdates(
- rdd7, Seq(("k3", 2)), Set() //
- )
+ val rdd8 = testStateUpdates( // should remove k3's state
+ rdd7, Seq(("k3", 2)), Set())
}
/** Assert whether the `trackStateByKey` operation generates expected results */
@@ -176,7 +301,7 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
// Persist to make sure that it gets computed only once and we can track precisely how many
// state keys the computing touched
- newStateRDD.persist()
+ newStateRDD.persist().count()
assertRDD(newStateRDD, expectedStates, expectedEmittedRecords)
newStateRDD
}
@@ -188,7 +313,8 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
expectedEmittedRecords: Set[T]): Unit = {
val states = trackStateRDD.flatMap { _.stateMap.getAll() }.collect().toSet
val emittedRecords = trackStateRDD.flatMap { _.emittedRecords }.collect().toSet
- assert(states === expectedStates, "states after track state operation were not as expected")
+ assert(states === expectedStates,
+ "states after track state operation were not as expected")
assert(emittedRecords === expectedEmittedRecords,
"emitted records after track state operation were not as expected")
}