From e4e46b20f6475f8e148d5326f7c88c57850d46a1 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 12 Nov 2015 19:02:49 -0800 Subject: [SPARK-11681][STREAMING] Correctly update state timestamp even when state is not updated Bug: Timestamp is not updated if there is data but the corresponding state is not updated. This is wrong, and timeout is defined as "no data for a while", not "not state update for a while". Fix: Update timestamp when timestamp when timeout is specified, otherwise no need. Also refactored the code for better testability and added unit tests. Author: Tathagata Das Closes #9648 from tdas/SPARK-11681. --- .../apache/spark/streaming/rdd/TrackStateRDD.scala | 105 +++++++++------- .../spark/streaming/rdd/TrackStateRDDSuite.scala | 136 ++++++++++++++++++++- 2 files changed, 192 insertions(+), 49 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala index fc51496be4..7050378d0f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala @@ -32,8 +32,51 @@ import org.apache.spark._ * Record storing the keyed-state [[TrackStateRDD]]. Each record contains a [[StateMap]] and a * sequence of records returned by the tracking function of `trackStateByKey`. */ -private[streaming] case class TrackStateRDDRecord[K, S, T]( - var stateMap: StateMap[K, S], var emittedRecords: Seq[T]) +private[streaming] case class TrackStateRDDRecord[K, S, E]( + var stateMap: StateMap[K, S], var emittedRecords: Seq[E]) + +private[streaming] object TrackStateRDDRecord { + def updateRecordWithData[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag]( + prevRecord: Option[TrackStateRDDRecord[K, S, E]], + dataIterator: Iterator[(K, V)], + updateFunction: (Time, K, Option[V], State[S]) => Option[E], + batchTime: Time, + timeoutThresholdTime: Option[Long], + removeTimedoutData: Boolean + ): TrackStateRDDRecord[K, S, E] = { + // Create a new state map by cloning the previous one (if it exists) or by creating an empty one + val newStateMap = prevRecord.map { _.stateMap.copy() }. getOrElse { new EmptyStateMap[K, S]() } + + val emittedRecords = new ArrayBuffer[E] + val wrappedState = new StateImpl[S]() + + // Call the tracking function on each record in the data iterator, and accordingly + // update the states touched, and collect the data returned by the tracking function + dataIterator.foreach { case (key, value) => + wrappedState.wrap(newStateMap.get(key)) + val emittedRecord = updateFunction(batchTime, key, Some(value), wrappedState) + if (wrappedState.isRemoved) { + newStateMap.remove(key) + } else if (wrappedState.isUpdated || timeoutThresholdTime.isDefined) { + newStateMap.put(key, wrappedState.get(), batchTime.milliseconds) + } + emittedRecords ++= emittedRecord + } + + // Get the timed out state records, call the tracking function on each and collect the + // data returned + if (removeTimedoutData && timeoutThresholdTime.isDefined) { + newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) => + wrappedState.wrapTiminoutState(state) + val emittedRecord = updateFunction(batchTime, key, None, wrappedState) + emittedRecords ++= emittedRecord + newStateMap.remove(key) + } + } + + TrackStateRDDRecord(newStateMap, emittedRecords) + } +} /** * Partition of the [[TrackStateRDD]], which depends on corresponding partitions of prev state @@ -72,16 +115,16 @@ private[streaming] class TrackStateRDDPartition( * @param batchTime The time of the batch to which this RDD belongs to. Use to update * @param timeoutThresholdTime The time to indicate which keys are timeout */ -private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( - private var prevStateRDD: RDD[TrackStateRDDRecord[K, S, T]], +private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag]( + private var prevStateRDD: RDD[TrackStateRDDRecord[K, S, E]], private var partitionedDataRDD: RDD[(K, V)], - trackingFunction: (Time, K, Option[V], State[S]) => Option[T], + trackingFunction: (Time, K, Option[V], State[S]) => Option[E], batchTime: Time, timeoutThresholdTime: Option[Long] - ) extends RDD[TrackStateRDDRecord[K, S, T]]( + ) extends RDD[TrackStateRDDRecord[K, S, E]]( partitionedDataRDD.sparkContext, List( - new OneToOneDependency[TrackStateRDDRecord[K, S, T]](prevStateRDD), + new OneToOneDependency[TrackStateRDDRecord[K, S, E]](prevStateRDD), new OneToOneDependency(partitionedDataRDD)) ) { @@ -98,7 +141,7 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: } override def compute( - partition: Partition, context: TaskContext): Iterator[TrackStateRDDRecord[K, S, T]] = { + partition: Partition, context: TaskContext): Iterator[TrackStateRDDRecord[K, S, E]] = { val stateRDDPartition = partition.asInstanceOf[TrackStateRDDPartition] val prevStateRDDIterator = prevStateRDD.iterator( @@ -106,42 +149,16 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: val dataIterator = partitionedDataRDD.iterator( stateRDDPartition.partitionedDataRDDPartition, context) - // Create a new state map by cloning the previous one (if it exists) or by creating an empty one - val newStateMap = if (prevStateRDDIterator.hasNext) { - prevStateRDDIterator.next().stateMap.copy() - } else { - new EmptyStateMap[K, S]() - } - - val emittedRecords = new ArrayBuffer[T] - val wrappedState = new StateImpl[S]() - - // Call the tracking function on each record in the data RDD partition, and accordingly - // update the states touched, and the data returned by the tracking function. - dataIterator.foreach { case (key, value) => - wrappedState.wrap(newStateMap.get(key)) - val emittedRecord = trackingFunction(batchTime, key, Some(value), wrappedState) - if (wrappedState.isRemoved) { - newStateMap.remove(key) - } else if (wrappedState.isUpdated) { - newStateMap.put(key, wrappedState.get(), batchTime.milliseconds) - } - emittedRecords ++= emittedRecord - } - - // If the RDD is expected to be doing a full scan of all the data in the StateMap, - // then use this opportunity to filter out those keys that have timed out. - // For each of them call the tracking function. - if (doFullScan && timeoutThresholdTime.isDefined) { - newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) => - wrappedState.wrapTiminoutState(state) - val emittedRecord = trackingFunction(batchTime, key, None, wrappedState) - emittedRecords ++= emittedRecord - newStateMap.remove(key) - } - } - - Iterator(TrackStateRDDRecord(newStateMap, emittedRecords)) + val prevRecord = if (prevStateRDDIterator.hasNext) Some(prevStateRDDIterator.next()) else None + val newRecord = TrackStateRDDRecord.updateRecordWithData( + prevRecord, + dataIterator, + trackingFunction, + batchTime, + timeoutThresholdTime, + removeTimedoutData = doFullScan // remove timedout data only when full scan is enabled + ) + Iterator(newRecord) } override protected def getPartitions: Array[Partition] = { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala index f396b76e8d..19ef5a14f8 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala @@ -23,6 +23,7 @@ import scala.reflect.ClassTag import org.scalatest.BeforeAndAfterAll import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.util.OpenHashMapBasedStateMap import org.apache.spark.streaming.{Time, State} import org.apache.spark.{HashPartitioner, SparkConf, SparkContext, SparkFunSuite} @@ -52,6 +53,131 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll { assert(rdd.partitioner === Some(partitioner)) } + test("updating state and generating emitted data in TrackStateRecord") { + + val initialTime = 1000L + val updatedTime = 2000L + val thresholdTime = 1500L + @volatile var functionCalled = false + + /** + * Assert that applying given data on a prior record generates correct updated record, with + * correct state map and emitted data + */ + def assertRecordUpdate( + initStates: Iterable[Int], + data: Iterable[String], + expectedStates: Iterable[(Int, Long)], + timeoutThreshold: Option[Long] = None, + removeTimedoutData: Boolean = false, + expectedOutput: Iterable[Int] = None, + expectedTimingOutStates: Iterable[Int] = None, + expectedRemovedStates: Iterable[Int] = None + ): Unit = { + val initialStateMap = new OpenHashMapBasedStateMap[String, Int]() + initStates.foreach { s => initialStateMap.put("key", s, initialTime) } + functionCalled = false + val record = TrackStateRDDRecord[String, Int, Int](initialStateMap, Seq.empty) + val dataIterator = data.map { v => ("key", v) }.iterator + val removedStates = new ArrayBuffer[Int] + val timingOutStates = new ArrayBuffer[Int] + /** + * Tracking function that updates/removes state based on instructions in the data, and + * return state (when instructed or when state is timing out). + */ + def testFunc(t: Time, key: String, data: Option[String], state: State[Int]): Option[Int] = { + functionCalled = true + + assert(t.milliseconds === updatedTime, "tracking func called with wrong time") + + data match { + case Some("noop") => + None + case Some("get-state") => + Some(state.getOption().getOrElse(-1)) + case Some("update-state") => + if (state.exists) state.update(state.get + 1) else state.update(0) + None + case Some("remove-state") => + removedStates += state.get() + state.remove() + None + case None => + assert(state.isTimingOut() === true, "State is not timing out when data = None") + timingOutStates += state.get() + None + case _ => + fail("Unexpected test data") + } + } + + val updatedRecord = TrackStateRDDRecord.updateRecordWithData[String, String, Int, Int]( + Some(record), dataIterator, testFunc, + Time(updatedTime), timeoutThreshold, removeTimedoutData) + + val updatedStateData = updatedRecord.stateMap.getAll().map { x => (x._2, x._3) } + assert(updatedStateData.toSet === expectedStates.toSet, + "states do not match after updating the TrackStateRecord") + + assert(updatedRecord.emittedRecords.toSet === expectedOutput.toSet, + "emitted data do not match after updating the TrackStateRecord") + + assert(timingOutStates.toSet === expectedTimingOutStates.toSet, "timing out states do not " + + "match those that were expected to do so while updating the TrackStateRecord") + + assert(removedStates.toSet === expectedRemovedStates.toSet, "removed states do not " + + "match those that were expected to do so while updating the TrackStateRecord") + + } + + // No data, no state should be changed, function should not be called, + assertRecordUpdate(initStates = Nil, data = None, expectedStates = Nil) + assert(functionCalled === false) + assertRecordUpdate(initStates = Seq(0), data = None, expectedStates = Seq((0, initialTime))) + assert(functionCalled === false) + + // Data present, function should be called irrespective of whether state exists + assertRecordUpdate(initStates = Seq(0), data = Seq("noop"), + expectedStates = Seq((0, initialTime))) + assert(functionCalled === true) + assertRecordUpdate(initStates = None, data = Some("noop"), expectedStates = None) + assert(functionCalled === true) + + // Function called with right state data + assertRecordUpdate(initStates = None, data = Seq("get-state"), + expectedStates = None, expectedOutput = Seq(-1)) + assertRecordUpdate(initStates = Seq(123), data = Seq("get-state"), + expectedStates = Seq((123, initialTime)), expectedOutput = Seq(123)) + + // Update state and timestamp, when timeout not present + assertRecordUpdate(initStates = Nil, data = Seq("update-state"), + expectedStates = Seq((0, updatedTime))) + assertRecordUpdate(initStates = Seq(0), data = Seq("update-state"), + expectedStates = Seq((1, updatedTime))) + + // Remove state + assertRecordUpdate(initStates = Seq(345), data = Seq("remove-state"), + expectedStates = Nil, expectedRemovedStates = Seq(345)) + + // State strictly older than timeout threshold should be timed out + assertRecordUpdate(initStates = Seq(123), data = Nil, + timeoutThreshold = Some(initialTime), removeTimedoutData = true, + expectedStates = Seq((123, initialTime)), expectedTimingOutStates = Nil) + + assertRecordUpdate(initStates = Seq(123), data = Nil, + timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true, + expectedStates = Nil, expectedTimingOutStates = Seq(123)) + + // State should not be timed out after it has received data + assertRecordUpdate(initStates = Seq(123), data = Seq("noop"), + timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true, + expectedStates = Seq((123, updatedTime)), expectedTimingOutStates = Nil) + assertRecordUpdate(initStates = Seq(123), data = Seq("remove-state"), + timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true, + expectedStates = Nil, expectedTimingOutStates = Nil, expectedRemovedStates = Seq(123)) + + } + test("states generated by TrackStateRDD") { val initStates = Seq(("k1", 0), ("k2", 0)) val initTime = 123 @@ -148,9 +274,8 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll { val rdd7 = testStateUpdates( // should remove k2's state rdd6, Seq(("k2", 2), ("k0", 2), ("k3", 1)), Set(("k3", 0, updateTime))) - val rdd8 = testStateUpdates( - rdd7, Seq(("k3", 2)), Set() // - ) + val rdd8 = testStateUpdates( // should remove k3's state + rdd7, Seq(("k3", 2)), Set()) } /** Assert whether the `trackStateByKey` operation generates expected results */ @@ -176,7 +301,7 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll { // Persist to make sure that it gets computed only once and we can track precisely how many // state keys the computing touched - newStateRDD.persist() + newStateRDD.persist().count() assertRDD(newStateRDD, expectedStates, expectedEmittedRecords) newStateRDD } @@ -188,7 +313,8 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll { expectedEmittedRecords: Set[T]): Unit = { val states = trackStateRDD.flatMap { _.stateMap.getAll() }.collect().toSet val emittedRecords = trackStateRDD.flatMap { _.emittedRecords }.collect().toSet - assert(states === expectedStates, "states after track state operation were not as expected") + assert(states === expectedStates, + "states after track state operation were not as expected") assert(emittedRecords === expectedEmittedRecords, "emitted records after track state operation were not as expected") } -- cgit v1.2.3