From 6f9e598ccf92f6272bbfb56ac56d3101387131b9 Mon Sep 17 00:00:00 2001 From: Aaditya Ramesh Date: Tue, 15 Nov 2016 13:01:01 -0800 Subject: [SPARK-13027][STREAMING] Added batch time as a parameter to updateStateByKey Added RDD batch time as an input parameter to the update function in updateStateByKey. Author: Aaditya Ramesh Closes #11122 from aramesh117/SPARK-13027. --- .../streaming/dstream/PairDStreamFunctions.scala | 40 ++++++++++--- .../spark/streaming/dstream/StateDStream.scala | 28 ++++----- .../spark/streaming/BasicOperationsSuite.scala | 66 ++++++++++++++++++++++ .../spark/streaming/DStreamClosureSuite.scala | 12 ++++ 4 files changed, 126 insertions(+), 20 deletions(-) (limited to 'streaming/src') diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index 2f2a6d13dd..ac739411fd 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -453,9 +453,12 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) def updateStateByKey[S: ClassTag]( updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)], partitioner: Partitioner, - rememberPartitioner: Boolean - ): DStream[(K, S)] = ssc.withScope { - new StateDStream(self, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner, None) + rememberPartitioner: Boolean): DStream[(K, S)] = ssc.withScope { + val cleanedFunc = ssc.sc.clean(updateFunc) + val newUpdateFunc = (_: Time, it: Iterator[(K, Seq[V], Option[S])]) => { + cleanedFunc(it) + } + new StateDStream(self, newUpdateFunc, partitioner, rememberPartitioner, None) } /** @@ -499,10 +502,33 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)], partitioner: Partitioner, rememberPartitioner: Boolean, - initialRDD: RDD[(K, S)] - ): DStream[(K, S)] = ssc.withScope { - new StateDStream(self, ssc.sc.clean(updateFunc), partitioner, - rememberPartitioner, Some(initialRDD)) + initialRDD: RDD[(K, S)]): DStream[(K, S)] = ssc.withScope { + val cleanedFunc = ssc.sc.clean(updateFunc) + val newUpdateFunc = (_: Time, it: Iterator[(K, Seq[V], Option[S])]) => { + cleanedFunc(it) + } + new StateDStream(self, newUpdateFunc, partitioner, rememberPartitioner, Some(initialRDD)) + } + + /** + * Return a new "state" DStream where the state for each key is updated by applying + * the given function on the previous state of the key and the new values of the key. + * org.apache.spark.Partitioner is used to control the partitioning of each RDD. + * @param updateFunc State update function. If `this` function returns None, then + * corresponding state key-value pair will be eliminated. + * @param partitioner Partitioner for controlling the partitioning of each RDD in the new + * DStream. + * @tparam S State type + */ + def updateStateByKey[S: ClassTag](updateFunc: (Time, K, Seq[V], Option[S]) => Option[S], + partitioner: Partitioner, + rememberPartitioner: Boolean, + initialRDD: Option[RDD[(K, S)]] = None): DStream[(K, S)] = ssc.withScope { + val cleanedFunc = ssc.sc.clean(updateFunc) + val newUpdateFunc = (time: Time, iterator: Iterator[(K, Seq[V], Option[S])]) => { + iterator.flatMap(t => cleanedFunc(time, t._1, t._2, t._3).map(s => (t._1, s))) + } + new StateDStream(self, newUpdateFunc, partitioner, rememberPartitioner, initialRDD) } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala index 8efb09a8ce..5bf1dabf08 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala @@ -27,7 +27,7 @@ import org.apache.spark.streaming.{Duration, Time} private[streaming] class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( parent: DStream[(K, V)], - updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)], + updateFunc: (Time, Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)], partitioner: Partitioner, preservePartitioning: Boolean, initialRDD: Option[RDD[(K, S)]] @@ -41,8 +41,10 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( override val mustCheckpoint = true - private [this] def computeUsingPreviousRDD ( - parentRDD: RDD[(K, V)], prevStateRDD: RDD[(K, S)]) = { + private [this] def computeUsingPreviousRDD( + batchTime: Time, + parentRDD: RDD[(K, V)], + prevStateRDD: RDD[(K, S)]) = { // Define the function for the mapPartition operation on cogrouped RDD; // first map the cogrouped tuple to tuples of required type, // and then apply the update function @@ -53,7 +55,7 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( val headOption = if (itr.hasNext) Some(itr.next()) else None (t._1, t._2._1.toSeq, headOption) } - updateFuncLocal(i) + updateFuncLocal(batchTime, i) } val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner) val stateRDD = cogroupedRDD.mapPartitions(finalFunc, preservePartitioning) @@ -68,15 +70,14 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( case Some(prevStateRDD) => // If previous state RDD exists // Try to get the parent RDD parent.getOrCompute(validTime) match { - case Some(parentRDD) => // If parent RDD exists, then compute as usual - computeUsingPreviousRDD(parentRDD, prevStateRDD) - case None => // If parent RDD does not exist - + case Some(parentRDD) => // If parent RDD exists, then compute as usual + computeUsingPreviousRDD (validTime, parentRDD, prevStateRDD) + case None => // If parent RDD does not exist // Re-apply the update function to the old state RDD val updateFuncLocal = updateFunc val finalFunc = (iterator: Iterator[(K, S)]) => { val i = iterator.map(t => (t._1, Seq[V](), Option(t._2))) - updateFuncLocal(i) + updateFuncLocal(validTime, i) } val stateRDD = prevStateRDD.mapPartitions(finalFunc, preservePartitioning) Some(stateRDD) @@ -93,15 +94,16 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( // and then apply the update function val updateFuncLocal = updateFunc val finalFunc = (iterator: Iterator[(K, Iterable[V])]) => { - updateFuncLocal(iterator.map(tuple => (tuple._1, tuple._2.toSeq, None))) + updateFuncLocal (validTime, + iterator.map (tuple => (tuple._1, tuple._2.toSeq, None))) } val groupedRDD = parentRDD.groupByKey(partitioner) val sessionRDD = groupedRDD.mapPartitions(finalFunc, preservePartitioning) // logDebug("Generating state RDD for time " + validTime + " (first)") - Some(sessionRDD) - case Some(initialStateRDD) => - computeUsingPreviousRDD(parentRDD, initialStateRDD) + Some (sessionRDD) + case Some (initialStateRDD) => + computeUsingPreviousRDD(validTime, parentRDD, initialStateRDD) } case None => // If parent RDD does not exist, then nothing to do! // logDebug("Not generating state RDD (no previous state, no parent)") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index cfcbdc7c38..4e702bbb92 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -471,6 +471,72 @@ class BasicOperationsSuite extends TestSuiteBase { testOperation(inputData, updateStateOperation, outputData, true) } + test("updateStateByKey - testing time stamps as input") { + type StreamingState = Long + val initial: Seq[(String, StreamingState)] = Seq(("a", 0L), ("c", 0L)) + + val inputData = + Seq( + Seq("a"), + Seq("a", "b"), + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq() + ) + + // a -> 1000, 3000, 6000, 10000, 15000, 15000 + // b -> 0, 2000, 5000, 9000, 9000, 9000 + // c -> 1000, 1000, 3000, 3000, 3000, 3000 + + val outputData: Seq[Seq[(String, StreamingState)]] = Seq( + Seq( + ("a", 1000L), + ("c", 0L)), // t = 1000 + Seq( + ("a", 3000L), + ("b", 2000L), + ("c", 0L)), // t = 2000 + Seq( + ("a", 6000L), + ("b", 5000L), + ("c", 3000L)), // t = 3000 + Seq( + ("a", 10000L), + ("b", 9000L), + ("c", 3000L)), // t = 4000 + Seq( + ("a", 15000L), + ("b", 9000L), + ("c", 3000L)), // t = 5000 + Seq( + ("a", 15000L), + ("b", 9000L), + ("c", 3000L)) // t = 6000 + ) + + val updateStateOperation = (s: DStream[String]) => { + val initialRDD = s.context.sparkContext.makeRDD(initial) + val updateFunc = (time: Time, + key: String, + values: Seq[Int], + state: Option[StreamingState]) => { + // Update only if we receive values for this key during the batch. + if (values.nonEmpty) { + Option(time.milliseconds + state.getOrElse(0L)) + } else { + Option(state.getOrElse(0L)) + } + } + s.map(x => (x, 1)).updateStateByKey[StreamingState](updateFunc = updateFunc, + partitioner = new HashPartitioner (numInputPartitions), rememberPartitioner = false, + initialRDD = Option(initialRDD)) + } + + testOperation(input = inputData, operation = updateStateOperation, + expectedOutput = outputData, useSet = true) + } + test("updateStateByKey - with initial value RDD") { val initial = Seq(("a", 1), ("c", 2)) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala index 1fc34f569f..2ab600ab81 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala @@ -164,6 +164,10 @@ class DStreamClosureSuite extends SparkFunSuite with BeforeAndAfterAll { private def testUpdateStateByKey(ds: DStream[(Int, Int)]): Unit = { val updateF1 = (_: Seq[Int], _: Option[Int]) => { return; Some(1) } val updateF2 = (_: Iterator[(Int, Seq[Int], Option[Int])]) => { return; Seq((1, 1)).toIterator } + val updateF3 = (_: Time, _: Int, _: Seq[Int], _: Option[Int]) => { + return + Option(1) + } val initialRDD = ds.ssc.sparkContext.emptyRDD[Int].map { i => (i, i) } expectCorrectException { ds.updateStateByKey(updateF1) } expectCorrectException { ds.updateStateByKey(updateF1, 5) } @@ -177,6 +181,14 @@ class DStreamClosureSuite extends SparkFunSuite with BeforeAndAfterAll { expectCorrectException { ds.updateStateByKey(updateF2, new HashPartitioner(5), true, initialRDD) } + expectCorrectException { + ds.updateStateByKey( + updateFunc = updateF3, + partitioner = new HashPartitioner(5), + rememberPartitioner = true, + initialRDD = Option(initialRDD) + ) + } } private def testMapValues(ds: DStream[(Int, Int)]): Unit = expectCorrectException { ds.mapValues { _ => return; 1 } -- cgit v1.2.3