aboutsummaryrefslogtreecommitdiff
path: root/streaming
diff options
context:
space:
mode:
Diffstat (limited to 'streaming')
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala14
1 files changed, 12 insertions, 2 deletions
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala
index 994309ddd0..056248ccc7 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala
@@ -264,9 +264,19 @@ private[python] class PythonTransformed2DStream(
*/
private[python] class PythonStateDStream(
parent: DStream[Array[Byte]],
- reduceFunc: PythonTransformFunction)
+ reduceFunc: PythonTransformFunction,
+ initialRDD: Option[RDD[Array[Byte]]])
extends PythonDStream(parent, reduceFunc) {
+ def this(
+ parent: DStream[Array[Byte]],
+ reduceFunc: PythonTransformFunction) = this(parent, reduceFunc, None)
+
+ def this(
+ parent: DStream[Array[Byte]],
+ reduceFunc: PythonTransformFunction,
+ initialRDD: JavaRDD[Array[Byte]]) = this(parent, reduceFunc, Some(initialRDD.rdd))
+
super.persist(StorageLevel.MEMORY_ONLY)
override val mustCheckpoint = true
@@ -274,7 +284,7 @@ private[python] class PythonStateDStream(
val lastState = getOrCompute(validTime - slideDuration)
val rdd = parent.getOrCompute(validTime)
if (rdd.isDefined) {
- func(lastState, rdd, validTime)
+ func(lastState.orElse(initialRDD), rdd, validTime)
} else {
lastState
}