From 36ddeb7bf83ac5a1af9d3db07ad4c380777e4d1a Mon Sep 17 00:00:00 2001 From: Soumitra Kumar Date: Wed, 12 Nov 2014 12:25:31 -0800 Subject: [SPARK-3660][STREAMING] Initial RDD for updateStateByKey transformation SPARK-3660 : Initial RDD for updateStateByKey transformation I have added a sample StatefulNetworkWordCountWithInitial inspired by StatefulNetworkWordCount. Please let me know if any changes are required. Author: Soumitra Kumar Closes #2665 from soumitrak/master and squashes the following commits: ee8980b [Soumitra Kumar] Fixed copy/paste issue. 304f636 [Soumitra Kumar] Added simpler version of updateStateByKey API with initialRDD and test. 9781135 [Soumitra Kumar] Fixed test, and renamed variable. 3da51a2 [Soumitra Kumar] Adding updateStateByKey with initialRDD API to JavaPairDStream. 2f78f7e [Soumitra Kumar] Merge remote-tracking branch 'upstream/master' d4fdd18 [Soumitra Kumar] Renamed variable and moved method. d0ce2cd [Soumitra Kumar] Merge remote-tracking branch 'upstream/master' 31399a4 [Soumitra Kumar] Merge remote-tracking branch 'upstream/master' 4efa58b [Soumitra Kumar] [SPARK-3660][STREAMING] Initial RDD for updateStateByKey transformation 8f40ca0 [Soumitra Kumar] Merge remote-tracking branch 'upstream/master' dde4271 [Soumitra Kumar] Merge remote-tracking branch 'upstream/master' fdd7db3 [Soumitra Kumar] Adding support of initial value for state update. SPARK-3660 : Initial RDD for updateStateByKey transformation --- .../spark/streaming/api/java/JavaPairDStream.scala | 19 ++++++ .../streaming/dstream/PairDStreamFunctions.scala | 49 +++++++++++++- .../spark/streaming/dstream/StateDStream.scala | 70 +++++++++++--------- .../org/apache/spark/streaming/JavaAPISuite.java | 53 ++++++++++++++-- .../spark/streaming/BasicOperationsSuite.scala | 74 ++++++++++++++++++++++ 5 files changed, 229 insertions(+), 36 deletions(-) (limited to 'streaming') diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index 59d4423086..bb44b906d7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -492,6 +492,25 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( dstream.updateStateByKey(convertUpdateStateFunction(updateFunc), partitioner) } + /** + * Return a new "state" DStream where the state for each key is updated by applying + * the given function on the previous state of the key and the new values of the key. + * org.apache.spark.Partitioner is used to control the partitioning of each RDD. + * @param updateFunc State update function. If `this` function returns None, then + * corresponding state key-value pair will be eliminated. + * @param partitioner Partitioner for controlling the partitioning of each RDD in the new + * DStream. + * @param initialRDD initial state value of each key. + * @tparam S State type + */ + def updateStateByKey[S]( + updateFunc: JFunction2[JList[V], Optional[S], Optional[S]], + partitioner: Partitioner, + initialRDD: JavaPairRDD[K, S] + ): JavaPairDStream[K, S] = { + implicit val cm: ClassTag[S] = fakeClassTag + dstream.updateStateByKey(convertUpdateStateFunction(updateFunc), partitioner, initialRDD) + } /** * Return a new DStream by applying a map function to the value of each key-value pairs in diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index 9467595d30..b39f47f04a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -413,7 +413,54 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)]) partitioner: Partitioner, rememberPartitioner: Boolean ): DStream[(K, S)] = { - new StateDStream(self, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner) + new StateDStream(self, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner, None) + } + + /** + * Return a new "state" DStream where the state for each key is updated by applying + * the given function on the previous state of the key and the new values of the key. + * org.apache.spark.Partitioner is used to control the partitioning of each RDD. + * @param updateFunc State update function. If `this` function returns None, then + * corresponding state key-value pair will be eliminated. + * @param partitioner Partitioner for controlling the partitioning of each RDD in the new + * DStream. + * @param initialRDD initial state value of each key. + * @tparam S State type + */ + def updateStateByKey[S: ClassTag]( + updateFunc: (Seq[V], Option[S]) => Option[S], + partitioner: Partitioner, + initialRDD: RDD[(K, S)] + ): DStream[(K, S)] = { + val newUpdateFunc = (iterator: Iterator[(K, Seq[V], Option[S])]) => { + iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s))) + } + updateStateByKey(newUpdateFunc, partitioner, true, initialRDD) + } + + /** + * Return a new "state" DStream where the state for each key is updated by applying + * the given function on the previous state of the key and the new values of each key. + * org.apache.spark.Partitioner is used to control the partitioning of each RDD. + * @param updateFunc State update function. If `this` function returns None, then + * corresponding state key-value pair will be eliminated. Note, that + * this function may generate a different a tuple with a different key + * than the input key. It is up to the developer to decide whether to + * remember the partitioner despite the key being changed. + * @param partitioner Partitioner for controlling the partitioning of each RDD in the new + * DStream + * @param rememberPartitioner Whether to remember the paritioner object in the generated RDDs. + * @param initialRDD initial state value of each key. + * @tparam S State type + */ + def updateStateByKey[S: ClassTag]( + updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)], + partitioner: Partitioner, + rememberPartitioner: Boolean, + initialRDD: RDD[(K, S)] + ): DStream[(K, S)] = { + new StateDStream(self, ssc.sc.clean(updateFunc), partitioner, + rememberPartitioner, Some(initialRDD)) } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala index 7e22268767..ebb04dd35b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala @@ -30,7 +30,8 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( parent: DStream[(K, V)], updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)], partitioner: Partitioner, - preservePartitioning: Boolean + preservePartitioning: Boolean, + initialRDD : Option[RDD[(K, S)]] ) extends DStream[(K, S)](parent.ssc) { super.persist(StorageLevel.MEMORY_ONLY_SER) @@ -41,6 +42,25 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( override val mustCheckpoint = true + private [this] def computeUsingPreviousRDD ( + parentRDD : RDD[(K, V)], prevStateRDD : RDD[(K, S)]) = { + // Define the function for the mapPartition operation on cogrouped RDD; + // first map the cogrouped tuple to tuples of required type, + // and then apply the update function + val updateFuncLocal = updateFunc + val finalFunc = (iterator: Iterator[(K, (Iterable[V], Iterable[S]))]) => { + val i = iterator.map(t => { + val itr = t._2._2.iterator + val headOption = if(itr.hasNext) Some(itr.next) else None + (t._1, t._2._1.toSeq, headOption) + }) + updateFuncLocal(i) + } + val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner) + val stateRDD = cogroupedRDD.mapPartitions(finalFunc, preservePartitioning) + Some(stateRDD) + } + override def compute(validTime: Time): Option[RDD[(K, S)]] = { // Try to get the previous state RDD @@ -51,25 +71,7 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( // Try to get the parent RDD parent.getOrCompute(validTime) match { case Some(parentRDD) => { // If parent RDD exists, then compute as usual - - // Define the function for the mapPartition operation on cogrouped RDD; - // first map the cogrouped tuple to tuples of required type, - // and then apply the update function - val updateFuncLocal = updateFunc - val finalFunc = (iterator: Iterator[(K, (Iterable[V], Iterable[S]))]) => { - val i = iterator.map(t => { - val itr = t._2._2.iterator - val headOption = itr.hasNext match { - case true => Some(itr.next()) - case false => None - } - (t._1, t._2._1.toSeq, headOption) - }) - updateFuncLocal(i) - } - val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner) - val stateRDD = cogroupedRDD.mapPartitions(finalFunc, preservePartitioning) - Some(stateRDD) + computeUsingPreviousRDD (parentRDD, prevStateRDD) } case None => { // If parent RDD does not exist @@ -90,19 +92,25 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( // Try to get the parent RDD parent.getOrCompute(validTime) match { case Some(parentRDD) => { // If parent RDD exists, then compute as usual + initialRDD match { + case None => { + // Define the function for the mapPartition operation on grouped RDD; + // first map the grouped tuple to tuples of required type, + // and then apply the update function + val updateFuncLocal = updateFunc + val finalFunc = (iterator : Iterator[(K, Iterable[V])]) => { + updateFuncLocal (iterator.map (tuple => (tuple._1, tuple._2.toSeq, None))) + } - // Define the function for the mapPartition operation on grouped RDD; - // first map the grouped tuple to tuples of required type, - // and then apply the update function - val updateFuncLocal = updateFunc - val finalFunc = (iterator: Iterator[(K, Iterable[V])]) => { - updateFuncLocal(iterator.map(tuple => (tuple._1, tuple._2.toSeq, None))) + val groupedRDD = parentRDD.groupByKey (partitioner) + val sessionRDD = groupedRDD.mapPartitions (finalFunc, preservePartitioning) + // logDebug("Generating state RDD for time " + validTime + " (first)") + Some (sessionRDD) + } + case Some (initialStateRDD) => { + computeUsingPreviousRDD(parentRDD, initialStateRDD) + } } - - val groupedRDD = parentRDD.groupByKey(partitioner) - val sessionRDD = groupedRDD.mapPartitions(finalFunc, preservePartitioning) - // logDebug("Generating state RDD for time " + validTime + " (first)") - Some(sessionRDD) } case None => { // If parent RDD does not exist, then nothing to do! // logDebug("Not generating state RDD (no previous state, no parent)") diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index 4efeb8dfbe..ce645fccba 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -806,15 +806,17 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements Serializa * Performs an order-invariant comparison of lists representing two RDD streams. This allows * us to account for ordering variation within individual RDD's which occurs during windowing. */ - public static > void assertOrderInvariantEquals( + public static void assertOrderInvariantEquals( List> expected, List> actual) { + List> expectedSets = new ArrayList>(); for (List list: expected) { - Collections.sort(list); + expectedSets.add(Collections.unmodifiableSet(new HashSet(list))); } + List> actualSets = new ArrayList>(); for (List list: actual) { - Collections.sort(list); + actualSets.add(Collections.unmodifiableSet(new HashSet(list))); } - Assert.assertEquals(expected, actual); + Assert.assertEquals(expectedSets, actualSets); } @@ -1239,6 +1241,49 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements Serializa Assert.assertEquals(expected, result); } + @SuppressWarnings("unchecked") + @Test + public void testUpdateStateByKeyWithInitial() { + List>> inputData = stringIntKVStream; + + List> initial = Arrays.asList ( + new Tuple2 ("california", 1), + new Tuple2 ("new york", 2)); + + JavaRDD> tmpRDD = ssc.sparkContext().parallelize(initial); + JavaPairRDD initialRDD = JavaPairRDD.fromJavaRDD (tmpRDD); + + List>> expected = Arrays.asList( + Arrays.asList(new Tuple2("california", 5), + new Tuple2("new york", 7)), + Arrays.asList(new Tuple2("california", 15), + new Tuple2("new york", 11)), + Arrays.asList(new Tuple2("california", 15), + new Tuple2("new york", 11))); + + JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); + JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream updated = pairStream.updateStateByKey( + new Function2, Optional, Optional>() { + @Override + public Optional call(List values, Optional state) { + int out = 0; + if (state.isPresent()) { + out = out + state.get(); + } + for (Integer v: values) { + out = out + v; + } + return Optional.of(out); + } + }, new HashPartitioner(1), initialRDD); + JavaTestUtils.attachTestOutputStream(updated); + List>> result = JavaTestUtils.runStreams(ssc, 3, 3); + + assertOrderInvariantEquals(expected, result); + } + @SuppressWarnings("unchecked") @Test public void testReduceByKeyAndWindowWithInverse() { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index dbab685dc3..30a359677c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.rdd.{BlockRDD, RDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.streaming.dstream.{DStream, WindowedDStream} +import org.apache.spark.HashPartitioner class BasicOperationsSuite extends TestSuiteBase { test("map") { @@ -350,6 +351,79 @@ class BasicOperationsSuite extends TestSuiteBase { testOperation(inputData, updateStateOperation, outputData, true) } + test("updateStateByKey - simple with initial value RDD") { + val initial = Seq(("a", 1), ("c", 2)) + + val inputData = + Seq( + Seq("a"), + Seq("a", "b"), + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq() + ) + + val outputData = + Seq( + Seq(("a", 2), ("c", 2)), + Seq(("a", 3), ("b", 1), ("c", 2)), + Seq(("a", 4), ("b", 2), ("c", 3)), + Seq(("a", 5), ("b", 3), ("c", 3)), + Seq(("a", 6), ("b", 3), ("c", 3)), + Seq(("a", 6), ("b", 3), ("c", 3)) + ) + + val updateStateOperation = (s: DStream[String]) => { + val initialRDD = s.context.sparkContext.makeRDD(initial) + val updateFunc = (values: Seq[Int], state: Option[Int]) => { + Some(values.sum + state.getOrElse(0)) + } + s.map(x => (x, 1)).updateStateByKey[Int](updateFunc, + new HashPartitioner (numInputPartitions), initialRDD) + } + + testOperation(inputData, updateStateOperation, outputData, true) + } + + test("updateStateByKey - with initial value RDD") { + val initial = Seq(("a", 1), ("c", 2)) + + val inputData = + Seq( + Seq("a"), + Seq("a", "b"), + Seq("a", "b", "c"), + Seq("a", "b"), + Seq("a"), + Seq() + ) + + val outputData = + Seq( + Seq(("a", 2), ("c", 2)), + Seq(("a", 3), ("b", 1), ("c", 2)), + Seq(("a", 4), ("b", 2), ("c", 3)), + Seq(("a", 5), ("b", 3), ("c", 3)), + Seq(("a", 6), ("b", 3), ("c", 3)), + Seq(("a", 6), ("b", 3), ("c", 3)) + ) + + val updateStateOperation = (s: DStream[String]) => { + val initialRDD = s.context.sparkContext.makeRDD(initial) + val updateFunc = (values: Seq[Int], state: Option[Int]) => { + Some(values.sum + state.getOrElse(0)) + } + val newUpdateFunc = (iterator: Iterator[(String, Seq[Int], Option[Int])]) => { + iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s))) + } + s.map(x => (x, 1)).updateStateByKey[Int](newUpdateFunc, + new HashPartitioner (numInputPartitions), true, initialRDD) + } + + testOperation(inputData, updateStateOperation, outputData, true) + } + test("updateStateByKey - object lifecycle") { val inputData = Seq( -- cgit v1.2.3