From 6a6c1fc5c807ba4e8aba3e260537aa527ff5d46a Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 10 Dec 2015 14:21:15 -0800 Subject: [SPARK-11713] [PYSPARK] [STREAMING] Initial RDD updateStateByKey for PySpark Adding ability to define an initial state RDD for use with updateStateByKey PySpark. Added unit test and changed stateful_network_wordcount example to use initial RDD. Author: Bryan Cutler Closes #10082 from BryanCutler/initial-rdd-updateStateByKey-SPARK-11713. --- .../python/streaming/stateful_network_wordcount.py | 5 ++++- python/pyspark/streaming/dstream.py | 13 +++++++++++-- python/pyspark/streaming/tests.py | 20 ++++++++++++++++++++ .../spark/streaming/api/python/PythonDStream.scala | 14 ++++++++++++-- 4 files changed, 47 insertions(+), 5 deletions(-) diff --git a/examples/src/main/python/streaming/stateful_network_wordcount.py b/examples/src/main/python/streaming/stateful_network_wordcount.py index 16ef646b7c..f8bbc659c2 100644 --- a/examples/src/main/python/streaming/stateful_network_wordcount.py +++ b/examples/src/main/python/streaming/stateful_network_wordcount.py @@ -44,13 +44,16 @@ if __name__ == "__main__": ssc = StreamingContext(sc, 1) ssc.checkpoint("checkpoint") + # RDD with initial state (key, value) pairs + initialStateRDD = sc.parallelize([(u'hello', 1), (u'world', 1)]) + def updateFunc(new_values, last_sum): return sum(new_values) + (last_sum or 0) lines = ssc.socketTextStream(sys.argv[1], int(sys.argv[2])) running_counts = lines.flatMap(lambda line: line.split(" "))\ .map(lambda word: (word, 1))\ - .updateStateByKey(updateFunc) + .updateStateByKey(updateFunc, initialRDD=initialStateRDD) running_counts.pprint() diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index acec850f02..f61137cb88 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -568,7 +568,7 @@ class DStream(object): self._ssc._jduration(slideDuration)) return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer) - def updateStateByKey(self, updateFunc, numPartitions=None): + def updateStateByKey(self, updateFunc, numPartitions=None, initialRDD=None): """ Return a new "state" DStream where the state for each key is updated by applying the given function on the previous state of the key and the new values of the key. @@ -579,6 +579,9 @@ class DStream(object): if numPartitions is None: numPartitions = self._sc.defaultParallelism + if initialRDD and not isinstance(initialRDD, RDD): + initialRDD = self._sc.parallelize(initialRDD) + def reduceFunc(t, a, b): if a is None: g = b.groupByKey(numPartitions).mapValues(lambda vs: (list(vs), None)) @@ -590,7 +593,13 @@ class DStream(object): jreduceFunc = TransformFunction(self._sc, reduceFunc, self._sc.serializer, self._jrdd_deserializer) - dstream = self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc) + if initialRDD: + initialRDD = initialRDD._reserialize(self._jrdd_deserializer) + dstream = self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc, + initialRDD._jrdd) + else: + dstream = self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc) + return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index a2bfd79e1a..4949cd68e3 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -403,6 +403,26 @@ class BasicOperationTests(PySparkStreamingTestCase): expected = [[('k', v)] for v in expected] self._test_func(input, func, expected) + def test_update_state_by_key_initial_rdd(self): + + def updater(vs, s): + if not s: + s = [] + s.extend(vs) + return s + + initial = [('k', [0, 1])] + initial = self.sc.parallelize(initial, 1) + + input = [[('k', i)] for i in range(2, 5)] + + def func(dstream): + return dstream.updateStateByKey(updater, initialRDD=initial) + + expected = [[0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]] + expected = [[('k', v)] for v in expected] + self._test_func(input, func, expected) + def test_failed_func(self): # Test failure in # TransformFunction.apply(rdd: Option[RDD[_]], time: Time) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 994309ddd0..056248ccc7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -264,9 +264,19 @@ private[python] class PythonTransformed2DStream( */ private[python] class PythonStateDStream( parent: DStream[Array[Byte]], - reduceFunc: PythonTransformFunction) + reduceFunc: PythonTransformFunction, + initialRDD: Option[RDD[Array[Byte]]]) extends PythonDStream(parent, reduceFunc) { + def this( + parent: DStream[Array[Byte]], + reduceFunc: PythonTransformFunction) = this(parent, reduceFunc, None) + + def this( + parent: DStream[Array[Byte]], + reduceFunc: PythonTransformFunction, + initialRDD: JavaRDD[Array[Byte]]) = this(parent, reduceFunc, Some(initialRDD.rdd)) + super.persist(StorageLevel.MEMORY_ONLY) override val mustCheckpoint = true @@ -274,7 +284,7 @@ private[python] class PythonStateDStream( val lastState = getOrCompute(validTime - slideDuration) val rdd = parent.getOrCompute(validTime) if (rdd.isDefined) { - func(lastState, rdd, validTime) + func(lastState.orElse(initialRDD), rdd, validTime) } else { lastState } -- cgit v1.2.3