aboutsummaryrefslogtreecommitdiff
path: root/streaming
diff options
context:
space:
mode:
authorShixiong Zhu <shixiong@databricks.com>2015-11-12 17:48:43 -0800
committerTathagata Das <tathagata.das1565@gmail.com>2015-11-12 17:48:43 -0800
commit0f1d00a905614bb5eebf260566dbcb831158d445 (patch)
tree5b46386a0c742cd035549fa26c08da296010e86d /streaming
parent41bbd2300472501d69ed46f0407d5ed7cbede4a8 (diff)
downloadspark-0f1d00a905614bb5eebf260566dbcb831158d445.tar.gz
spark-0f1d00a905614bb5eebf260566dbcb831158d445.tar.bz2
spark-0f1d00a905614bb5eebf260566dbcb831158d445.zip
[SPARK-11663][STREAMING] Add Java API for trackStateByKey
TODO - [x] Add Java API - [x] Add API tests - [x] Add a function test Author: Shixiong Zhu <shixiong@databricks.com> Closes #9636 from zsxwing/java-track.
Diffstat (limited to 'streaming')
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/State.scala25
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala84
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala46
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala44
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala1
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala4
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala6
-rw-r--r--streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java210
8 files changed, 393 insertions, 27 deletions
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/State.scala b/streaming/src/main/scala/org/apache/spark/streaming/State.scala
index 7dd1b72f80..604e64fc61 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/State.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/State.scala
@@ -50,9 +50,30 @@ import org.apache.spark.annotation.Experimental
*
* }}}
*
- * Java example:
+ * Java example of using `State`:
* {{{
- * TODO(@zsxwing)
+ * // A tracking function that maintains an integer state and return a String
+ * Function2<Optional<Integer>, State<Integer>, Optional<String>> trackStateFunc =
+ * new Function2<Optional<Integer>, State<Integer>, Optional<String>>() {
+ *
+ * @Override
+ * public Optional<String> call(Optional<Integer> one, State<Integer> state) {
+ * if (state.exists()) {
+ * int existingState = state.get(); // Get the existing state
+ * boolean shouldRemove = ...; // Decide whether to remove the state
+ * if (shouldRemove) {
+ * state.remove(); // Remove the state
+ * } else {
+ * int newState = ...;
+ * state.update(newState); // Set the new state
+ * }
+ * } else {
+ * int initialState = ...; // Set the initial state
+ * state.update(initialState);
+ * }
+ * // return something
+ * }
+ * };
* }}}
*/
@Experimental
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
index c9fe35e74c..bea5b9df20 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
@@ -17,15 +17,14 @@
package org.apache.spark.streaming
-import scala.reflect.ClassTag
-
+import com.google.common.base.Optional
import org.apache.spark.annotation.Experimental
-import org.apache.spark.api.java.JavaPairRDD
+import org.apache.spark.api.java.{JavaPairRDD, JavaUtils}
+import org.apache.spark.api.java.function.{Function2 => JFunction2, Function4 => JFunction4}
import org.apache.spark.rdd.RDD
import org.apache.spark.util.ClosureCleaner
import org.apache.spark.{HashPartitioner, Partitioner}
-
/**
* :: Experimental ::
* Abstract class representing all the specifications of the DStream transformation
@@ -49,12 +48,12 @@ import org.apache.spark.{HashPartitioner, Partitioner}
*
* Example in Java:
* {{{
- * StateStateSpec[KeyType, ValueType, StateType, EmittedDataType] spec =
- * StateStateSpec.function[KeyType, ValueType, StateType, EmittedDataType](trackingFunction)
+ * StateSpec<KeyType, ValueType, StateType, EmittedDataType> spec =
+ * StateSpec.<KeyType, ValueType, StateType, EmittedDataType>function(trackingFunction)
* .numPartition(10);
*
- * JavaDStream[EmittedDataType] emittedRecordDStream =
- * javaPairDStream.trackStateByKey[StateType, EmittedDataType](spec);
+ * JavaTrackStateDStream<KeyType, ValueType, StateType, EmittedType> emittedRecordDStream =
+ * javaPairDStream.<StateType, EmittedDataType>trackStateByKey(spec);
* }}}
*/
@Experimental
@@ -92,6 +91,7 @@ sealed abstract class StateSpec[KeyType, ValueType, StateType, EmittedType] exte
/**
* :: Experimental ::
* Builder object for creating instances of [[org.apache.spark.streaming.StateSpec StateSpec]]
+ * that is used for specifying the parameters of the DStream transformation `trackStateByKey`
* that is used for specifying the parameters of the DStream transformation
* `trackStateByKey` operation of a
* [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a
@@ -103,28 +103,27 @@ sealed abstract class StateSpec[KeyType, ValueType, StateType, EmittedType] exte
* ...
* }
*
- * val spec = StateSpec.function(trackingFunction).numPartitions(10)
- *
- * val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType](spec)
+ * val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType](
+ * StateSpec.function(trackingFunction).numPartitions(10))
* }}}
*
* Example in Java:
* {{{
- * StateStateSpec[KeyType, ValueType, StateType, EmittedDataType] spec =
- * StateStateSpec.function[KeyType, ValueType, StateType, EmittedDataType](trackingFunction)
+ * StateSpec<KeyType, ValueType, StateType, EmittedDataType> spec =
+ * StateSpec.<KeyType, ValueType, StateType, EmittedDataType>function(trackingFunction)
* .numPartition(10);
*
- * JavaDStream[EmittedDataType] emittedRecordDStream =
- * javaPairDStream.trackStateByKey[StateType, EmittedDataType](spec);
+ * JavaTrackStateDStream<KeyType, ValueType, StateType, EmittedType> emittedRecordDStream =
+ * javaPairDStream.<StateType, EmittedDataType>trackStateByKey(spec);
* }}}
*/
@Experimental
object StateSpec {
/**
* Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications
- * `trackStateByKey` operation on a
- * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a
- * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
+ * of the `trackStateByKey` operation on a
+ * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]].
+ *
* @param trackingFunction The function applied on every data item to manage the associated state
* and generate the emitted data
* @tparam KeyType Class of the keys
@@ -141,9 +140,9 @@ object StateSpec {
/**
* Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications
- * `trackStateByKey` operation on a
- * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a
- * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
+ * of the `trackStateByKey` operation on a
+ * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]].
+ *
* @param trackingFunction The function applied on every data item to manage the associated state
* and generate the emitted data
* @tparam ValueType Class of the values
@@ -160,6 +159,48 @@ object StateSpec {
}
new StateSpecImpl(wrappedFunction)
}
+
+ /**
+ * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all
+ * the specifications of the `trackStateByKey` operation on a
+ * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]].
+ *
+ * @param javaTrackingFunction The function applied on every data item to manage the associated
+ * state and generate the emitted data
+ * @tparam KeyType Class of the keys
+ * @tparam ValueType Class of the values
+ * @tparam StateType Class of the states data
+ * @tparam EmittedType Class of the emitted data
+ */
+ def function[KeyType, ValueType, StateType, EmittedType](javaTrackingFunction:
+ JFunction4[Time, KeyType, Optional[ValueType], State[StateType], Optional[EmittedType]]):
+ StateSpec[KeyType, ValueType, StateType, EmittedType] = {
+ val trackingFunc = (time: Time, k: KeyType, v: Option[ValueType], s: State[StateType]) => {
+ val t = javaTrackingFunction.call(time, k, JavaUtils.optionToOptional(v), s)
+ Option(t.orNull)
+ }
+ StateSpec.function(trackingFunc)
+ }
+
+ /**
+ * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications
+ * of the `trackStateByKey` operation on a
+ * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]].
+ *
+ * @param javaTrackingFunction The function applied on every data item to manage the associated
+ * state and generate the emitted data
+ * @tparam ValueType Class of the values
+ * @tparam StateType Class of the states data
+ * @tparam EmittedType Class of the emitted data
+ */
+ def function[KeyType, ValueType, StateType, EmittedType](
+ javaTrackingFunction: JFunction2[Optional[ValueType], State[StateType], EmittedType]):
+ StateSpec[KeyType, ValueType, StateType, EmittedType] = {
+ val trackingFunc = (v: Option[ValueType], s: State[StateType]) => {
+ javaTrackingFunction.call(Optional.fromNullable(v.get), s)
+ }
+ StateSpec.function(trackingFunc)
+ }
}
@@ -184,7 +225,6 @@ case class StateSpecImpl[K, V, S, T](
this
}
-
override def numPartitions(numPartitions: Int): this.type = {
this.partitioner(new HashPartitioner(numPartitions))
this
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
index e2aec6c2f6..70e32b383e 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
@@ -28,8 +28,10 @@ import com.google.common.base.Optional
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.{JobConf, OutputFormat}
import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat}
+
import org.apache.spark.Partitioner
-import org.apache.spark.api.java.{JavaPairRDD, JavaUtils}
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.{JavaPairRDD, JavaSparkContext, JavaUtils}
import org.apache.spark.api.java.JavaPairRDD._
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2}
@@ -426,6 +428,48 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
)
}
+ /**
+ * :: Experimental ::
+ * Return a new [[JavaDStream]] of data generated by combining the key-value data in `this` stream
+ * with a continuously updated per-key state. The user-provided state tracking function is
+ * applied on each keyed data item along with its corresponding state. The function can choose to
+ * update/remove the state and return a transformed data, which forms the
+ * [[JavaTrackStateDStream]].
+ *
+ * The specifications of this transformation is made through the
+ * [[org.apache.spark.streaming.StateSpec StateSpec]] class. Besides the tracking function, there
+ * are a number of optional parameters - initial state data, number of partitions, timeouts, etc.
+ * See the [[org.apache.spark.streaming.StateSpec StateSpec]] for more details.
+ *
+ * Example of using `trackStateByKey`:
+ * {{{
+ * // A tracking function that maintains an integer state and return a String
+ * Function2<Optional<Integer>, State<Integer>, Optional<String>> trackStateFunc =
+ * new Function2<Optional<Integer>, State<Integer>, Optional<String>>() {
+ *
+ * @Override
+ * public Optional<String> call(Optional<Integer> one, State<Integer> state) {
+ * // Check if state exists, accordingly update/remove state and return transformed data
+ * }
+ * };
+ *
+ * JavaTrackStateDStream<Integer, Integer, Integer, String> trackStateDStream =
+ * keyValueDStream.<Integer, String>trackStateByKey(
+ * StateSpec.function(trackStateFunc).numPartitions(10));
+ * }}}
+ *
+ * @param spec Specification of this transformation
+ * @tparam StateType Class type of the state
+ * @tparam EmittedType Class type of the tranformed data return by the tracking function
+ */
+ @Experimental
+ def trackStateByKey[StateType, EmittedType](spec: StateSpec[K, V, StateType, EmittedType]):
+ JavaTrackStateDStream[K, V, StateType, EmittedType] = {
+ new JavaTrackStateDStream(dstream.trackStateByKey(spec)(
+ JavaSparkContext.fakeClassTag,
+ JavaSparkContext.fakeClassTag))
+ }
+
private def convertUpdateStateFunction[S](in: JFunction2[JList[V], Optional[S], Optional[S]]):
(Seq[V], Option[S]) => Option[S] = {
val scalaFunc: (Seq[V], Option[S]) => Option[S] = (values, state) => {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala
new file mode 100644
index 0000000000..f459930d06
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.api.java
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaSparkContext
+import org.apache.spark.streaming.dstream.TrackStateDStream
+
+/**
+ * :: Experimental ::
+ * [[JavaDStream]] representing the stream of records emitted by the tracking function in the
+ * `trackStateByKey` operation on a [[JavaPairDStream]]. Additionally, it also gives access to the
+ * stream of state snapshots, that is, the state data of all keys after a batch has updated them.
+ *
+ * @tparam KeyType Class of the state key
+ * @tparam ValueType Class of the state value
+ * @tparam StateType Class of the state
+ * @tparam EmittedType Class of the emitted records
+ */
+@Experimental
+class JavaTrackStateDStream[KeyType, ValueType, StateType, EmittedType](
+ dstream: TrackStateDStream[KeyType, ValueType, StateType, EmittedType])
+ extends JavaDStream[EmittedType](dstream)(JavaSparkContext.fakeClassTag) {
+
+ def stateSnapshots(): JavaPairDStream[KeyType, StateType] =
+ new JavaPairDStream(dstream.stateSnapshots())(
+ JavaSparkContext.fakeClassTag,
+ JavaSparkContext.fakeClassTag)
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala
index 58d89c93bc..98e881e6ae 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala
@@ -35,6 +35,7 @@ import org.apache.spark.streaming.rdd.{TrackStateRDD, TrackStateRDDRecord}
* all keys after a batch has updated them.
*
* @tparam KeyType Class of the state key
+ * @tparam ValueType Class of the state value
* @tparam StateType Class of the state data
* @tparam EmittedType Class of the emitted records
*/
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala
index ed7cea26d0..fc51496be4 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala
@@ -70,12 +70,14 @@ private[streaming] class TrackStateRDDPartition(
* in the `prevStateRDD` to create `this` RDD
* @param trackingFunction The function that will be used to update state and return new data
* @param batchTime The time of the batch to which this RDD belongs to. Use to update
+ * @param timeoutThresholdTime The time to indicate which keys are timeout
*/
private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
private var prevStateRDD: RDD[TrackStateRDDRecord[K, S, T]],
private var partitionedDataRDD: RDD[(K, V)],
trackingFunction: (Time, K, Option[V], State[S]) => Option[T],
- batchTime: Time, timeoutThresholdTime: Option[Long]
+ batchTime: Time,
+ timeoutThresholdTime: Option[Long]
) extends RDD[TrackStateRDDRecord[K, S, T]](
partitionedDataRDD.sparkContext,
List(
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
index ed622ef7bf..34287c3e00 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
@@ -267,7 +267,11 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag](
// Read the data of the delta
val deltaMapSize = inputStream.readInt()
- deltaMap = new OpenHashMap[K, StateInfo[S]]()
+ deltaMap = if (deltaMapSize != 0) {
+ new OpenHashMap[K, StateInfo[S]](deltaMapSize)
+ } else {
+ new OpenHashMap[K, StateInfo[S]](initialCapacity)
+ }
var deltaMapCount = 0
while (deltaMapCount < deltaMapSize) {
val key = inputStream.readObject().asInstanceOf[K]
diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java
new file mode 100644
index 0000000000..eac4cdd14a
--- /dev/null
+++ b/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java
@@ -0,0 +1,210 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Set;
+
+import scala.Tuple2;
+
+import com.google.common.base.Optional;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.streaming.api.java.JavaDStream;
+import org.apache.spark.util.ManualClock;
+import org.junit.Assert;
+import org.junit.Test;
+
+import org.apache.spark.HashPartitioner;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.function.Function2;
+import org.apache.spark.api.java.function.Function4;
+import org.apache.spark.streaming.api.java.JavaPairDStream;
+import org.apache.spark.streaming.api.java.JavaTrackStateDStream;
+
+public class JavaTrackStateByKeySuite extends LocalJavaStreamingContext implements Serializable {
+
+ /**
+ * This test is only for testing the APIs. It's not necessary to run it.
+ */
+ public void testAPI() {
+ JavaPairRDD<String, Boolean> initialRDD = null;
+ JavaPairDStream<String, Integer> wordsDstream = null;
+
+ final Function4<Time, String, Optional<Integer>, State<Boolean>, Optional<Double>>
+ trackStateFunc =
+ new Function4<Time, String, Optional<Integer>, State<Boolean>, Optional<Double>>() {
+
+ @Override
+ public Optional<Double> call(
+ Time time, String word, Optional<Integer> one, State<Boolean> state) {
+ // Use all State's methods here
+ state.exists();
+ state.get();
+ state.isTimingOut();
+ state.remove();
+ state.update(true);
+ return Optional.of(2.0);
+ }
+ };
+
+ JavaTrackStateDStream<String, Integer, Boolean, Double> stateDstream =
+ wordsDstream.trackStateByKey(
+ StateSpec.function(trackStateFunc)
+ .initialState(initialRDD)
+ .numPartitions(10)
+ .partitioner(new HashPartitioner(10))
+ .timeout(Durations.seconds(10)));
+
+ JavaPairDStream<String, Boolean> emittedRecords = stateDstream.stateSnapshots();
+
+ final Function2<Optional<Integer>, State<Boolean>, Double> trackStateFunc2 =
+ new Function2<Optional<Integer>, State<Boolean>, Double>() {
+
+ @Override
+ public Double call(Optional<Integer> one, State<Boolean> state) {
+ // Use all State's methods here
+ state.exists();
+ state.get();
+ state.isTimingOut();
+ state.remove();
+ state.update(true);
+ return 2.0;
+ }
+ };
+
+ JavaTrackStateDStream<String, Integer, Boolean, Double> stateDstream2 =
+ wordsDstream.trackStateByKey(
+ StateSpec.<String, Integer, Boolean, Double> function(trackStateFunc2)
+ .initialState(initialRDD)
+ .numPartitions(10)
+ .partitioner(new HashPartitioner(10))
+ .timeout(Durations.seconds(10)));
+
+ JavaPairDStream<String, Boolean> emittedRecords2 = stateDstream2.stateSnapshots();
+ }
+
+ @Test
+ public void testBasicFunction() {
+ List<List<String>> inputData = Arrays.asList(
+ Collections.<String>emptyList(),
+ Arrays.asList("a"),
+ Arrays.asList("a", "b"),
+ Arrays.asList("a", "b", "c"),
+ Arrays.asList("a", "b"),
+ Arrays.asList("a"),
+ Collections.<String>emptyList()
+ );
+
+ List<Set<Integer>> outputData = Arrays.asList(
+ Collections.<Integer>emptySet(),
+ Sets.newHashSet(1),
+ Sets.newHashSet(2, 1),
+ Sets.newHashSet(3, 2, 1),
+ Sets.newHashSet(4, 3),
+ Sets.newHashSet(5),
+ Collections.<Integer>emptySet()
+ );
+
+ List<Set<Tuple2<String, Integer>>> stateData = Arrays.asList(
+ Collections.<Tuple2<String, Integer>>emptySet(),
+ Sets.newHashSet(new Tuple2<String, Integer>("a", 1)),
+ Sets.newHashSet(new Tuple2<String, Integer>("a", 2), new Tuple2<String, Integer>("b", 1)),
+ Sets.newHashSet(
+ new Tuple2<String, Integer>("a", 3),
+ new Tuple2<String, Integer>("b", 2),
+ new Tuple2<String, Integer>("c", 1)),
+ Sets.newHashSet(
+ new Tuple2<String, Integer>("a", 4),
+ new Tuple2<String, Integer>("b", 3),
+ new Tuple2<String, Integer>("c", 1)),
+ Sets.newHashSet(
+ new Tuple2<String, Integer>("a", 5),
+ new Tuple2<String, Integer>("b", 3),
+ new Tuple2<String, Integer>("c", 1)),
+ Sets.newHashSet(
+ new Tuple2<String, Integer>("a", 5),
+ new Tuple2<String, Integer>("b", 3),
+ new Tuple2<String, Integer>("c", 1))
+ );
+
+ Function2<Optional<Integer>, State<Integer>, Integer> trackStateFunc =
+ new Function2<Optional<Integer>, State<Integer>, Integer>() {
+
+ @Override
+ public Integer call(Optional<Integer> value, State<Integer> state) throws Exception {
+ int sum = value.or(0) + (state.exists() ? state.get() : 0);
+ state.update(sum);
+ return sum;
+ }
+ };
+ testOperation(
+ inputData,
+ StateSpec.<String, Integer, Integer, Integer>function(trackStateFunc),
+ outputData,
+ stateData);
+ }
+
+ private <K, S, T> void testOperation(
+ List<List<K>> input,
+ StateSpec<K, Integer, S, T> trackStateSpec,
+ List<Set<T>> expectedOutputs,
+ List<Set<Tuple2<K, S>>> expectedStateSnapshots) {
+ int numBatches = expectedOutputs.size();
+ JavaDStream<K> inputStream = JavaTestUtils.attachTestInputStream(ssc, input, 2);
+ JavaTrackStateDStream<K, Integer, S, T> trackeStateStream =
+ JavaPairDStream.fromJavaDStream(inputStream.map(new Function<K, Tuple2<K, Integer>>() {
+ @Override
+ public Tuple2<K, Integer> call(K x) throws Exception {
+ return new Tuple2<K, Integer>(x, 1);
+ }
+ })).trackStateByKey(trackStateSpec);
+
+ final List<Set<T>> collectedOutputs =
+ Collections.synchronizedList(Lists.<Set<T>>newArrayList());
+ trackeStateStream.foreachRDD(new Function<JavaRDD<T>, Void>() {
+ @Override
+ public Void call(JavaRDD<T> rdd) throws Exception {
+ collectedOutputs.add(Sets.newHashSet(rdd.collect()));
+ return null;
+ }
+ });
+ final List<Set<Tuple2<K, S>>> collectedStateSnapshots =
+ Collections.synchronizedList(Lists.<Set<Tuple2<K, S>>>newArrayList());
+ trackeStateStream.stateSnapshots().foreachRDD(new Function<JavaPairRDD<K, S>, Void>() {
+ @Override
+ public Void call(JavaPairRDD<K, S> rdd) throws Exception {
+ collectedStateSnapshots.add(Sets.newHashSet(rdd.collect()));
+ return null;
+ }
+ });
+ BatchCounter batchCounter = new BatchCounter(ssc.ssc());
+ ssc.start();
+ ((ManualClock) ssc.ssc().scheduler().clock())
+ .advance(ssc.ssc().progressListener().batchDuration() * numBatches + 1);
+ batchCounter.waitUntilBatchesCompleted(numBatches, 10000);
+
+ Assert.assertEquals(expectedOutputs, collectedOutputs);
+ Assert.assertEquals(expectedStateSnapshots, collectedStateSnapshots);
+ }
+}