aboutsummaryrefslogtreecommitdiff
path: root/streaming
diff options
context:
space:
mode:
Diffstat (limited to 'streaming')
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala2
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala316
2 files changed, 317 insertions, 1 deletions
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
index a6184de4e8..2a7004e56e 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
@@ -167,7 +167,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
new JavaPairDStream(dstream.flatMap(fn)(cm))(fakeClassTag[K2], fakeClassTag[V2])
}
- /**
+ /**
* Return a new DStream in which each RDD is generated by applying mapPartitions() to each RDDs
* of this DStream. Applying mapPartitions() to an RDD applies a function to each partition
* of the RDD.
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala
new file mode 100644
index 0000000000..213dff6a76
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala
@@ -0,0 +1,316 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.api.python
+
+import java.io.{ObjectInputStream, ObjectOutputStream}
+import java.lang.reflect.Proxy
+import java.util.{ArrayList => JArrayList, List => JList}
+import scala.collection.JavaConversions._
+import scala.collection.JavaConverters._
+import scala.language.existentials
+
+import py4j.GatewayServer
+
+import org.apache.spark.api.java._
+import org.apache.spark.api.python._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.streaming.{Interval, Duration, Time}
+import org.apache.spark.streaming.dstream._
+import org.apache.spark.streaming.api.java._
+
+
+/**
+ * Interface for Python callback function which is used to transform RDDs
+ */
+private[python] trait PythonTransformFunction {
+ def call(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]]
+}
+
+/**
+ * Interface for Python Serializer to serialize PythonTransformFunction
+ */
+private[python] trait PythonTransformFunctionSerializer {
+ def dumps(id: String): Array[Byte]
+ def loads(bytes: Array[Byte]): PythonTransformFunction
+}
+
+/**
+ * Wraps a PythonTransformFunction (which is a Python object accessed through Py4J)
+ * so that it looks like a Scala function and can be transparently serialized and
+ * deserialized by Java.
+ */
+private[python] class TransformFunction(@transient var pfunc: PythonTransformFunction)
+ extends function.Function2[JList[JavaRDD[_]], Time, JavaRDD[Array[Byte]]] {
+
+ def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = {
+ Option(pfunc.call(time.milliseconds, List(rdd.map(JavaRDD.fromRDD(_)).orNull).asJava))
+ .map(_.rdd)
+ }
+
+ def apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = {
+ val rdds = List(rdd.map(JavaRDD.fromRDD(_)).orNull, rdd2.map(JavaRDD.fromRDD(_)).orNull).asJava
+ Option(pfunc.call(time.milliseconds, rdds)).map(_.rdd)
+ }
+
+ // for function.Function2
+ def call(rdds: JList[JavaRDD[_]], time: Time): JavaRDD[Array[Byte]] = {
+ pfunc.call(time.milliseconds, rdds)
+ }
+
+ private def writeObject(out: ObjectOutputStream): Unit = {
+ val bytes = PythonTransformFunctionSerializer.serialize(pfunc)
+ out.writeInt(bytes.length)
+ out.write(bytes)
+ }
+
+ private def readObject(in: ObjectInputStream): Unit = {
+ val length = in.readInt()
+ val bytes = new Array[Byte](length)
+ in.readFully(bytes)
+ pfunc = PythonTransformFunctionSerializer.deserialize(bytes)
+ }
+}
+
+/**
+ * Helpers for PythonTransformFunctionSerializer
+ *
+ * PythonTransformFunctionSerializer is logically a singleton that's happens to be
+ * implemented as a Python object.
+ */
+private[python] object PythonTransformFunctionSerializer {
+
+ /**
+ * A serializer in Python, used to serialize PythonTransformFunction
+ */
+ private var serializer: PythonTransformFunctionSerializer = _
+
+ /*
+ * Register a serializer from Python, should be called during initialization
+ */
+ def register(ser: PythonTransformFunctionSerializer): Unit = {
+ serializer = ser
+ }
+
+ def serialize(func: PythonTransformFunction): Array[Byte] = {
+ assert(serializer != null, "Serializer has not been registered!")
+ // get the id of PythonTransformFunction in py4j
+ val h = Proxy.getInvocationHandler(func.asInstanceOf[Proxy])
+ val f = h.getClass().getDeclaredField("id")
+ f.setAccessible(true)
+ val id = f.get(h).asInstanceOf[String]
+ serializer.dumps(id)
+ }
+
+ def deserialize(bytes: Array[Byte]): PythonTransformFunction = {
+ assert(serializer != null, "Serializer has not been registered!")
+ serializer.loads(bytes)
+ }
+}
+
+/**
+ * Helper functions, which are called from Python via Py4J.
+ */
+private[python] object PythonDStream {
+
+ /**
+ * can not access PythonTransformFunctionSerializer.register() via Py4j
+ * Py4JError: PythonTransformFunctionSerializerregister does not exist in the JVM
+ */
+ def registerSerializer(ser: PythonTransformFunctionSerializer): Unit = {
+ PythonTransformFunctionSerializer.register(ser)
+ }
+
+ /**
+ * Update the port of callback client to `port`
+ */
+ def updatePythonGatewayPort(gws: GatewayServer, port: Int): Unit = {
+ val cl = gws.getCallbackClient
+ val f = cl.getClass.getDeclaredField("port")
+ f.setAccessible(true)
+ f.setInt(cl, port)
+ }
+
+ /**
+ * helper function for DStream.foreachRDD(),
+ * cannot be `foreachRDD`, it will confusing py4j
+ */
+ def callForeachRDD(jdstream: JavaDStream[Array[Byte]], pfunc: PythonTransformFunction) {
+ val func = new TransformFunction((pfunc))
+ jdstream.dstream.foreachRDD((rdd, time) => func(Some(rdd), time))
+ }
+
+ /**
+ * convert list of RDD into queue of RDDs, for ssc.queueStream()
+ */
+ def toRDDQueue(rdds: JArrayList[JavaRDD[Array[Byte]]]): java.util.Queue[JavaRDD[Array[Byte]]] = {
+ val queue = new java.util.LinkedList[JavaRDD[Array[Byte]]]
+ rdds.forall(queue.add(_))
+ queue
+ }
+}
+
+/**
+ * Base class for PythonDStream with some common methods
+ */
+private[python] abstract class PythonDStream(
+ parent: DStream[_],
+ @transient pfunc: PythonTransformFunction)
+ extends DStream[Array[Byte]] (parent.ssc) {
+
+ val func = new TransformFunction(pfunc)
+
+ override def dependencies = List(parent)
+
+ override def slideDuration: Duration = parent.slideDuration
+
+ val asJavaDStream = JavaDStream.fromDStream(this)
+}
+
+/**
+ * Transformed DStream in Python.
+ */
+private[python] class PythonTransformedDStream (
+ parent: DStream[_],
+ @transient pfunc: PythonTransformFunction)
+ extends PythonDStream(parent, pfunc) {
+
+ override def compute(validTime: Time): Option[RDD[Array[Byte]]] = {
+ val rdd = parent.getOrCompute(validTime)
+ if (rdd.isDefined) {
+ func(rdd, validTime)
+ } else {
+ None
+ }
+ }
+}
+
+/**
+ * Transformed from two DStreams in Python.
+ */
+private[python] class PythonTransformed2DStream(
+ parent: DStream[_],
+ parent2: DStream[_],
+ @transient pfunc: PythonTransformFunction)
+ extends DStream[Array[Byte]] (parent.ssc) {
+
+ val func = new TransformFunction(pfunc)
+
+ override def dependencies = List(parent, parent2)
+
+ override def slideDuration: Duration = parent.slideDuration
+
+ override def compute(validTime: Time): Option[RDD[Array[Byte]]] = {
+ val empty: RDD[_] = ssc.sparkContext.emptyRDD
+ val rdd1 = parent.getOrCompute(validTime).getOrElse(empty)
+ val rdd2 = parent2.getOrCompute(validTime).getOrElse(empty)
+ func(Some(rdd1), Some(rdd2), validTime)
+ }
+
+ val asJavaDStream = JavaDStream.fromDStream(this)
+}
+
+/**
+ * similar to StateDStream
+ */
+private[python] class PythonStateDStream(
+ parent: DStream[Array[Byte]],
+ @transient reduceFunc: PythonTransformFunction)
+ extends PythonDStream(parent, reduceFunc) {
+
+ super.persist(StorageLevel.MEMORY_ONLY)
+ override val mustCheckpoint = true
+
+ override def compute(validTime: Time): Option[RDD[Array[Byte]]] = {
+ val lastState = getOrCompute(validTime - slideDuration)
+ val rdd = parent.getOrCompute(validTime)
+ if (rdd.isDefined) {
+ func(lastState, rdd, validTime)
+ } else {
+ lastState
+ }
+ }
+}
+
+/**
+ * similar to ReducedWindowedDStream
+ */
+private[python] class PythonReducedWindowedDStream(
+ parent: DStream[Array[Byte]],
+ @transient preduceFunc: PythonTransformFunction,
+ @transient pinvReduceFunc: PythonTransformFunction,
+ _windowDuration: Duration,
+ _slideDuration: Duration)
+ extends PythonDStream(parent, preduceFunc) {
+
+ super.persist(StorageLevel.MEMORY_ONLY)
+ override val mustCheckpoint = true
+
+ val invReduceFunc = new TransformFunction(pinvReduceFunc)
+
+ def windowDuration: Duration = _windowDuration
+ override def slideDuration: Duration = _slideDuration
+ override def parentRememberDuration: Duration = rememberDuration + windowDuration
+
+ override def compute(validTime: Time): Option[RDD[Array[Byte]]] = {
+ val currentTime = validTime
+ val current = new Interval(currentTime - windowDuration, currentTime)
+ val previous = current - slideDuration
+
+ // _____________________________
+ // | previous window _________|___________________
+ // |___________________| current window | --------------> Time
+ // |_____________________________|
+ //
+ // |________ _________| |________ _________|
+ // | |
+ // V V
+ // old RDDs new RDDs
+ //
+ val previousRDD = getOrCompute(previous.endTime)
+
+ // for small window, reduce once will be better than twice
+ if (pinvReduceFunc != null && previousRDD.isDefined
+ && windowDuration >= slideDuration * 5) {
+
+ // subtract the values from old RDDs
+ val oldRDDs = parent.slice(previous.beginTime + parent.slideDuration, current.beginTime)
+ val subtracted = if (oldRDDs.size > 0) {
+ invReduceFunc(previousRDD, Some(ssc.sc.union(oldRDDs)), validTime)
+ } else {
+ previousRDD
+ }
+
+ // add the RDDs of the reduced values in "new time steps"
+ val newRDDs = parent.slice(previous.endTime + parent.slideDuration, current.endTime)
+ if (newRDDs.size > 0) {
+ func(subtracted, Some(ssc.sc.union(newRDDs)), validTime)
+ } else {
+ subtracted
+ }
+ } else {
+ // Get the RDDs of the reduced values in current window
+ val currentRDDs = parent.slice(current.beginTime + parent.slideDuration, current.endTime)
+ if (currentRDDs.size > 0) {
+ func(None, Some(ssc.sc.union(currentRDDs)), validTime)
+ } else {
+ None
+ }
+ }
+ }
+}