aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/src/main/python/streaming/stateful_network_wordcount.py5
-rw-r--r--python/pyspark/streaming/dstream.py13
-rw-r--r--python/pyspark/streaming/tests.py20
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala14
4 files changed, 47 insertions, 5 deletions
diff --git a/examples/src/main/python/streaming/stateful_network_wordcount.py b/examples/src/main/python/streaming/stateful_network_wordcount.py
index 16ef646b7c..f8bbc659c2 100644
--- a/examples/src/main/python/streaming/stateful_network_wordcount.py
+++ b/examples/src/main/python/streaming/stateful_network_wordcount.py
@@ -44,13 +44,16 @@ if __name__ == "__main__":
ssc = StreamingContext(sc, 1)
ssc.checkpoint("checkpoint")
+ # RDD with initial state (key, value) pairs
+ initialStateRDD = sc.parallelize([(u'hello', 1), (u'world', 1)])
+
def updateFunc(new_values, last_sum):
return sum(new_values) + (last_sum or 0)
lines = ssc.socketTextStream(sys.argv[1], int(sys.argv[2]))
running_counts = lines.flatMap(lambda line: line.split(" "))\
.map(lambda word: (word, 1))\
- .updateStateByKey(updateFunc)
+ .updateStateByKey(updateFunc, initialRDD=initialStateRDD)
running_counts.pprint()
diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py
index acec850f02..f61137cb88 100644
--- a/python/pyspark/streaming/dstream.py
+++ b/python/pyspark/streaming/dstream.py
@@ -568,7 +568,7 @@ class DStream(object):
self._ssc._jduration(slideDuration))
return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer)
- def updateStateByKey(self, updateFunc, numPartitions=None):
+ def updateStateByKey(self, updateFunc, numPartitions=None, initialRDD=None):
"""
Return a new "state" DStream where the state for each key is updated by applying
the given function on the previous state of the key and the new values of the key.
@@ -579,6 +579,9 @@ class DStream(object):
if numPartitions is None:
numPartitions = self._sc.defaultParallelism
+ if initialRDD and not isinstance(initialRDD, RDD):
+ initialRDD = self._sc.parallelize(initialRDD)
+
def reduceFunc(t, a, b):
if a is None:
g = b.groupByKey(numPartitions).mapValues(lambda vs: (list(vs), None))
@@ -590,7 +593,13 @@ class DStream(object):
jreduceFunc = TransformFunction(self._sc, reduceFunc,
self._sc.serializer, self._jrdd_deserializer)
- dstream = self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc)
+ if initialRDD:
+ initialRDD = initialRDD._reserialize(self._jrdd_deserializer)
+ dstream = self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc,
+ initialRDD._jrdd)
+ else:
+ dstream = self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc)
+
return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer)
diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py
index a2bfd79e1a..4949cd68e3 100644
--- a/python/pyspark/streaming/tests.py
+++ b/python/pyspark/streaming/tests.py
@@ -403,6 +403,26 @@ class BasicOperationTests(PySparkStreamingTestCase):
expected = [[('k', v)] for v in expected]
self._test_func(input, func, expected)
+ def test_update_state_by_key_initial_rdd(self):
+
+ def updater(vs, s):
+ if not s:
+ s = []
+ s.extend(vs)
+ return s
+
+ initial = [('k', [0, 1])]
+ initial = self.sc.parallelize(initial, 1)
+
+ input = [[('k', i)] for i in range(2, 5)]
+
+ def func(dstream):
+ return dstream.updateStateByKey(updater, initialRDD=initial)
+
+ expected = [[0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]]
+ expected = [[('k', v)] for v in expected]
+ self._test_func(input, func, expected)
+
def test_failed_func(self):
# Test failure in
# TransformFunction.apply(rdd: Option[RDD[_]], time: Time)
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala
index 994309ddd0..056248ccc7 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala
@@ -264,9 +264,19 @@ private[python] class PythonTransformed2DStream(
*/
private[python] class PythonStateDStream(
parent: DStream[Array[Byte]],
- reduceFunc: PythonTransformFunction)
+ reduceFunc: PythonTransformFunction,
+ initialRDD: Option[RDD[Array[Byte]]])
extends PythonDStream(parent, reduceFunc) {
+ def this(
+ parent: DStream[Array[Byte]],
+ reduceFunc: PythonTransformFunction) = this(parent, reduceFunc, None)
+
+ def this(
+ parent: DStream[Array[Byte]],
+ reduceFunc: PythonTransformFunction,
+ initialRDD: JavaRDD[Array[Byte]]) = this(parent, reduceFunc, Some(initialRDD.rdd))
+
super.persist(StorageLevel.MEMORY_ONLY)
override val mustCheckpoint = true
@@ -274,7 +284,7 @@ private[python] class PythonStateDStream(
val lastState = getOrCompute(validTime - slideDuration)
val rdd = parent.getOrCompute(validTime)
if (rdd.isDefined) {
- func(lastState, rdd, validTime)
+ func(lastState.orElse(initialRDD), rdd, validTime)
} else {
lastState
}