aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala10
-rw-r--r--examples/src/main/python/streaming/hdfs_wordcount.py49
-rw-r--r--examples/src/main/python/streaming/network_wordcount.py48
-rw-r--r--examples/src/main/python/streaming/stateful_network_wordcount.py57
-rw-r--r--python/docs/epytext.py2
-rw-r--r--python/docs/index.rst1
-rw-r--r--python/docs/pyspark.rst3
-rw-r--r--python/pyspark/context.py8
-rw-r--r--python/pyspark/serializers.py3
-rw-r--r--python/pyspark/streaming/__init__.py21
-rw-r--r--python/pyspark/streaming/context.py325
-rw-r--r--python/pyspark/streaming/dstream.py621
-rw-r--r--python/pyspark/streaming/tests.py545
-rw-r--r--python/pyspark/streaming/util.py128
-rwxr-xr-xpython/run-tests7
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala2
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala316
17 files changed, 2133 insertions, 13 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index c74f86548e..4acbdf9d5e 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -25,8 +25,6 @@ import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collectio
import scala.collection.JavaConversions._
import scala.collection.mutable
import scala.language.existentials
-import scala.reflect.ClassTag
-import scala.util.{Try, Success, Failure}
import net.razorvine.pickle.{Pickler, Unpickler}
@@ -42,7 +40,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
private[spark] class PythonRDD(
- parent: RDD[_],
+ @transient parent: RDD[_],
command: Array[Byte],
envVars: JMap[String, String],
pythonIncludes: JList[String],
@@ -55,9 +53,9 @@ private[spark] class PythonRDD(
val bufferSize = conf.getInt("spark.buffer.size", 65536)
val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true)
- override def getPartitions = parent.partitions
+ override def getPartitions = firstParent.partitions
- override val partitioner = if (preservePartitoning) parent.partitioner else None
+ override val partitioner = if (preservePartitoning) firstParent.partitioner else None
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
val startTime = System.currentTimeMillis
@@ -234,7 +232,7 @@ private[spark] class PythonRDD(
dataOut.writeInt(command.length)
dataOut.write(command)
// Data values
- PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut)
+ PythonRDD.writeIteratorToStream(firstParent.iterator(split, context), dataOut)
dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
dataOut.flush()
} catch {
diff --git a/examples/src/main/python/streaming/hdfs_wordcount.py b/examples/src/main/python/streaming/hdfs_wordcount.py
new file mode 100644
index 0000000000..40faff0ccc
--- /dev/null
+++ b/examples/src/main/python/streaming/hdfs_wordcount.py
@@ -0,0 +1,49 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+ Counts words in new text files created in the given directory
+ Usage: hdfs_wordcount.py <directory>
+ <directory> is the directory that Spark Streaming will use to find and read new text files.
+
+ To run this on your local machine on directory `localdir`, run this example
+ $ bin/spark-submit examples/src/main/python/streaming/network_wordcount.py localdir
+
+ Then create a text file in `localdir` and the words in the file will get counted.
+"""
+
+import sys
+
+from pyspark import SparkContext
+from pyspark.streaming import StreamingContext
+
+if __name__ == "__main__":
+ if len(sys.argv) != 2:
+ print >> sys.stderr, "Usage: hdfs_wordcount.py <directory>"
+ exit(-1)
+
+ sc = SparkContext(appName="PythonStreamingHDFSWordCount")
+ ssc = StreamingContext(sc, 1)
+
+ lines = ssc.textFileStream(sys.argv[1])
+ counts = lines.flatMap(lambda line: line.split(" "))\
+ .map(lambda x: (x, 1))\
+ .reduceByKey(lambda a, b: a+b)
+ counts.pprint()
+
+ ssc.start()
+ ssc.awaitTermination()
diff --git a/examples/src/main/python/streaming/network_wordcount.py b/examples/src/main/python/streaming/network_wordcount.py
new file mode 100644
index 0000000000..cfa9c1ff5b
--- /dev/null
+++ b/examples/src/main/python/streaming/network_wordcount.py
@@ -0,0 +1,48 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+ Counts words in UTF8 encoded, '\n' delimited text received from the network every second.
+ Usage: network_wordcount.py <hostname> <port>
+ <hostname> and <port> describe the TCP server that Spark Streaming would connect to receive data.
+
+ To run this on your local machine, you need to first run a Netcat server
+ `$ nc -lk 9999`
+ and then run the example
+ `$ bin/spark-submit examples/src/main/python/streaming/network_wordcount.py localhost 9999`
+"""
+
+import sys
+
+from pyspark import SparkContext
+from pyspark.streaming import StreamingContext
+
+if __name__ == "__main__":
+ if len(sys.argv) != 3:
+ print >> sys.stderr, "Usage: network_wordcount.py <hostname> <port>"
+ exit(-1)
+ sc = SparkContext(appName="PythonStreamingNetworkWordCount")
+ ssc = StreamingContext(sc, 1)
+
+ lines = ssc.socketTextStream(sys.argv[1], int(sys.argv[2]))
+ counts = lines.flatMap(lambda line: line.split(" "))\
+ .map(lambda word: (word, 1))\
+ .reduceByKey(lambda a, b: a+b)
+ counts.pprint()
+
+ ssc.start()
+ ssc.awaitTermination()
diff --git a/examples/src/main/python/streaming/stateful_network_wordcount.py b/examples/src/main/python/streaming/stateful_network_wordcount.py
new file mode 100644
index 0000000000..18a9a5a452
--- /dev/null
+++ b/examples/src/main/python/streaming/stateful_network_wordcount.py
@@ -0,0 +1,57 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+ Counts words in UTF8 encoded, '\n' delimited text received from the
+ network every second.
+
+ Usage: stateful_network_wordcount.py <hostname> <port>
+ <hostname> and <port> describe the TCP server that Spark Streaming
+ would connect to receive data.
+
+ To run this on your local machine, you need to first run a Netcat server
+ `$ nc -lk 9999`
+ and then run the example
+ `$ bin/spark-submit examples/src/main/python/streaming/stateful_network_wordcount.py \
+ localhost 9999`
+"""
+
+import sys
+
+from pyspark import SparkContext
+from pyspark.streaming import StreamingContext
+
+if __name__ == "__main__":
+ if len(sys.argv) != 3:
+ print >> sys.stderr, "Usage: stateful_network_wordcount.py <hostname> <port>"
+ exit(-1)
+ sc = SparkContext(appName="PythonStreamingStatefulNetworkWordCount")
+ ssc = StreamingContext(sc, 1)
+ ssc.checkpoint("checkpoint")
+
+ def updateFunc(new_values, last_sum):
+ return sum(new_values) + (last_sum or 0)
+
+ lines = ssc.socketTextStream(sys.argv[1], int(sys.argv[2]))
+ running_counts = lines.flatMap(lambda line: line.split(" "))\
+ .map(lambda word: (word, 1))\
+ .updateStateByKey(updateFunc)
+
+ running_counts.pprint()
+
+ ssc.start()
+ ssc.awaitTermination()
diff --git a/python/docs/epytext.py b/python/docs/epytext.py
index 61d731bff5..19fefbfc05 100644
--- a/python/docs/epytext.py
+++ b/python/docs/epytext.py
@@ -5,7 +5,7 @@ RULES = (
(r"L{([\w.()]+)}", r":class:`\1`"),
(r"[LC]{(\w+\.\w+)\(\)}", r":func:`\1`"),
(r"C{([\w.()]+)}", r":class:`\1`"),
- (r"[IBCM]{(.+)}", r"`\1`"),
+ (r"[IBCM]{([^}]+)}", r"`\1`"),
('pyspark.rdd.RDD', 'RDD'),
)
diff --git a/python/docs/index.rst b/python/docs/index.rst
index d66e051b15..703bef644d 100644
--- a/python/docs/index.rst
+++ b/python/docs/index.rst
@@ -13,6 +13,7 @@ Contents:
pyspark
pyspark.sql
+ pyspark.streaming
pyspark.mllib
diff --git a/python/docs/pyspark.rst b/python/docs/pyspark.rst
index a68bd62433..e81be3b6cb 100644
--- a/python/docs/pyspark.rst
+++ b/python/docs/pyspark.rst
@@ -7,8 +7,9 @@ Subpackages
.. toctree::
:maxdepth: 1
- pyspark.mllib
pyspark.sql
+ pyspark.streaming
+ pyspark.mllib
Contents
--------
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 85c04624da..89d2e2e5b4 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -68,7 +68,7 @@ class SparkContext(object):
def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
environment=None, batchSize=0, serializer=PickleSerializer(), conf=None,
- gateway=None):
+ gateway=None, jsc=None):
"""
Create a new SparkContext. At least the master and app name should be set,
either through the named parameters here or through C{conf}.
@@ -104,14 +104,14 @@ class SparkContext(object):
SparkContext._ensure_initialized(self, gateway=gateway)
try:
self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
- conf)
+ conf, jsc)
except:
# If an error occurs, clean up in order to allow future SparkContext creation:
self.stop()
raise
def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
- conf):
+ conf, jsc):
self.environment = environment or {}
self._conf = conf or SparkConf(_jvm=self._jvm)
self._batchSize = batchSize # -1 represents an unlimited batch size
@@ -154,7 +154,7 @@ class SparkContext(object):
self.environment[varName] = v
# Create the Java SparkContext through Py4J
- self._jsc = self._initialize_context(self._conf._jconf)
+ self._jsc = jsc or self._initialize_context(self._conf._jconf)
# Create a single Accumulator in Java that we'll send all our updates through;
# they will be passed back to us through a TCP server
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 3d1a34b281..08a0f0d8ff 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -114,6 +114,9 @@ class Serializer(object):
def __repr__(self):
return "<%s object>" % self.__class__.__name__
+ def __hash__(self):
+ return hash(str(self))
+
class FramedSerializer(Serializer):
diff --git a/python/pyspark/streaming/__init__.py b/python/pyspark/streaming/__init__.py
new file mode 100644
index 0000000000..d2644a1d4f
--- /dev/null
+++ b/python/pyspark/streaming/__init__.py
@@ -0,0 +1,21 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.streaming.context import StreamingContext
+from pyspark.streaming.dstream import DStream
+
+__all__ = ['StreamingContext', 'DStream']
diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py
new file mode 100644
index 0000000000..dc9dc41121
--- /dev/null
+++ b/python/pyspark/streaming/context.py
@@ -0,0 +1,325 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import os
+import sys
+
+from py4j.java_collections import ListConverter
+from py4j.java_gateway import java_import, JavaObject
+
+from pyspark import RDD, SparkConf
+from pyspark.serializers import UTF8Deserializer, CloudPickleSerializer
+from pyspark.context import SparkContext
+from pyspark.storagelevel import StorageLevel
+from pyspark.streaming.dstream import DStream
+from pyspark.streaming.util import TransformFunction, TransformFunctionSerializer
+
+__all__ = ["StreamingContext"]
+
+
+def _daemonize_callback_server():
+ """
+ Hack Py4J to daemonize callback server
+
+ The thread of callback server has daemon=False, it will block the driver
+ from exiting if it's not shutdown. The following code replace `start()`
+ of CallbackServer with a new version, which set daemon=True for this
+ thread.
+
+ Also, it will update the port number (0) with real port
+ """
+ # TODO: create a patch for Py4J
+ import socket
+ import py4j.java_gateway
+ logger = py4j.java_gateway.logger
+ from py4j.java_gateway import Py4JNetworkError
+ from threading import Thread
+
+ def start(self):
+ """Starts the CallbackServer. This method should be called by the
+ client instead of run()."""
+ self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR,
+ 1)
+ try:
+ self.server_socket.bind((self.address, self.port))
+ if not self.port:
+ # update port with real port
+ self.port = self.server_socket.getsockname()[1]
+ except Exception as e:
+ msg = 'An error occurred while trying to start the callback server: %s' % e
+ logger.exception(msg)
+ raise Py4JNetworkError(msg)
+
+ # Maybe thread needs to be cleanup up?
+ self.thread = Thread(target=self.run)
+ self.thread.daemon = True
+ self.thread.start()
+
+ py4j.java_gateway.CallbackServer.start = start
+
+
+class StreamingContext(object):
+ """
+ Main entry point for Spark Streaming functionality. A StreamingContext
+ represents the connection to a Spark cluster, and can be used to create
+ L{DStream} various input sources. It can be from an existing L{SparkContext}.
+ After creating and transforming DStreams, the streaming computation can
+ be started and stopped using `context.start()` and `context.stop()`,
+ respectively. `context.awaitTransformation()` allows the current thread
+ to wait for the termination of the context by `stop()` or by an exception.
+ """
+ _transformerSerializer = None
+
+ def __init__(self, sparkContext, batchDuration=None, jssc=None):
+ """
+ Create a new StreamingContext.
+
+ @param sparkContext: L{SparkContext} object.
+ @param batchDuration: the time interval (in seconds) at which streaming
+ data will be divided into batches
+ """
+
+ self._sc = sparkContext
+ self._jvm = self._sc._jvm
+ self._jssc = jssc or self._initialize_context(self._sc, batchDuration)
+
+ def _initialize_context(self, sc, duration):
+ self._ensure_initialized()
+ return self._jvm.JavaStreamingContext(sc._jsc, self._jduration(duration))
+
+ def _jduration(self, seconds):
+ """
+ Create Duration object given number of seconds
+ """
+ return self._jvm.Duration(int(seconds * 1000))
+
+ @classmethod
+ def _ensure_initialized(cls):
+ SparkContext._ensure_initialized()
+ gw = SparkContext._gateway
+
+ java_import(gw.jvm, "org.apache.spark.streaming.*")
+ java_import(gw.jvm, "org.apache.spark.streaming.api.java.*")
+ java_import(gw.jvm, "org.apache.spark.streaming.api.python.*")
+
+ # start callback server
+ # getattr will fallback to JVM, so we cannot test by hasattr()
+ if "_callback_server" not in gw.__dict__:
+ _daemonize_callback_server()
+ # use random port
+ gw._start_callback_server(0)
+ # gateway with real port
+ gw._python_proxy_port = gw._callback_server.port
+ # get the GatewayServer object in JVM by ID
+ jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client)
+ # update the port of CallbackClient with real port
+ gw.jvm.PythonDStream.updatePythonGatewayPort(jgws, gw._python_proxy_port)
+
+ # register serializer for TransformFunction
+ # it happens before creating SparkContext when loading from checkpointing
+ cls._transformerSerializer = TransformFunctionSerializer(
+ SparkContext._active_spark_context, CloudPickleSerializer(), gw)
+
+ @classmethod
+ def getOrCreate(cls, checkpointPath, setupFunc):
+ """
+ Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
+ If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be
+ recreated from the checkpoint data. If the data does not exist, then the provided setupFunc
+ will be used to create a JavaStreamingContext.
+
+ @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program
+ @param setupFunc Function to create a new JavaStreamingContext and setup DStreams
+ """
+ # TODO: support checkpoint in HDFS
+ if not os.path.exists(checkpointPath) or not os.listdir(checkpointPath):
+ ssc = setupFunc()
+ ssc.checkpoint(checkpointPath)
+ return ssc
+
+ cls._ensure_initialized()
+ gw = SparkContext._gateway
+
+ try:
+ jssc = gw.jvm.JavaStreamingContext(checkpointPath)
+ except Exception:
+ print >>sys.stderr, "failed to load StreamingContext from checkpoint"
+ raise
+
+ jsc = jssc.sparkContext()
+ conf = SparkConf(_jconf=jsc.getConf())
+ sc = SparkContext(conf=conf, gateway=gw, jsc=jsc)
+ # update ctx in serializer
+ SparkContext._active_spark_context = sc
+ cls._transformerSerializer.ctx = sc
+ return StreamingContext(sc, None, jssc)
+
+ @property
+ def sparkContext(self):
+ """
+ Return SparkContext which is associated with this StreamingContext.
+ """
+ return self._sc
+
+ def start(self):
+ """
+ Start the execution of the streams.
+ """
+ self._jssc.start()
+
+ def awaitTermination(self, timeout=None):
+ """
+ Wait for the execution to stop.
+ @param timeout: time to wait in seconds
+ """
+ if timeout is None:
+ self._jssc.awaitTermination()
+ else:
+ self._jssc.awaitTermination(int(timeout * 1000))
+
+ def stop(self, stopSparkContext=True, stopGraceFully=False):
+ """
+ Stop the execution of the streams, with option of ensuring all
+ received data has been processed.
+
+ @param stopSparkContext: Stop the associated SparkContext or not
+ @param stopGracefully: Stop gracefully by waiting for the processing
+ of all received data to be completed
+ """
+ self._jssc.stop(stopSparkContext, stopGraceFully)
+ if stopSparkContext:
+ self._sc.stop()
+
+ def remember(self, duration):
+ """
+ Set each DStreams in this context to remember RDDs it generated
+ in the last given duration. DStreams remember RDDs only for a
+ limited duration of time and releases them for garbage collection.
+ This method allows the developer to specify how to long to remember
+ the RDDs (if the developer wishes to query old data outside the
+ DStream computation).
+
+ @param duration: Minimum duration (in seconds) that each DStream
+ should remember its RDDs
+ """
+ self._jssc.remember(self._jduration(duration))
+
+ def checkpoint(self, directory):
+ """
+ Sets the context to periodically checkpoint the DStream operations for master
+ fault-tolerance. The graph will be checkpointed every batch interval.
+
+ @param directory: HDFS-compatible directory where the checkpoint data
+ will be reliably stored
+ """
+ self._jssc.checkpoint(directory)
+
+ def socketTextStream(self, hostname, port, storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2):
+ """
+ Create an input from TCP source hostname:port. Data is received using
+ a TCP socket and receive byte is interpreted as UTF8 encoded ``\\n`` delimited
+ lines.
+
+ @param hostname: Hostname to connect to for receiving data
+ @param port: Port to connect to for receiving data
+ @param storageLevel: Storage level to use for storing the received objects
+ """
+ jlevel = self._sc._getJavaStorageLevel(storageLevel)
+ return DStream(self._jssc.socketTextStream(hostname, port, jlevel), self,
+ UTF8Deserializer())
+
+ def textFileStream(self, directory):
+ """
+ Create an input stream that monitors a Hadoop-compatible file system
+ for new files and reads them as text files. Files must be wrriten to the
+ monitored directory by "moving" them from another location within the same
+ file system. File names starting with . are ignored.
+ """
+ return DStream(self._jssc.textFileStream(directory), self, UTF8Deserializer())
+
+ def _check_serializers(self, rdds):
+ # make sure they have same serializer
+ if len(set(rdd._jrdd_deserializer for rdd in rdds)) > 1:
+ for i in range(len(rdds)):
+ # reset them to sc.serializer
+ rdds[i] = rdds[i]._reserialize()
+
+ def queueStream(self, rdds, oneAtATime=True, default=None):
+ """
+ Create an input stream from an queue of RDDs or list. In each batch,
+ it will process either one or all of the RDDs returned by the queue.
+
+ NOTE: changes to the queue after the stream is created will not be recognized.
+
+ @param rdds: Queue of RDDs
+ @param oneAtATime: pick one rdd each time or pick all of them once.
+ @param default: The default rdd if no more in rdds
+ """
+ if default and not isinstance(default, RDD):
+ default = self._sc.parallelize(default)
+
+ if not rdds and default:
+ rdds = [rdds]
+
+ if rdds and not isinstance(rdds[0], RDD):
+ rdds = [self._sc.parallelize(input) for input in rdds]
+ self._check_serializers(rdds)
+
+ jrdds = ListConverter().convert([r._jrdd for r in rdds],
+ SparkContext._gateway._gateway_client)
+ queue = self._jvm.PythonDStream.toRDDQueue(jrdds)
+ if default:
+ default = default._reserialize(rdds[0]._jrdd_deserializer)
+ jdstream = self._jssc.queueStream(queue, oneAtATime, default._jrdd)
+ else:
+ jdstream = self._jssc.queueStream(queue, oneAtATime)
+ return DStream(jdstream, self, rdds[0]._jrdd_deserializer)
+
+ def transform(self, dstreams, transformFunc):
+ """
+ Create a new DStream in which each RDD is generated by applying
+ a function on RDDs of the DStreams. The order of the JavaRDDs in
+ the transform function parameter will be the same as the order
+ of corresponding DStreams in the list.
+ """
+ jdstreams = ListConverter().convert([d._jdstream for d in dstreams],
+ SparkContext._gateway._gateway_client)
+ # change the final serializer to sc.serializer
+ func = TransformFunction(self._sc,
+ lambda t, *rdds: transformFunc(rdds).map(lambda x: x),
+ *[d._jrdd_deserializer for d in dstreams])
+ jfunc = self._jvm.TransformFunction(func)
+ jdstream = self._jssc.transform(jdstreams, jfunc)
+ return DStream(jdstream, self, self._sc.serializer)
+
+ def union(self, *dstreams):
+ """
+ Create a unified DStream from multiple DStreams of the same
+ type and same slide duration.
+ """
+ if not dstreams:
+ raise ValueError("should have at least one DStream to union")
+ if len(dstreams) == 1:
+ return dstreams[0]
+ if len(set(s._jrdd_deserializer for s in dstreams)) > 1:
+ raise ValueError("All DStreams should have same serializer")
+ if len(set(s._slideDuration for s in dstreams)) > 1:
+ raise ValueError("All DStreams should have same slide duration")
+ first = dstreams[0]
+ jrest = ListConverter().convert([d._jdstream for d in dstreams[1:]],
+ SparkContext._gateway._gateway_client)
+ return DStream(self._jssc.union(first._jdstream, jrest), self, first._jrdd_deserializer)
diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py
new file mode 100644
index 0000000000..5ae5cf07f0
--- /dev/null
+++ b/python/pyspark/streaming/dstream.py
@@ -0,0 +1,621 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from itertools import chain, ifilter, imap
+import operator
+import time
+from datetime import datetime
+
+from py4j.protocol import Py4JJavaError
+
+from pyspark import RDD
+from pyspark.storagelevel import StorageLevel
+from pyspark.streaming.util import rddToFileName, TransformFunction
+from pyspark.rdd import portable_hash
+from pyspark.resultiterable import ResultIterable
+
+__all__ = ["DStream"]
+
+
+class DStream(object):
+ """
+ A Discretized Stream (DStream), the basic abstraction in Spark Streaming,
+ is a continuous sequence of RDDs (of the same type) representing a
+ continuous stream of data (see L{RDD} in the Spark core documentation
+ for more details on RDDs).
+
+ DStreams can either be created from live data (such as, data from TCP
+ sockets, Kafka, Flume, etc.) using a L{StreamingContext} or it can be
+ generated by transforming existing DStreams using operations such as
+ `map`, `window` and `reduceByKeyAndWindow`. While a Spark Streaming
+ program is running, each DStream periodically generates a RDD, either
+ from live data or by transforming the RDD generated by a parent DStream.
+
+ DStreams internally is characterized by a few basic properties:
+ - A list of other DStreams that the DStream depends on
+ - A time interval at which the DStream generates an RDD
+ - A function that is used to generate an RDD after each time interval
+ """
+ def __init__(self, jdstream, ssc, jrdd_deserializer):
+ self._jdstream = jdstream
+ self._ssc = ssc
+ self._sc = ssc._sc
+ self._jrdd_deserializer = jrdd_deserializer
+ self.is_cached = False
+ self.is_checkpointed = False
+
+ def context(self):
+ """
+ Return the StreamingContext associated with this DStream
+ """
+ return self._ssc
+
+ def count(self):
+ """
+ Return a new DStream in which each RDD has a single element
+ generated by counting each RDD of this DStream.
+ """
+ return self.mapPartitions(lambda i: [sum(1 for _ in i)]).reduce(operator.add)
+
+ def filter(self, f):
+ """
+ Return a new DStream containing only the elements that satisfy predicate.
+ """
+ def func(iterator):
+ return ifilter(f, iterator)
+ return self.mapPartitions(func, True)
+
+ def flatMap(self, f, preservesPartitioning=False):
+ """
+ Return a new DStream by applying a function to all elements of
+ this DStream, and then flattening the results
+ """
+ def func(s, iterator):
+ return chain.from_iterable(imap(f, iterator))
+ return self.mapPartitionsWithIndex(func, preservesPartitioning)
+
+ def map(self, f, preservesPartitioning=False):
+ """
+ Return a new DStream by applying a function to each element of DStream.
+ """
+ def func(iterator):
+ return imap(f, iterator)
+ return self.mapPartitions(func, preservesPartitioning)
+
+ def mapPartitions(self, f, preservesPartitioning=False):
+ """
+ Return a new DStream in which each RDD is generated by applying
+ mapPartitions() to each RDDs of this DStream.
+ """
+ def func(s, iterator):
+ return f(iterator)
+ return self.mapPartitionsWithIndex(func, preservesPartitioning)
+
+ def mapPartitionsWithIndex(self, f, preservesPartitioning=False):
+ """
+ Return a new DStream in which each RDD is generated by applying
+ mapPartitionsWithIndex() to each RDDs of this DStream.
+ """
+ return self.transform(lambda rdd: rdd.mapPartitionsWithIndex(f, preservesPartitioning))
+
+ def reduce(self, func):
+ """
+ Return a new DStream in which each RDD has a single element
+ generated by reducing each RDD of this DStream.
+ """
+ return self.map(lambda x: (None, x)).reduceByKey(func, 1).map(lambda x: x[1])
+
+ def reduceByKey(self, func, numPartitions=None):
+ """
+ Return a new DStream by applying reduceByKey to each RDD.
+ """
+ if numPartitions is None:
+ numPartitions = self._sc.defaultParallelism
+ return self.combineByKey(lambda x: x, func, func, numPartitions)
+
+ def combineByKey(self, createCombiner, mergeValue, mergeCombiners,
+ numPartitions=None):
+ """
+ Return a new DStream by applying combineByKey to each RDD.
+ """
+ if numPartitions is None:
+ numPartitions = self._sc.defaultParallelism
+
+ def func(rdd):
+ return rdd.combineByKey(createCombiner, mergeValue, mergeCombiners, numPartitions)
+ return self.transform(func)
+
+ def partitionBy(self, numPartitions, partitionFunc=portable_hash):
+ """
+ Return a copy of the DStream in which each RDD are partitioned
+ using the specified partitioner.
+ """
+ return self.transform(lambda rdd: rdd.partitionBy(numPartitions, partitionFunc))
+
+ def foreachRDD(self, func):
+ """
+ Apply a function to each RDD in this DStream.
+ """
+ if func.func_code.co_argcount == 1:
+ old_func = func
+ func = lambda t, rdd: old_func(rdd)
+ jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer)
+ api = self._ssc._jvm.PythonDStream
+ api.callForeachRDD(self._jdstream, jfunc)
+
+ def pprint(self):
+ """
+ Print the first ten elements of each RDD generated in this DStream.
+ """
+ def takeAndPrint(time, rdd):
+ taken = rdd.take(11)
+ print "-------------------------------------------"
+ print "Time: %s" % time
+ print "-------------------------------------------"
+ for record in taken[:10]:
+ print record
+ if len(taken) > 10:
+ print "..."
+ print
+
+ self.foreachRDD(takeAndPrint)
+
+ def mapValues(self, f):
+ """
+ Return a new DStream by applying a map function to the value of
+ each key-value pairs in this DStream without changing the key.
+ """
+ map_values_fn = lambda (k, v): (k, f(v))
+ return self.map(map_values_fn, preservesPartitioning=True)
+
+ def flatMapValues(self, f):
+ """
+ Return a new DStream by applying a flatmap function to the value
+ of each key-value pairs in this DStream without changing the key.
+ """
+ flat_map_fn = lambda (k, v): ((k, x) for x in f(v))
+ return self.flatMap(flat_map_fn, preservesPartitioning=True)
+
+ def glom(self):
+ """
+ Return a new DStream in which RDD is generated by applying glom()
+ to RDD of this DStream.
+ """
+ def func(iterator):
+ yield list(iterator)
+ return self.mapPartitions(func)
+
+ def cache(self):
+ """
+ Persist the RDDs of this DStream with the default storage level
+ (C{MEMORY_ONLY_SER}).
+ """
+ self.is_cached = True
+ self.persist(StorageLevel.MEMORY_ONLY_SER)
+ return self
+
+ def persist(self, storageLevel):
+ """
+ Persist the RDDs of this DStream with the given storage level
+ """
+ self.is_cached = True
+ javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel)
+ self._jdstream.persist(javaStorageLevel)
+ return self
+
+ def checkpoint(self, interval):
+ """
+ Enable periodic checkpointing of RDDs of this DStream
+
+ @param interval: time in seconds, after each period of that, generated
+ RDD will be checkpointed
+ """
+ self.is_checkpointed = True
+ self._jdstream.checkpoint(self._ssc._jduration(interval))
+ return self
+
+ def groupByKey(self, numPartitions=None):
+ """
+ Return a new DStream by applying groupByKey on each RDD.
+ """
+ if numPartitions is None:
+ numPartitions = self._sc.defaultParallelism
+ return self.transform(lambda rdd: rdd.groupByKey(numPartitions))
+
+ def countByValue(self):
+ """
+ Return a new DStream in which each RDD contains the counts of each
+ distinct value in each RDD of this DStream.
+ """
+ return self.map(lambda x: (x, None)).reduceByKey(lambda x, y: None).count()
+
+ def saveAsTextFiles(self, prefix, suffix=None):
+ """
+ Save each RDD in this DStream as at text file, using string
+ representation of elements.
+ """
+ def saveAsTextFile(t, rdd):
+ path = rddToFileName(prefix, suffix, t)
+ try:
+ rdd.saveAsTextFile(path)
+ except Py4JJavaError as e:
+ # after recovered from checkpointing, the foreachRDD may
+ # be called twice
+ if 'FileAlreadyExistsException' not in str(e):
+ raise
+ return self.foreachRDD(saveAsTextFile)
+
+ # TODO: uncomment this until we have ssc.pickleFileStream()
+ # def saveAsPickleFiles(self, prefix, suffix=None):
+ # """
+ # Save each RDD in this DStream as at binary file, the elements are
+ # serialized by pickle.
+ # """
+ # def saveAsPickleFile(t, rdd):
+ # path = rddToFileName(prefix, suffix, t)
+ # try:
+ # rdd.saveAsPickleFile(path)
+ # except Py4JJavaError as e:
+ # # after recovered from checkpointing, the foreachRDD may
+ # # be called twice
+ # if 'FileAlreadyExistsException' not in str(e):
+ # raise
+ # return self.foreachRDD(saveAsPickleFile)
+
+ def transform(self, func):
+ """
+ Return a new DStream in which each RDD is generated by applying a function
+ on each RDD of this DStream.
+
+ `func` can have one argument of `rdd`, or have two arguments of
+ (`time`, `rdd`)
+ """
+ if func.func_code.co_argcount == 1:
+ oldfunc = func
+ func = lambda t, rdd: oldfunc(rdd)
+ assert func.func_code.co_argcount == 2, "func should take one or two arguments"
+ return TransformedDStream(self, func)
+
+ def transformWith(self, func, other, keepSerializer=False):
+ """
+ Return a new DStream in which each RDD is generated by applying a function
+ on each RDD of this DStream and 'other' DStream.
+
+ `func` can have two arguments of (`rdd_a`, `rdd_b`) or have three
+ arguments of (`time`, `rdd_a`, `rdd_b`)
+ """
+ if func.func_code.co_argcount == 2:
+ oldfunc = func
+ func = lambda t, a, b: oldfunc(a, b)
+ assert func.func_code.co_argcount == 3, "func should take two or three arguments"
+ jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer, other._jrdd_deserializer)
+ dstream = self._sc._jvm.PythonTransformed2DStream(self._jdstream.dstream(),
+ other._jdstream.dstream(), jfunc)
+ jrdd_serializer = self._jrdd_deserializer if keepSerializer else self._sc.serializer
+ return DStream(dstream.asJavaDStream(), self._ssc, jrdd_serializer)
+
+ def repartition(self, numPartitions):
+ """
+ Return a new DStream with an increased or decreased level of parallelism.
+ """
+ return self.transform(lambda rdd: rdd.repartition(numPartitions))
+
+ @property
+ def _slideDuration(self):
+ """
+ Return the slideDuration in seconds of this DStream
+ """
+ return self._jdstream.dstream().slideDuration().milliseconds() / 1000.0
+
+ def union(self, other):
+ """
+ Return a new DStream by unifying data of another DStream with this DStream.
+
+ @param other: Another DStream having the same interval (i.e., slideDuration)
+ as this DStream.
+ """
+ if self._slideDuration != other._slideDuration:
+ raise ValueError("the two DStream should have same slide duration")
+ return self.transformWith(lambda a, b: a.union(b), other, True)
+
+ def cogroup(self, other, numPartitions=None):
+ """
+ Return a new DStream by applying 'cogroup' between RDDs of this
+ DStream and `other` DStream.
+
+ Hash partitioning is used to generate the RDDs with `numPartitions` partitions.
+ """
+ if numPartitions is None:
+ numPartitions = self._sc.defaultParallelism
+ return self.transformWith(lambda a, b: a.cogroup(b, numPartitions), other)
+
+ def join(self, other, numPartitions=None):
+ """
+ Return a new DStream by applying 'join' between RDDs of this DStream and
+ `other` DStream.
+
+ Hash partitioning is used to generate the RDDs with `numPartitions`
+ partitions.
+ """
+ if numPartitions is None:
+ numPartitions = self._sc.defaultParallelism
+ return self.transformWith(lambda a, b: a.join(b, numPartitions), other)
+
+ def leftOuterJoin(self, other, numPartitions=None):
+ """
+ Return a new DStream by applying 'left outer join' between RDDs of this DStream and
+ `other` DStream.
+
+ Hash partitioning is used to generate the RDDs with `numPartitions`
+ partitions.
+ """
+ if numPartitions is None:
+ numPartitions = self._sc.defaultParallelism
+ return self.transformWith(lambda a, b: a.leftOuterJoin(b, numPartitions), other)
+
+ def rightOuterJoin(self, other, numPartitions=None):
+ """
+ Return a new DStream by applying 'right outer join' between RDDs of this DStream and
+ `other` DStream.
+
+ Hash partitioning is used to generate the RDDs with `numPartitions`
+ partitions.
+ """
+ if numPartitions is None:
+ numPartitions = self._sc.defaultParallelism
+ return self.transformWith(lambda a, b: a.rightOuterJoin(b, numPartitions), other)
+
+ def fullOuterJoin(self, other, numPartitions=None):
+ """
+ Return a new DStream by applying 'full outer join' between RDDs of this DStream and
+ `other` DStream.
+
+ Hash partitioning is used to generate the RDDs with `numPartitions`
+ partitions.
+ """
+ if numPartitions is None:
+ numPartitions = self._sc.defaultParallelism
+ return self.transformWith(lambda a, b: a.fullOuterJoin(b, numPartitions), other)
+
+ def _jtime(self, timestamp):
+ """ Convert datetime or unix_timestamp into Time
+ """
+ if isinstance(timestamp, datetime):
+ timestamp = time.mktime(timestamp.timetuple())
+ return self._sc._jvm.Time(long(timestamp * 1000))
+
+ def slice(self, begin, end):
+ """
+ Return all the RDDs between 'begin' to 'end' (both included)
+
+ `begin`, `end` could be datetime.datetime() or unix_timestamp
+ """
+ jrdds = self._jdstream.slice(self._jtime(begin), self._jtime(end))
+ return [RDD(jrdd, self._sc, self._jrdd_deserializer) for jrdd in jrdds]
+
+ def _validate_window_param(self, window, slide):
+ duration = self._jdstream.dstream().slideDuration().milliseconds()
+ if int(window * 1000) % duration != 0:
+ raise ValueError("windowDuration must be multiple of the slide duration (%d ms)"
+ % duration)
+ if slide and int(slide * 1000) % duration != 0:
+ raise ValueError("slideDuration must be multiple of the slide duration (%d ms)"
+ % duration)
+
+ def window(self, windowDuration, slideDuration=None):
+ """
+ Return a new DStream in which each RDD contains all the elements in seen in a
+ sliding window of time over this DStream.
+
+ @param windowDuration: width of the window; must be a multiple of this DStream's
+ batching interval
+ @param slideDuration: sliding interval of the window (i.e., the interval after which
+ the new DStream will generate RDDs); must be a multiple of this
+ DStream's batching interval
+ """
+ self._validate_window_param(windowDuration, slideDuration)
+ d = self._ssc._jduration(windowDuration)
+ if slideDuration is None:
+ return DStream(self._jdstream.window(d), self._ssc, self._jrdd_deserializer)
+ s = self._ssc._jduration(slideDuration)
+ return DStream(self._jdstream.window(d, s), self._ssc, self._jrdd_deserializer)
+
+ def reduceByWindow(self, reduceFunc, invReduceFunc, windowDuration, slideDuration):
+ """
+ Return a new DStream in which each RDD has a single element generated by reducing all
+ elements in a sliding window over this DStream.
+
+ if `invReduceFunc` is not None, the reduction is done incrementally
+ using the old window's reduced value :
+ 1. reduce the new values that entered the window (e.g., adding new counts)
+ 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts)
+ This is more efficient than `invReduceFunc` is None.
+
+ @param reduceFunc: associative reduce function
+ @param invReduceFunc: inverse reduce function of `reduceFunc`
+ @param windowDuration: width of the window; must be a multiple of this DStream's
+ batching interval
+ @param slideDuration: sliding interval of the window (i.e., the interval after which
+ the new DStream will generate RDDs); must be a multiple of this
+ DStream's batching interval
+ """
+ keyed = self.map(lambda x: (1, x))
+ reduced = keyed.reduceByKeyAndWindow(reduceFunc, invReduceFunc,
+ windowDuration, slideDuration, 1)
+ return reduced.map(lambda (k, v): v)
+
+ def countByWindow(self, windowDuration, slideDuration):
+ """
+ Return a new DStream in which each RDD has a single element generated
+ by counting the number of elements in a window over this DStream.
+ windowDuration and slideDuration are as defined in the window() operation.
+
+ This is equivalent to window(windowDuration, slideDuration).count(),
+ but will be more efficient if window is large.
+ """
+ return self.map(lambda x: 1).reduceByWindow(operator.add, operator.sub,
+ windowDuration, slideDuration)
+
+ def countByValueAndWindow(self, windowDuration, slideDuration, numPartitions=None):
+ """
+ Return a new DStream in which each RDD contains the count of distinct elements in
+ RDDs in a sliding window over this DStream.
+
+ @param windowDuration: width of the window; must be a multiple of this DStream's
+ batching interval
+ @param slideDuration: sliding interval of the window (i.e., the interval after which
+ the new DStream will generate RDDs); must be a multiple of this
+ DStream's batching interval
+ @param numPartitions: number of partitions of each RDD in the new DStream.
+ """
+ keyed = self.map(lambda x: (x, 1))
+ counted = keyed.reduceByKeyAndWindow(operator.add, operator.sub,
+ windowDuration, slideDuration, numPartitions)
+ return counted.filter(lambda (k, v): v > 0).count()
+
+ def groupByKeyAndWindow(self, windowDuration, slideDuration, numPartitions=None):
+ """
+ Return a new DStream by applying `groupByKey` over a sliding window.
+ Similar to `DStream.groupByKey()`, but applies it over a sliding window.
+
+ @param windowDuration: width of the window; must be a multiple of this DStream's
+ batching interval
+ @param slideDuration: sliding interval of the window (i.e., the interval after which
+ the new DStream will generate RDDs); must be a multiple of this
+ DStream's batching interval
+ @param numPartitions: Number of partitions of each RDD in the new DStream.
+ """
+ ls = self.mapValues(lambda x: [x])
+ grouped = ls.reduceByKeyAndWindow(lambda a, b: a.extend(b) or a, lambda a, b: a[len(b):],
+ windowDuration, slideDuration, numPartitions)
+ return grouped.mapValues(ResultIterable)
+
+ def reduceByKeyAndWindow(self, func, invFunc, windowDuration, slideDuration=None,
+ numPartitions=None, filterFunc=None):
+ """
+ Return a new DStream by applying incremental `reduceByKey` over a sliding window.
+
+ The reduced value of over a new window is calculated using the old window's reduce value :
+ 1. reduce the new values that entered the window (e.g., adding new counts)
+ 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts)
+
+ `invFunc` can be None, then it will reduce all the RDDs in window, could be slower
+ than having `invFunc`.
+
+ @param reduceFunc: associative reduce function
+ @param invReduceFunc: inverse function of `reduceFunc`
+ @param windowDuration: width of the window; must be a multiple of this DStream's
+ batching interval
+ @param slideDuration: sliding interval of the window (i.e., the interval after which
+ the new DStream will generate RDDs); must be a multiple of this
+ DStream's batching interval
+ @param numPartitions: number of partitions of each RDD in the new DStream.
+ @param filterFunc: function to filter expired key-value pairs;
+ only pairs that satisfy the function are retained
+ set this to null if you do not want to filter
+ """
+ self._validate_window_param(windowDuration, slideDuration)
+ if numPartitions is None:
+ numPartitions = self._sc.defaultParallelism
+
+ reduced = self.reduceByKey(func, numPartitions)
+
+ def reduceFunc(t, a, b):
+ b = b.reduceByKey(func, numPartitions)
+ r = a.union(b).reduceByKey(func, numPartitions) if a else b
+ if filterFunc:
+ r = r.filter(filterFunc)
+ return r
+
+ def invReduceFunc(t, a, b):
+ b = b.reduceByKey(func, numPartitions)
+ joined = a.leftOuterJoin(b, numPartitions)
+ return joined.mapValues(lambda (v1, v2): invFunc(v1, v2) if v2 is not None else v1)
+
+ jreduceFunc = TransformFunction(self._sc, reduceFunc, reduced._jrdd_deserializer)
+ if invReduceFunc:
+ jinvReduceFunc = TransformFunction(self._sc, invReduceFunc, reduced._jrdd_deserializer)
+ else:
+ jinvReduceFunc = None
+ if slideDuration is None:
+ slideDuration = self._slideDuration
+ dstream = self._sc._jvm.PythonReducedWindowedDStream(reduced._jdstream.dstream(),
+ jreduceFunc, jinvReduceFunc,
+ self._ssc._jduration(windowDuration),
+ self._ssc._jduration(slideDuration))
+ return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer)
+
+ def updateStateByKey(self, updateFunc, numPartitions=None):
+ """
+ Return a new "state" DStream where the state for each key is updated by applying
+ the given function on the previous state of the key and the new values of the key.
+
+ @param updateFunc: State update function. If this function returns None, then
+ corresponding state key-value pair will be eliminated.
+ """
+ if numPartitions is None:
+ numPartitions = self._sc.defaultParallelism
+
+ def reduceFunc(t, a, b):
+ if a is None:
+ g = b.groupByKey(numPartitions).mapValues(lambda vs: (list(vs), None))
+ else:
+ g = a.cogroup(b, numPartitions)
+ g = g.mapValues(lambda (va, vb): (list(vb), list(va)[0] if len(va) else None))
+ state = g.mapValues(lambda (vs, s): updateFunc(vs, s))
+ return state.filter(lambda (k, v): v is not None)
+
+ jreduceFunc = TransformFunction(self._sc, reduceFunc,
+ self._sc.serializer, self._jrdd_deserializer)
+ dstream = self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc)
+ return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer)
+
+
+class TransformedDStream(DStream):
+ """
+ TransformedDStream is an DStream generated by an Python function
+ transforming each RDD of an DStream to another RDDs.
+
+ Multiple continuous transformations of DStream can be combined into
+ one transformation.
+ """
+ def __init__(self, prev, func):
+ self._ssc = prev._ssc
+ self._sc = self._ssc._sc
+ self._jrdd_deserializer = self._sc.serializer
+ self.is_cached = False
+ self.is_checkpointed = False
+ self._jdstream_val = None
+
+ if (isinstance(prev, TransformedDStream) and
+ not prev.is_cached and not prev.is_checkpointed):
+ prev_func = prev.func
+ self.func = lambda t, rdd: func(t, prev_func(t, rdd))
+ self.prev = prev.prev
+ else:
+ self.prev = prev
+ self.func = func
+
+ @property
+ def _jdstream(self):
+ if self._jdstream_val is not None:
+ return self._jdstream_val
+
+ jfunc = TransformFunction(self._sc, self.func, self.prev._jrdd_deserializer)
+ dstream = self._sc._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(), jfunc)
+ self._jdstream_val = dstream.asJavaDStream()
+ return self._jdstream_val
diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py
new file mode 100644
index 0000000000..a8d876d0fa
--- /dev/null
+++ b/python/pyspark/streaming/tests.py
@@ -0,0 +1,545 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+from itertools import chain
+import time
+import operator
+import unittest
+import tempfile
+
+from pyspark.context import SparkConf, SparkContext, RDD
+from pyspark.streaming.context import StreamingContext
+
+
+class PySparkStreamingTestCase(unittest.TestCase):
+
+ timeout = 10 # seconds
+ duration = 1
+
+ def setUp(self):
+ class_name = self.__class__.__name__
+ conf = SparkConf().set("spark.default.parallelism", 1)
+ self.sc = SparkContext(appName=class_name, conf=conf)
+ self.sc.setCheckpointDir("/tmp")
+ # TODO: decrease duration to speed up tests
+ self.ssc = StreamingContext(self.sc, self.duration)
+
+ def tearDown(self):
+ self.ssc.stop()
+
+ def wait_for(self, result, n):
+ start_time = time.time()
+ while len(result) < n and time.time() - start_time < self.timeout:
+ time.sleep(0.01)
+ if len(result) < n:
+ print "timeout after", self.timeout
+
+ def _take(self, dstream, n):
+ """
+ Return the first `n` elements in the stream (will start and stop).
+ """
+ results = []
+
+ def take(_, rdd):
+ if rdd and len(results) < n:
+ results.extend(rdd.take(n - len(results)))
+
+ dstream.foreachRDD(take)
+
+ self.ssc.start()
+ self.wait_for(results, n)
+ return results
+
+ def _collect(self, dstream, n, block=True):
+ """
+ Collect each RDDs into the returned list.
+
+ :return: list, which will have the collected items.
+ """
+ result = []
+
+ def get_output(_, rdd):
+ if rdd and len(result) < n:
+ r = rdd.collect()
+ if r:
+ result.append(r)
+
+ dstream.foreachRDD(get_output)
+
+ if not block:
+ return result
+
+ self.ssc.start()
+ self.wait_for(result, n)
+ return result
+
+ def _test_func(self, input, func, expected, sort=False, input2=None):
+ """
+ @param input: dataset for the test. This should be list of lists.
+ @param func: wrapped function. This function should return PythonDStream object.
+ @param expected: expected output for this testcase.
+ """
+ if not isinstance(input[0], RDD):
+ input = [self.sc.parallelize(d, 1) for d in input]
+ input_stream = self.ssc.queueStream(input)
+ if input2 and not isinstance(input2[0], RDD):
+ input2 = [self.sc.parallelize(d, 1) for d in input2]
+ input_stream2 = self.ssc.queueStream(input2) if input2 is not None else None
+
+ # Apply test function to stream.
+ if input2:
+ stream = func(input_stream, input_stream2)
+ else:
+ stream = func(input_stream)
+
+ result = self._collect(stream, len(expected))
+ if sort:
+ self._sort_result_based_on_key(result)
+ self._sort_result_based_on_key(expected)
+ self.assertEqual(expected, result)
+
+ def _sort_result_based_on_key(self, outputs):
+ """Sort the list based on first value."""
+ for output in outputs:
+ output.sort(key=lambda x: x[0])
+
+
+class BasicOperationTests(PySparkStreamingTestCase):
+
+ def test_map(self):
+ """Basic operation test for DStream.map."""
+ input = [range(1, 5), range(5, 9), range(9, 13)]
+
+ def func(dstream):
+ return dstream.map(str)
+ expected = map(lambda x: map(str, x), input)
+ self._test_func(input, func, expected)
+
+ def test_flatMap(self):
+ """Basic operation test for DStream.faltMap."""
+ input = [range(1, 5), range(5, 9), range(9, 13)]
+
+ def func(dstream):
+ return dstream.flatMap(lambda x: (x, x * 2))
+ expected = map(lambda x: list(chain.from_iterable((map(lambda y: [y, y * 2], x)))),
+ input)
+ self._test_func(input, func, expected)
+
+ def test_filter(self):
+ """Basic operation test for DStream.filter."""
+ input = [range(1, 5), range(5, 9), range(9, 13)]
+
+ def func(dstream):
+ return dstream.filter(lambda x: x % 2 == 0)
+ expected = map(lambda x: filter(lambda y: y % 2 == 0, x), input)
+ self._test_func(input, func, expected)
+
+ def test_count(self):
+ """Basic operation test for DStream.count."""
+ input = [range(5), range(10), range(20)]
+
+ def func(dstream):
+ return dstream.count()
+ expected = map(lambda x: [len(x)], input)
+ self._test_func(input, func, expected)
+
+ def test_reduce(self):
+ """Basic operation test for DStream.reduce."""
+ input = [range(1, 5), range(5, 9), range(9, 13)]
+
+ def func(dstream):
+ return dstream.reduce(operator.add)
+ expected = map(lambda x: [reduce(operator.add, x)], input)
+ self._test_func(input, func, expected)
+
+ def test_reduceByKey(self):
+ """Basic operation test for DStream.reduceByKey."""
+ input = [[("a", 1), ("a", 1), ("b", 1), ("b", 1)],
+ [("", 1), ("", 1), ("", 1), ("", 1)],
+ [(1, 1), (1, 1), (2, 1), (2, 1), (3, 1)]]
+
+ def func(dstream):
+ return dstream.reduceByKey(operator.add)
+ expected = [[("a", 2), ("b", 2)], [("", 4)], [(1, 2), (2, 2), (3, 1)]]
+ self._test_func(input, func, expected, sort=True)
+
+ def test_mapValues(self):
+ """Basic operation test for DStream.mapValues."""
+ input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)],
+ [("", 4), (1, 1), (2, 2), (3, 3)],
+ [(1, 1), (2, 1), (3, 1), (4, 1)]]
+
+ def func(dstream):
+ return dstream.mapValues(lambda x: x + 10)
+ expected = [[("a", 12), ("b", 12), ("c", 11), ("d", 11)],
+ [("", 14), (1, 11), (2, 12), (3, 13)],
+ [(1, 11), (2, 11), (3, 11), (4, 11)]]
+ self._test_func(input, func, expected, sort=True)
+
+ def test_flatMapValues(self):
+ """Basic operation test for DStream.flatMapValues."""
+ input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)],
+ [("", 4), (1, 1), (2, 1), (3, 1)],
+ [(1, 1), (2, 1), (3, 1), (4, 1)]]
+
+ def func(dstream):
+ return dstream.flatMapValues(lambda x: (x, x + 10))
+ expected = [[("a", 2), ("a", 12), ("b", 2), ("b", 12),
+ ("c", 1), ("c", 11), ("d", 1), ("d", 11)],
+ [("", 4), ("", 14), (1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11)],
+ [(1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11), (4, 1), (4, 11)]]
+ self._test_func(input, func, expected)
+
+ def test_glom(self):
+ """Basic operation test for DStream.glom."""
+ input = [range(1, 5), range(5, 9), range(9, 13)]
+ rdds = [self.sc.parallelize(r, 2) for r in input]
+
+ def func(dstream):
+ return dstream.glom()
+ expected = [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]]
+ self._test_func(rdds, func, expected)
+
+ def test_mapPartitions(self):
+ """Basic operation test for DStream.mapPartitions."""
+ input = [range(1, 5), range(5, 9), range(9, 13)]
+ rdds = [self.sc.parallelize(r, 2) for r in input]
+
+ def func(dstream):
+ def f(iterator):
+ yield sum(iterator)
+ return dstream.mapPartitions(f)
+ expected = [[3, 7], [11, 15], [19, 23]]
+ self._test_func(rdds, func, expected)
+
+ def test_countByValue(self):
+ """Basic operation test for DStream.countByValue."""
+ input = [range(1, 5) * 2, range(5, 7) + range(5, 9), ["a", "a", "b", ""]]
+
+ def func(dstream):
+ return dstream.countByValue()
+ expected = [[4], [4], [3]]
+ self._test_func(input, func, expected)
+
+ def test_groupByKey(self):
+ """Basic operation test for DStream.groupByKey."""
+ input = [[(1, 1), (2, 1), (3, 1), (4, 1)],
+ [(1, 1), (1, 1), (1, 1), (2, 1), (2, 1), (3, 1)],
+ [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1), ("", 1)]]
+
+ def func(dstream):
+ return dstream.groupByKey().mapValues(list)
+
+ expected = [[(1, [1]), (2, [1]), (3, [1]), (4, [1])],
+ [(1, [1, 1, 1]), (2, [1, 1]), (3, [1])],
+ [("a", [1, 1]), ("b", [1]), ("", [1, 1, 1])]]
+ self._test_func(input, func, expected, sort=True)
+
+ def test_combineByKey(self):
+ """Basic operation test for DStream.combineByKey."""
+ input = [[(1, 1), (2, 1), (3, 1), (4, 1)],
+ [(1, 1), (1, 1), (1, 1), (2, 1), (2, 1), (3, 1)],
+ [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1), ("", 1)]]
+
+ def func(dstream):
+ def add(a, b):
+ return a + str(b)
+ return dstream.combineByKey(str, add, add)
+ expected = [[(1, "1"), (2, "1"), (3, "1"), (4, "1")],
+ [(1, "111"), (2, "11"), (3, "1")],
+ [("a", "11"), ("b", "1"), ("", "111")]]
+ self._test_func(input, func, expected, sort=True)
+
+ def test_repartition(self):
+ input = [range(1, 5), range(5, 9)]
+ rdds = [self.sc.parallelize(r, 2) for r in input]
+
+ def func(dstream):
+ return dstream.repartition(1).glom()
+ expected = [[[1, 2, 3, 4]], [[5, 6, 7, 8]]]
+ self._test_func(rdds, func, expected)
+
+ def test_union(self):
+ input1 = [range(3), range(5), range(6)]
+ input2 = [range(3, 6), range(5, 6)]
+
+ def func(d1, d2):
+ return d1.union(d2)
+
+ expected = [range(6), range(6), range(6)]
+ self._test_func(input1, func, expected, input2=input2)
+
+ def test_cogroup(self):
+ input = [[(1, 1), (2, 1), (3, 1)],
+ [(1, 1), (1, 1), (1, 1), (2, 1)],
+ [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1)]]
+ input2 = [[(1, 2)],
+ [(4, 1)],
+ [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 2)]]
+
+ def func(d1, d2):
+ return d1.cogroup(d2).mapValues(lambda vs: tuple(map(list, vs)))
+
+ expected = [[(1, ([1], [2])), (2, ([1], [])), (3, ([1], []))],
+ [(1, ([1, 1, 1], [])), (2, ([1], [])), (4, ([], [1]))],
+ [("a", ([1, 1], [1, 1])), ("b", ([1], [1])), ("", ([1, 1], [1, 2]))]]
+ self._test_func(input, func, expected, sort=True, input2=input2)
+
+ def test_join(self):
+ input = [[('a', 1), ('b', 2)]]
+ input2 = [[('b', 3), ('c', 4)]]
+
+ def func(a, b):
+ return a.join(b)
+
+ expected = [[('b', (2, 3))]]
+ self._test_func(input, func, expected, True, input2)
+
+ def test_left_outer_join(self):
+ input = [[('a', 1), ('b', 2)]]
+ input2 = [[('b', 3), ('c', 4)]]
+
+ def func(a, b):
+ return a.leftOuterJoin(b)
+
+ expected = [[('a', (1, None)), ('b', (2, 3))]]
+ self._test_func(input, func, expected, True, input2)
+
+ def test_right_outer_join(self):
+ input = [[('a', 1), ('b', 2)]]
+ input2 = [[('b', 3), ('c', 4)]]
+
+ def func(a, b):
+ return a.rightOuterJoin(b)
+
+ expected = [[('b', (2, 3)), ('c', (None, 4))]]
+ self._test_func(input, func, expected, True, input2)
+
+ def test_full_outer_join(self):
+ input = [[('a', 1), ('b', 2)]]
+ input2 = [[('b', 3), ('c', 4)]]
+
+ def func(a, b):
+ return a.fullOuterJoin(b)
+
+ expected = [[('a', (1, None)), ('b', (2, 3)), ('c', (None, 4))]]
+ self._test_func(input, func, expected, True, input2)
+
+ def test_update_state_by_key(self):
+
+ def updater(vs, s):
+ if not s:
+ s = []
+ s.extend(vs)
+ return s
+
+ input = [[('k', i)] for i in range(5)]
+
+ def func(dstream):
+ return dstream.updateStateByKey(updater)
+
+ expected = [[0], [0, 1], [0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]]
+ expected = [[('k', v)] for v in expected]
+ self._test_func(input, func, expected)
+
+
+class WindowFunctionTests(PySparkStreamingTestCase):
+
+ timeout = 20
+
+ def test_window(self):
+ input = [range(1), range(2), range(3), range(4), range(5)]
+
+ def func(dstream):
+ return dstream.window(3, 1).count()
+
+ expected = [[1], [3], [6], [9], [12], [9], [5]]
+ self._test_func(input, func, expected)
+
+ def test_count_by_window(self):
+ input = [range(1), range(2), range(3), range(4), range(5)]
+
+ def func(dstream):
+ return dstream.countByWindow(3, 1)
+
+ expected = [[1], [3], [6], [9], [12], [9], [5]]
+ self._test_func(input, func, expected)
+
+ def test_count_by_window_large(self):
+ input = [range(1), range(2), range(3), range(4), range(5), range(6)]
+
+ def func(dstream):
+ return dstream.countByWindow(5, 1)
+
+ expected = [[1], [3], [6], [10], [15], [20], [18], [15], [11], [6]]
+ self._test_func(input, func, expected)
+
+ def test_count_by_value_and_window(self):
+ input = [range(1), range(2), range(3), range(4), range(5), range(6)]
+
+ def func(dstream):
+ return dstream.countByValueAndWindow(5, 1)
+
+ expected = [[1], [2], [3], [4], [5], [6], [6], [6], [6], [6]]
+ self._test_func(input, func, expected)
+
+ def test_group_by_key_and_window(self):
+ input = [[('a', i)] for i in range(5)]
+
+ def func(dstream):
+ return dstream.groupByKeyAndWindow(3, 1).mapValues(list)
+
+ expected = [[('a', [0])], [('a', [0, 1])], [('a', [0, 1, 2])], [('a', [1, 2, 3])],
+ [('a', [2, 3, 4])], [('a', [3, 4])], [('a', [4])]]
+ self._test_func(input, func, expected)
+
+ def test_reduce_by_invalid_window(self):
+ input1 = [range(3), range(5), range(1), range(6)]
+ d1 = self.ssc.queueStream(input1)
+ self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 0.1, 0.1))
+ self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 1, 0.1))
+
+
+class StreamingContextTests(PySparkStreamingTestCase):
+
+ duration = 0.1
+
+ def _add_input_stream(self):
+ inputs = map(lambda x: range(1, x), range(101))
+ stream = self.ssc.queueStream(inputs)
+ self._collect(stream, 1, block=False)
+
+ def test_stop_only_streaming_context(self):
+ self._add_input_stream()
+ self.ssc.start()
+ self.ssc.stop(False)
+ self.assertEqual(len(self.sc.parallelize(range(5), 5).glom().collect()), 5)
+
+ def test_stop_multiple_times(self):
+ self._add_input_stream()
+ self.ssc.start()
+ self.ssc.stop()
+ self.ssc.stop()
+
+ def test_queue_stream(self):
+ input = [range(i + 1) for i in range(3)]
+ dstream = self.ssc.queueStream(input)
+ result = self._collect(dstream, 3)
+ self.assertEqual(input, result)
+
+ def test_text_file_stream(self):
+ d = tempfile.mkdtemp()
+ self.ssc = StreamingContext(self.sc, self.duration)
+ dstream2 = self.ssc.textFileStream(d).map(int)
+ result = self._collect(dstream2, 2, block=False)
+ self.ssc.start()
+ for name in ('a', 'b'):
+ time.sleep(1)
+ with open(os.path.join(d, name), "w") as f:
+ f.writelines(["%d\n" % i for i in range(10)])
+ self.wait_for(result, 2)
+ self.assertEqual([range(10), range(10)], result)
+
+ def test_union(self):
+ input = [range(i + 1) for i in range(3)]
+ dstream = self.ssc.queueStream(input)
+ dstream2 = self.ssc.queueStream(input)
+ dstream3 = self.ssc.union(dstream, dstream2)
+ result = self._collect(dstream3, 3)
+ expected = [i * 2 for i in input]
+ self.assertEqual(expected, result)
+
+ def test_transform(self):
+ dstream1 = self.ssc.queueStream([[1]])
+ dstream2 = self.ssc.queueStream([[2]])
+ dstream3 = self.ssc.queueStream([[3]])
+
+ def func(rdds):
+ rdd1, rdd2, rdd3 = rdds
+ return rdd2.union(rdd3).union(rdd1)
+
+ dstream = self.ssc.transform([dstream1, dstream2, dstream3], func)
+
+ self.assertEqual([2, 3, 1], self._take(dstream, 3))
+
+
+class CheckpointTests(PySparkStreamingTestCase):
+
+ def setUp(self):
+ pass
+
+ def test_get_or_create(self):
+ inputd = tempfile.mkdtemp()
+ outputd = tempfile.mkdtemp() + "/"
+
+ def updater(vs, s):
+ return sum(vs, s or 0)
+
+ def setup():
+ conf = SparkConf().set("spark.default.parallelism", 1)
+ sc = SparkContext(conf=conf)
+ ssc = StreamingContext(sc, 0.5)
+ dstream = ssc.textFileStream(inputd).map(lambda x: (x, 1))
+ wc = dstream.updateStateByKey(updater)
+ wc.map(lambda x: "%s,%d" % x).saveAsTextFiles(outputd + "test")
+ wc.checkpoint(.5)
+ return ssc
+
+ cpd = tempfile.mkdtemp("test_streaming_cps")
+ self.ssc = ssc = StreamingContext.getOrCreate(cpd, setup)
+ ssc.start()
+
+ def check_output(n):
+ while not os.listdir(outputd):
+ time.sleep(0.1)
+ time.sleep(1) # make sure mtime is larger than the previous one
+ with open(os.path.join(inputd, str(n)), 'w') as f:
+ f.writelines(["%d\n" % i for i in range(10)])
+
+ while True:
+ p = os.path.join(outputd, max(os.listdir(outputd)))
+ if '_SUCCESS' not in os.listdir(p):
+ # not finished
+ time.sleep(0.01)
+ continue
+ ordd = ssc.sparkContext.textFile(p).map(lambda line: line.split(","))
+ d = ordd.values().map(int).collect()
+ if not d:
+ time.sleep(0.01)
+ continue
+ self.assertEqual(10, len(d))
+ s = set(d)
+ self.assertEqual(1, len(s))
+ m = s.pop()
+ if n > m:
+ continue
+ self.assertEqual(n, m)
+ break
+
+ check_output(1)
+ check_output(2)
+ ssc.stop(True, True)
+
+ time.sleep(1)
+ self.ssc = ssc = StreamingContext.getOrCreate(cpd, setup)
+ ssc.start()
+ check_output(3)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py
new file mode 100644
index 0000000000..86ee5aa04f
--- /dev/null
+++ b/python/pyspark/streaming/util.py
@@ -0,0 +1,128 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import time
+from datetime import datetime
+import traceback
+
+from pyspark import SparkContext, RDD
+
+
+class TransformFunction(object):
+ """
+ This class wraps a function RDD[X] -> RDD[Y] that was passed to
+ DStream.transform(), allowing it to be called from Java via Py4J's
+ callback server.
+
+ Java calls this function with a sequence of JavaRDDs and this function
+ returns a single JavaRDD pointer back to Java.
+ """
+ _emptyRDD = None
+
+ def __init__(self, ctx, func, *deserializers):
+ self.ctx = ctx
+ self.func = func
+ self.deserializers = deserializers
+
+ def call(self, milliseconds, jrdds):
+ try:
+ if self.ctx is None:
+ self.ctx = SparkContext._active_spark_context
+ if not self.ctx or not self.ctx._jsc:
+ # stopped
+ return
+
+ # extend deserializers with the first one
+ sers = self.deserializers
+ if len(sers) < len(jrdds):
+ sers += (sers[0],) * (len(jrdds) - len(sers))
+
+ rdds = [RDD(jrdd, self.ctx, ser) if jrdd else None
+ for jrdd, ser in zip(jrdds, sers)]
+ t = datetime.fromtimestamp(milliseconds / 1000.0)
+ r = self.func(t, *rdds)
+ if r:
+ return r._jrdd
+ except Exception:
+ traceback.print_exc()
+
+ def __repr__(self):
+ return "TransformFunction(%s)" % self.func
+
+ class Java:
+ implements = ['org.apache.spark.streaming.api.python.PythonTransformFunction']
+
+
+class TransformFunctionSerializer(object):
+ """
+ This class implements a serializer for PythonTransformFunction Java
+ objects.
+
+ This is necessary because the Java PythonTransformFunction objects are
+ actually Py4J references to Python objects and thus are not directly
+ serializable. When Java needs to serialize a PythonTransformFunction,
+ it uses this class to invoke Python, which returns the serialized function
+ as a byte array.
+ """
+ def __init__(self, ctx, serializer, gateway=None):
+ self.ctx = ctx
+ self.serializer = serializer
+ self.gateway = gateway or self.ctx._gateway
+ self.gateway.jvm.PythonDStream.registerSerializer(self)
+
+ def dumps(self, id):
+ try:
+ func = self.gateway.gateway_property.pool[id]
+ return bytearray(self.serializer.dumps((func.func, func.deserializers)))
+ except Exception:
+ traceback.print_exc()
+
+ def loads(self, bytes):
+ try:
+ f, deserializers = self.serializer.loads(str(bytes))
+ return TransformFunction(self.ctx, f, *deserializers)
+ except Exception:
+ traceback.print_exc()
+
+ def __repr__(self):
+ return "TransformFunctionSerializer(%s)" % self.serializer
+
+ class Java:
+ implements = ['org.apache.spark.streaming.api.python.PythonTransformFunctionSerializer']
+
+
+def rddToFileName(prefix, suffix, timestamp):
+ """
+ Return string prefix-time(.suffix)
+
+ >>> rddToFileName("spark", None, 12345678910)
+ 'spark-12345678910'
+ >>> rddToFileName("spark", "tmp", 12345678910)
+ 'spark-12345678910.tmp'
+ """
+ if isinstance(timestamp, datetime):
+ seconds = time.mktime(timestamp.timetuple())
+ timestamp = long(seconds * 1000) + timestamp.microsecond / 1000
+ if suffix is None:
+ return prefix + "-" + str(timestamp)
+ else:
+ return prefix + "-" + str(timestamp) + "." + suffix
+
+
+if __name__ == "__main__":
+ import doctest
+ doctest.testmod()
diff --git a/python/run-tests b/python/run-tests
index f6a9684117..2f98443c30 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -81,6 +81,11 @@ function run_mllib_tests() {
run_test "pyspark/mllib/tests.py"
}
+function run_streaming_tests() {
+ run_test "pyspark/streaming/util.py"
+ run_test "pyspark/streaming/tests.py"
+}
+
echo "Running PySpark tests. Output is in python/unit-tests.log."
export PYSPARK_PYTHON="python"
@@ -96,6 +101,7 @@ $PYSPARK_PYTHON --version
run_core_tests
run_sql_tests
run_mllib_tests
+run_streaming_tests
# Try to test with PyPy
if [ $(which pypy) ]; then
@@ -105,6 +111,7 @@ if [ $(which pypy) ]; then
run_core_tests
run_sql_tests
+ run_streaming_tests
fi
if [[ $FAILED == 0 ]]; then
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
index a6184de4e8..2a7004e56e 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
@@ -167,7 +167,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
new JavaPairDStream(dstream.flatMap(fn)(cm))(fakeClassTag[K2], fakeClassTag[V2])
}
- /**
+ /**
* Return a new DStream in which each RDD is generated by applying mapPartitions() to each RDDs
* of this DStream. Applying mapPartitions() to an RDD applies a function to each partition
* of the RDD.
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala
new file mode 100644
index 0000000000..213dff6a76
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala
@@ -0,0 +1,316 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.api.python
+
+import java.io.{ObjectInputStream, ObjectOutputStream}
+import java.lang.reflect.Proxy
+import java.util.{ArrayList => JArrayList, List => JList}
+import scala.collection.JavaConversions._
+import scala.collection.JavaConverters._
+import scala.language.existentials
+
+import py4j.GatewayServer
+
+import org.apache.spark.api.java._
+import org.apache.spark.api.python._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.streaming.{Interval, Duration, Time}
+import org.apache.spark.streaming.dstream._
+import org.apache.spark.streaming.api.java._
+
+
+/**
+ * Interface for Python callback function which is used to transform RDDs
+ */
+private[python] trait PythonTransformFunction {
+ def call(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]]
+}
+
+/**
+ * Interface for Python Serializer to serialize PythonTransformFunction
+ */
+private[python] trait PythonTransformFunctionSerializer {
+ def dumps(id: String): Array[Byte]
+ def loads(bytes: Array[Byte]): PythonTransformFunction
+}
+
+/**
+ * Wraps a PythonTransformFunction (which is a Python object accessed through Py4J)
+ * so that it looks like a Scala function and can be transparently serialized and
+ * deserialized by Java.
+ */
+private[python] class TransformFunction(@transient var pfunc: PythonTransformFunction)
+ extends function.Function2[JList[JavaRDD[_]], Time, JavaRDD[Array[Byte]]] {
+
+ def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = {
+ Option(pfunc.call(time.milliseconds, List(rdd.map(JavaRDD.fromRDD(_)).orNull).asJava))
+ .map(_.rdd)
+ }
+
+ def apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = {
+ val rdds = List(rdd.map(JavaRDD.fromRDD(_)).orNull, rdd2.map(JavaRDD.fromRDD(_)).orNull).asJava
+ Option(pfunc.call(time.milliseconds, rdds)).map(_.rdd)
+ }
+
+ // for function.Function2
+ def call(rdds: JList[JavaRDD[_]], time: Time): JavaRDD[Array[Byte]] = {
+ pfunc.call(time.milliseconds, rdds)
+ }
+
+ private def writeObject(out: ObjectOutputStream): Unit = {
+ val bytes = PythonTransformFunctionSerializer.serialize(pfunc)
+ out.writeInt(bytes.length)
+ out.write(bytes)
+ }
+
+ private def readObject(in: ObjectInputStream): Unit = {
+ val length = in.readInt()
+ val bytes = new Array[Byte](length)
+ in.readFully(bytes)
+ pfunc = PythonTransformFunctionSerializer.deserialize(bytes)
+ }
+}
+
+/**
+ * Helpers for PythonTransformFunctionSerializer
+ *
+ * PythonTransformFunctionSerializer is logically a singleton that's happens to be
+ * implemented as a Python object.
+ */
+private[python] object PythonTransformFunctionSerializer {
+
+ /**
+ * A serializer in Python, used to serialize PythonTransformFunction
+ */
+ private var serializer: PythonTransformFunctionSerializer = _
+
+ /*
+ * Register a serializer from Python, should be called during initialization
+ */
+ def register(ser: PythonTransformFunctionSerializer): Unit = {
+ serializer = ser
+ }
+
+ def serialize(func: PythonTransformFunction): Array[Byte] = {
+ assert(serializer != null, "Serializer has not been registered!")
+ // get the id of PythonTransformFunction in py4j
+ val h = Proxy.getInvocationHandler(func.asInstanceOf[Proxy])
+ val f = h.getClass().getDeclaredField("id")
+ f.setAccessible(true)
+ val id = f.get(h).asInstanceOf[String]
+ serializer.dumps(id)
+ }
+
+ def deserialize(bytes: Array[Byte]): PythonTransformFunction = {
+ assert(serializer != null, "Serializer has not been registered!")
+ serializer.loads(bytes)
+ }
+}
+
+/**
+ * Helper functions, which are called from Python via Py4J.
+ */
+private[python] object PythonDStream {
+
+ /**
+ * can not access PythonTransformFunctionSerializer.register() via Py4j
+ * Py4JError: PythonTransformFunctionSerializerregister does not exist in the JVM
+ */
+ def registerSerializer(ser: PythonTransformFunctionSerializer): Unit = {
+ PythonTransformFunctionSerializer.register(ser)
+ }
+
+ /**
+ * Update the port of callback client to `port`
+ */
+ def updatePythonGatewayPort(gws: GatewayServer, port: Int): Unit = {
+ val cl = gws.getCallbackClient
+ val f = cl.getClass.getDeclaredField("port")
+ f.setAccessible(true)
+ f.setInt(cl, port)
+ }
+
+ /**
+ * helper function for DStream.foreachRDD(),
+ * cannot be `foreachRDD`, it will confusing py4j
+ */
+ def callForeachRDD(jdstream: JavaDStream[Array[Byte]], pfunc: PythonTransformFunction) {
+ val func = new TransformFunction((pfunc))
+ jdstream.dstream.foreachRDD((rdd, time) => func(Some(rdd), time))
+ }
+
+ /**
+ * convert list of RDD into queue of RDDs, for ssc.queueStream()
+ */
+ def toRDDQueue(rdds: JArrayList[JavaRDD[Array[Byte]]]): java.util.Queue[JavaRDD[Array[Byte]]] = {
+ val queue = new java.util.LinkedList[JavaRDD[Array[Byte]]]
+ rdds.forall(queue.add(_))
+ queue
+ }
+}
+
+/**
+ * Base class for PythonDStream with some common methods
+ */
+private[python] abstract class PythonDStream(
+ parent: DStream[_],
+ @transient pfunc: PythonTransformFunction)
+ extends DStream[Array[Byte]] (parent.ssc) {
+
+ val func = new TransformFunction(pfunc)
+
+ override def dependencies = List(parent)
+
+ override def slideDuration: Duration = parent.slideDuration
+
+ val asJavaDStream = JavaDStream.fromDStream(this)
+}
+
+/**
+ * Transformed DStream in Python.
+ */
+private[python] class PythonTransformedDStream (
+ parent: DStream[_],
+ @transient pfunc: PythonTransformFunction)
+ extends PythonDStream(parent, pfunc) {
+
+ override def compute(validTime: Time): Option[RDD[Array[Byte]]] = {
+ val rdd = parent.getOrCompute(validTime)
+ if (rdd.isDefined) {
+ func(rdd, validTime)
+ } else {
+ None
+ }
+ }
+}
+
+/**
+ * Transformed from two DStreams in Python.
+ */
+private[python] class PythonTransformed2DStream(
+ parent: DStream[_],
+ parent2: DStream[_],
+ @transient pfunc: PythonTransformFunction)
+ extends DStream[Array[Byte]] (parent.ssc) {
+
+ val func = new TransformFunction(pfunc)
+
+ override def dependencies = List(parent, parent2)
+
+ override def slideDuration: Duration = parent.slideDuration
+
+ override def compute(validTime: Time): Option[RDD[Array[Byte]]] = {
+ val empty: RDD[_] = ssc.sparkContext.emptyRDD
+ val rdd1 = parent.getOrCompute(validTime).getOrElse(empty)
+ val rdd2 = parent2.getOrCompute(validTime).getOrElse(empty)
+ func(Some(rdd1), Some(rdd2), validTime)
+ }
+
+ val asJavaDStream = JavaDStream.fromDStream(this)
+}
+
+/**
+ * similar to StateDStream
+ */
+private[python] class PythonStateDStream(
+ parent: DStream[Array[Byte]],
+ @transient reduceFunc: PythonTransformFunction)
+ extends PythonDStream(parent, reduceFunc) {
+
+ super.persist(StorageLevel.MEMORY_ONLY)
+ override val mustCheckpoint = true
+
+ override def compute(validTime: Time): Option[RDD[Array[Byte]]] = {
+ val lastState = getOrCompute(validTime - slideDuration)
+ val rdd = parent.getOrCompute(validTime)
+ if (rdd.isDefined) {
+ func(lastState, rdd, validTime)
+ } else {
+ lastState
+ }
+ }
+}
+
+/**
+ * similar to ReducedWindowedDStream
+ */
+private[python] class PythonReducedWindowedDStream(
+ parent: DStream[Array[Byte]],
+ @transient preduceFunc: PythonTransformFunction,
+ @transient pinvReduceFunc: PythonTransformFunction,
+ _windowDuration: Duration,
+ _slideDuration: Duration)
+ extends PythonDStream(parent, preduceFunc) {
+
+ super.persist(StorageLevel.MEMORY_ONLY)
+ override val mustCheckpoint = true
+
+ val invReduceFunc = new TransformFunction(pinvReduceFunc)
+
+ def windowDuration: Duration = _windowDuration
+ override def slideDuration: Duration = _slideDuration
+ override def parentRememberDuration: Duration = rememberDuration + windowDuration
+
+ override def compute(validTime: Time): Option[RDD[Array[Byte]]] = {
+ val currentTime = validTime
+ val current = new Interval(currentTime - windowDuration, currentTime)
+ val previous = current - slideDuration
+
+ // _____________________________
+ // | previous window _________|___________________
+ // |___________________| current window | --------------> Time
+ // |_____________________________|
+ //
+ // |________ _________| |________ _________|
+ // | |
+ // V V
+ // old RDDs new RDDs
+ //
+ val previousRDD = getOrCompute(previous.endTime)
+
+ // for small window, reduce once will be better than twice
+ if (pinvReduceFunc != null && previousRDD.isDefined
+ && windowDuration >= slideDuration * 5) {
+
+ // subtract the values from old RDDs
+ val oldRDDs = parent.slice(previous.beginTime + parent.slideDuration, current.beginTime)
+ val subtracted = if (oldRDDs.size > 0) {
+ invReduceFunc(previousRDD, Some(ssc.sc.union(oldRDDs)), validTime)
+ } else {
+ previousRDD
+ }
+
+ // add the RDDs of the reduced values in "new time steps"
+ val newRDDs = parent.slice(previous.endTime + parent.slideDuration, current.endTime)
+ if (newRDDs.size > 0) {
+ func(subtracted, Some(ssc.sc.union(newRDDs)), validTime)
+ } else {
+ subtracted
+ }
+ } else {
+ // Get the RDDs of the reduced values in current window
+ val currentRDDs = parent.slice(current.beginTime + parent.slideDuration, current.endTime)
+ if (currentRDDs.size > 0) {
+ func(None, Some(ssc.sc.union(currentRDDs)), validTime)
+ } else {
+ None
+ }
+ }
+ }
+}