aboutsummaryrefslogtreecommitdiff
path: root/streaming/src/main
diff options
context:
space:
mode:
authorShixiong Zhu <shixiong@databricks.com>2015-11-12 17:48:43 -0800
committerTathagata Das <tathagata.das1565@gmail.com>2015-11-12 17:48:43 -0800
commit0f1d00a905614bb5eebf260566dbcb831158d445 (patch)
tree5b46386a0c742cd035549fa26c08da296010e86d /streaming/src/main
parent41bbd2300472501d69ed46f0407d5ed7cbede4a8 (diff)
downloadspark-0f1d00a905614bb5eebf260566dbcb831158d445.tar.gz
spark-0f1d00a905614bb5eebf260566dbcb831158d445.tar.bz2
spark-0f1d00a905614bb5eebf260566dbcb831158d445.zip
[SPARK-11663][STREAMING] Add Java API for trackStateByKey
TODO - [x] Add Java API - [x] Add API tests - [x] Add a function test Author: Shixiong Zhu <shixiong@databricks.com> Closes #9636 from zsxwing/java-track.
Diffstat (limited to 'streaming/src/main')
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/State.scala25
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala84
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala46
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala44
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala1
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala4
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala6
7 files changed, 183 insertions, 27 deletions
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/State.scala b/streaming/src/main/scala/org/apache/spark/streaming/State.scala
index 7dd1b72f80..604e64fc61 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/State.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/State.scala
@@ -50,9 +50,30 @@ import org.apache.spark.annotation.Experimental
*
* }}}
*
- * Java example:
+ * Java example of using `State`:
* {{{
- * TODO(@zsxwing)
+ * // A tracking function that maintains an integer state and return a String
+ * Function2<Optional<Integer>, State<Integer>, Optional<String>> trackStateFunc =
+ * new Function2<Optional<Integer>, State<Integer>, Optional<String>>() {
+ *
+ * @Override
+ * public Optional<String> call(Optional<Integer> one, State<Integer> state) {
+ * if (state.exists()) {
+ * int existingState = state.get(); // Get the existing state
+ * boolean shouldRemove = ...; // Decide whether to remove the state
+ * if (shouldRemove) {
+ * state.remove(); // Remove the state
+ * } else {
+ * int newState = ...;
+ * state.update(newState); // Set the new state
+ * }
+ * } else {
+ * int initialState = ...; // Set the initial state
+ * state.update(initialState);
+ * }
+ * // return something
+ * }
+ * };
* }}}
*/
@Experimental
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
index c9fe35e74c..bea5b9df20 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
@@ -17,15 +17,14 @@
package org.apache.spark.streaming
-import scala.reflect.ClassTag
-
+import com.google.common.base.Optional
import org.apache.spark.annotation.Experimental
-import org.apache.spark.api.java.JavaPairRDD
+import org.apache.spark.api.java.{JavaPairRDD, JavaUtils}
+import org.apache.spark.api.java.function.{Function2 => JFunction2, Function4 => JFunction4}
import org.apache.spark.rdd.RDD
import org.apache.spark.util.ClosureCleaner
import org.apache.spark.{HashPartitioner, Partitioner}
-
/**
* :: Experimental ::
* Abstract class representing all the specifications of the DStream transformation
@@ -49,12 +48,12 @@ import org.apache.spark.{HashPartitioner, Partitioner}
*
* Example in Java:
* {{{
- * StateStateSpec[KeyType, ValueType, StateType, EmittedDataType] spec =
- * StateStateSpec.function[KeyType, ValueType, StateType, EmittedDataType](trackingFunction)
+ * StateSpec<KeyType, ValueType, StateType, EmittedDataType> spec =
+ * StateSpec.<KeyType, ValueType, StateType, EmittedDataType>function(trackingFunction)
* .numPartition(10);
*
- * JavaDStream[EmittedDataType] emittedRecordDStream =
- * javaPairDStream.trackStateByKey[StateType, EmittedDataType](spec);
+ * JavaTrackStateDStream<KeyType, ValueType, StateType, EmittedType> emittedRecordDStream =
+ * javaPairDStream.<StateType, EmittedDataType>trackStateByKey(spec);
* }}}
*/
@Experimental
@@ -92,6 +91,7 @@ sealed abstract class StateSpec[KeyType, ValueType, StateType, EmittedType] exte
/**
* :: Experimental ::
* Builder object for creating instances of [[org.apache.spark.streaming.StateSpec StateSpec]]
+ * that is used for specifying the parameters of the DStream transformation `trackStateByKey`
* that is used for specifying the parameters of the DStream transformation
* `trackStateByKey` operation of a
* [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a
@@ -103,28 +103,27 @@ sealed abstract class StateSpec[KeyType, ValueType, StateType, EmittedType] exte
* ...
* }
*
- * val spec = StateSpec.function(trackingFunction).numPartitions(10)
- *
- * val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType](spec)
+ * val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType](
+ * StateSpec.function(trackingFunction).numPartitions(10))
* }}}
*
* Example in Java:
* {{{
- * StateStateSpec[KeyType, ValueType, StateType, EmittedDataType] spec =
- * StateStateSpec.function[KeyType, ValueType, StateType, EmittedDataType](trackingFunction)
+ * StateSpec<KeyType, ValueType, StateType, EmittedDataType> spec =
+ * StateSpec.<KeyType, ValueType, StateType, EmittedDataType>function(trackingFunction)
* .numPartition(10);
*
- * JavaDStream[EmittedDataType] emittedRecordDStream =
- * javaPairDStream.trackStateByKey[StateType, EmittedDataType](spec);
+ * JavaTrackStateDStream<KeyType, ValueType, StateType, EmittedType> emittedRecordDStream =
+ * javaPairDStream.<StateType, EmittedDataType>trackStateByKey(spec);
* }}}
*/
@Experimental
object StateSpec {
/**
* Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications
- * `trackStateByKey` operation on a
- * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a
- * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
+ * of the `trackStateByKey` operation on a
+ * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]].
+ *
* @param trackingFunction The function applied on every data item to manage the associated state
* and generate the emitted data
* @tparam KeyType Class of the keys
@@ -141,9 +140,9 @@ object StateSpec {
/**
* Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications
- * `trackStateByKey` operation on a
- * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a
- * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
+ * of the `trackStateByKey` operation on a
+ * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]].
+ *
* @param trackingFunction The function applied on every data item to manage the associated state
* and generate the emitted data
* @tparam ValueType Class of the values
@@ -160,6 +159,48 @@ object StateSpec {
}
new StateSpecImpl(wrappedFunction)
}
+
+ /**
+ * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all
+ * the specifications of the `trackStateByKey` operation on a
+ * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]].
+ *
+ * @param javaTrackingFunction The function applied on every data item to manage the associated
+ * state and generate the emitted data
+ * @tparam KeyType Class of the keys
+ * @tparam ValueType Class of the values
+ * @tparam StateType Class of the states data
+ * @tparam EmittedType Class of the emitted data
+ */
+ def function[KeyType, ValueType, StateType, EmittedType](javaTrackingFunction:
+ JFunction4[Time, KeyType, Optional[ValueType], State[StateType], Optional[EmittedType]]):
+ StateSpec[KeyType, ValueType, StateType, EmittedType] = {
+ val trackingFunc = (time: Time, k: KeyType, v: Option[ValueType], s: State[StateType]) => {
+ val t = javaTrackingFunction.call(time, k, JavaUtils.optionToOptional(v), s)
+ Option(t.orNull)
+ }
+ StateSpec.function(trackingFunc)
+ }
+
+ /**
+ * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications
+ * of the `trackStateByKey` operation on a
+ * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]].
+ *
+ * @param javaTrackingFunction The function applied on every data item to manage the associated
+ * state and generate the emitted data
+ * @tparam ValueType Class of the values
+ * @tparam StateType Class of the states data
+ * @tparam EmittedType Class of the emitted data
+ */
+ def function[KeyType, ValueType, StateType, EmittedType](
+ javaTrackingFunction: JFunction2[Optional[ValueType], State[StateType], EmittedType]):
+ StateSpec[KeyType, ValueType, StateType, EmittedType] = {
+ val trackingFunc = (v: Option[ValueType], s: State[StateType]) => {
+ javaTrackingFunction.call(Optional.fromNullable(v.get), s)
+ }
+ StateSpec.function(trackingFunc)
+ }
}
@@ -184,7 +225,6 @@ case class StateSpecImpl[K, V, S, T](
this
}
-
override def numPartitions(numPartitions: Int): this.type = {
this.partitioner(new HashPartitioner(numPartitions))
this
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
index e2aec6c2f6..70e32b383e 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
@@ -28,8 +28,10 @@ import com.google.common.base.Optional
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.{JobConf, OutputFormat}
import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat}
+
import org.apache.spark.Partitioner
-import org.apache.spark.api.java.{JavaPairRDD, JavaUtils}
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.{JavaPairRDD, JavaSparkContext, JavaUtils}
import org.apache.spark.api.java.JavaPairRDD._
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2}
@@ -426,6 +428,48 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
)
}
+ /**
+ * :: Experimental ::
+ * Return a new [[JavaDStream]] of data generated by combining the key-value data in `this` stream
+ * with a continuously updated per-key state. The user-provided state tracking function is
+ * applied on each keyed data item along with its corresponding state. The function can choose to
+ * update/remove the state and return a transformed data, which forms the
+ * [[JavaTrackStateDStream]].
+ *
+ * The specifications of this transformation is made through the
+ * [[org.apache.spark.streaming.StateSpec StateSpec]] class. Besides the tracking function, there
+ * are a number of optional parameters - initial state data, number of partitions, timeouts, etc.
+ * See the [[org.apache.spark.streaming.StateSpec StateSpec]] for more details.
+ *
+ * Example of using `trackStateByKey`:
+ * {{{
+ * // A tracking function that maintains an integer state and return a String
+ * Function2<Optional<Integer>, State<Integer>, Optional<String>> trackStateFunc =
+ * new Function2<Optional<Integer>, State<Integer>, Optional<String>>() {
+ *
+ * @Override
+ * public Optional<String> call(Optional<Integer> one, State<Integer> state) {
+ * // Check if state exists, accordingly update/remove state and return transformed data
+ * }
+ * };
+ *
+ * JavaTrackStateDStream<Integer, Integer, Integer, String> trackStateDStream =
+ * keyValueDStream.<Integer, String>trackStateByKey(
+ * StateSpec.function(trackStateFunc).numPartitions(10));
+ * }}}
+ *
+ * @param spec Specification of this transformation
+ * @tparam StateType Class type of the state
+ * @tparam EmittedType Class type of the tranformed data return by the tracking function
+ */
+ @Experimental
+ def trackStateByKey[StateType, EmittedType](spec: StateSpec[K, V, StateType, EmittedType]):
+ JavaTrackStateDStream[K, V, StateType, EmittedType] = {
+ new JavaTrackStateDStream(dstream.trackStateByKey(spec)(
+ JavaSparkContext.fakeClassTag,
+ JavaSparkContext.fakeClassTag))
+ }
+
private def convertUpdateStateFunction[S](in: JFunction2[JList[V], Optional[S], Optional[S]]):
(Seq[V], Option[S]) => Option[S] = {
val scalaFunc: (Seq[V], Option[S]) => Option[S] = (values, state) => {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala
new file mode 100644
index 0000000000..f459930d06
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.api.java
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaSparkContext
+import org.apache.spark.streaming.dstream.TrackStateDStream
+
+/**
+ * :: Experimental ::
+ * [[JavaDStream]] representing the stream of records emitted by the tracking function in the
+ * `trackStateByKey` operation on a [[JavaPairDStream]]. Additionally, it also gives access to the
+ * stream of state snapshots, that is, the state data of all keys after a batch has updated them.
+ *
+ * @tparam KeyType Class of the state key
+ * @tparam ValueType Class of the state value
+ * @tparam StateType Class of the state
+ * @tparam EmittedType Class of the emitted records
+ */
+@Experimental
+class JavaTrackStateDStream[KeyType, ValueType, StateType, EmittedType](
+ dstream: TrackStateDStream[KeyType, ValueType, StateType, EmittedType])
+ extends JavaDStream[EmittedType](dstream)(JavaSparkContext.fakeClassTag) {
+
+ def stateSnapshots(): JavaPairDStream[KeyType, StateType] =
+ new JavaPairDStream(dstream.stateSnapshots())(
+ JavaSparkContext.fakeClassTag,
+ JavaSparkContext.fakeClassTag)
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala
index 58d89c93bc..98e881e6ae 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala
@@ -35,6 +35,7 @@ import org.apache.spark.streaming.rdd.{TrackStateRDD, TrackStateRDDRecord}
* all keys after a batch has updated them.
*
* @tparam KeyType Class of the state key
+ * @tparam ValueType Class of the state value
* @tparam StateType Class of the state data
* @tparam EmittedType Class of the emitted records
*/
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala
index ed7cea26d0..fc51496be4 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala
@@ -70,12 +70,14 @@ private[streaming] class TrackStateRDDPartition(
* in the `prevStateRDD` to create `this` RDD
* @param trackingFunction The function that will be used to update state and return new data
* @param batchTime The time of the batch to which this RDD belongs to. Use to update
+ * @param timeoutThresholdTime The time to indicate which keys are timeout
*/
private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
private var prevStateRDD: RDD[TrackStateRDDRecord[K, S, T]],
private var partitionedDataRDD: RDD[(K, V)],
trackingFunction: (Time, K, Option[V], State[S]) => Option[T],
- batchTime: Time, timeoutThresholdTime: Option[Long]
+ batchTime: Time,
+ timeoutThresholdTime: Option[Long]
) extends RDD[TrackStateRDDRecord[K, S, T]](
partitionedDataRDD.sparkContext,
List(
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
index ed622ef7bf..34287c3e00 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
@@ -267,7 +267,11 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag](
// Read the data of the delta
val deltaMapSize = inputStream.readInt()
- deltaMap = new OpenHashMap[K, StateInfo[S]]()
+ deltaMap = if (deltaMapSize != 0) {
+ new OpenHashMap[K, StateInfo[S]](deltaMapSize)
+ } else {
+ new OpenHashMap[K, StateInfo[S]](initialCapacity)
+ }
var deltaMapCount = 0
while (deltaMapCount < deltaMapSize) {
val key = inputStream.readObject().asInstanceOf[K]