aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java16
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala12
-rw-r--r--extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java18
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/State.scala20
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala160
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaMapWithStateDStream.scala (renamed from streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala)20
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala50
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala (renamed from streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala)61
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala41
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala (renamed from streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala)99
-rw-r--r--streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java (renamed from streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java)48
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala (renamed from streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala)112
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala (renamed from streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala)114
13 files changed, 389 insertions, 382 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java
index c400e4237a..14997c64d5 100644
--- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java
+++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java
@@ -65,7 +65,7 @@ public class JavaStatefulNetworkWordCount {
JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, Durations.seconds(1));
ssc.checkpoint(".");
- // Initial RDD input to trackStateByKey
+ // Initial state RDD input to mapWithState
@SuppressWarnings("unchecked")
List<Tuple2<String, Integer>> tuples = Arrays.asList(new Tuple2<String, Integer>("hello", 1),
new Tuple2<String, Integer>("world", 1));
@@ -90,21 +90,21 @@ public class JavaStatefulNetworkWordCount {
});
// Update the cumulative count function
- final Function4<Time, String, Optional<Integer>, State<Integer>, Optional<Tuple2<String, Integer>>> trackStateFunc =
- new Function4<Time, String, Optional<Integer>, State<Integer>, Optional<Tuple2<String, Integer>>>() {
+ final Function3<String, Optional<Integer>, State<Integer>, Tuple2<String, Integer>> mappingFunc =
+ new Function3<String, Optional<Integer>, State<Integer>, Tuple2<String, Integer>>() {
@Override
- public Optional<Tuple2<String, Integer>> call(Time time, String word, Optional<Integer> one, State<Integer> state) {
+ public Tuple2<String, Integer> call(String word, Optional<Integer> one, State<Integer> state) {
int sum = one.or(0) + (state.exists() ? state.get() : 0);
Tuple2<String, Integer> output = new Tuple2<String, Integer>(word, sum);
state.update(sum);
- return Optional.of(output);
+ return output;
}
};
- // This will give a Dstream made of state (which is the cumulative count of the words)
- JavaTrackStateDStream<String, Integer, Integer, Tuple2<String, Integer>> stateDstream =
- wordsDstream.trackStateByKey(StateSpec.function(trackStateFunc).initialState(initialRDD));
+ // DStream made of get cumulative counts that get updated in every batch
+ JavaMapWithStateDStream<String, Integer, Integer, Tuple2<String, Integer>> stateDstream =
+ wordsDstream.mapWithState(StateSpec.function(mappingFunc).initialState(initialRDD));
stateDstream.print();
ssc.start();
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala
index a4f847f118..2dce1820d9 100644
--- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala
@@ -49,7 +49,7 @@ object StatefulNetworkWordCount {
val ssc = new StreamingContext(sparkConf, Seconds(1))
ssc.checkpoint(".")
- // Initial RDD input to trackStateByKey
+ // Initial state RDD for mapWithState operation
val initialRDD = ssc.sparkContext.parallelize(List(("hello", 1), ("world", 1)))
// Create a ReceiverInputDStream on target ip:port and count the
@@ -58,17 +58,17 @@ object StatefulNetworkWordCount {
val words = lines.flatMap(_.split(" "))
val wordDstream = words.map(x => (x, 1))
- // Update the cumulative count using updateStateByKey
+ // Update the cumulative count using mapWithState
// This will give a DStream made of state (which is the cumulative count of the words)
- val trackStateFunc = (batchTime: Time, word: String, one: Option[Int], state: State[Int]) => {
+ val mappingFunc = (word: String, one: Option[Int], state: State[Int]) => {
val sum = one.getOrElse(0) + state.getOption.getOrElse(0)
val output = (word, sum)
state.update(sum)
- Some(output)
+ output
}
- val stateDstream = wordDstream.trackStateByKey(
- StateSpec.function(trackStateFunc).initialState(initialRDD))
+ val stateDstream = wordDstream.mapWithState(
+ StateSpec.function(mappingFunc).initialState(initialRDD))
stateDstream.print()
ssc.start()
ssc.awaitTermination()
diff --git a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java
index 4eee97bc89..89e0c7fdf7 100644
--- a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java
+++ b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java
@@ -32,12 +32,10 @@ import org.apache.spark.Accumulator;
import org.apache.spark.HashPartitioner;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.function.Function2;
-import org.apache.spark.api.java.function.Function4;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaPairDStream;
-import org.apache.spark.streaming.api.java.JavaTrackStateDStream;
+import org.apache.spark.streaming.api.java.JavaMapWithStateDStream;
/**
* Most of these tests replicate org.apache.spark.streaming.JavaAPISuite using java 8
@@ -863,12 +861,12 @@ public class Java8APISuite extends LocalJavaStreamingContext implements Serializ
/**
* This test is only for testing the APIs. It's not necessary to run it.
*/
- public void testTrackStateByAPI() {
+ public void testMapWithStateAPI() {
JavaPairRDD<String, Boolean> initialRDD = null;
JavaPairDStream<String, Integer> wordsDstream = null;
- JavaTrackStateDStream<String, Integer, Boolean, Double> stateDstream =
- wordsDstream.trackStateByKey(
+ JavaMapWithStateDStream<String, Integer, Boolean, Double> stateDstream =
+ wordsDstream.mapWithState(
StateSpec.<String, Integer, Boolean, Double> function((time, key, value, state) -> {
// Use all State's methods here
state.exists();
@@ -884,9 +882,9 @@ public class Java8APISuite extends LocalJavaStreamingContext implements Serializ
JavaPairDStream<String, Boolean> emittedRecords = stateDstream.stateSnapshots();
- JavaTrackStateDStream<String, Integer, Boolean, Double> stateDstream2 =
- wordsDstream.trackStateByKey(
- StateSpec.<String, Integer, Boolean, Double>function((value, state) -> {
+ JavaMapWithStateDStream<String, Integer, Boolean, Double> stateDstream2 =
+ wordsDstream.mapWithState(
+ StateSpec.<String, Integer, Boolean, Double>function((key, value, state) -> {
state.exists();
state.get();
state.isTimingOut();
@@ -898,6 +896,6 @@ public class Java8APISuite extends LocalJavaStreamingContext implements Serializ
.partitioner(new HashPartitioner(10))
.timeout(Durations.seconds(10)));
- JavaPairDStream<String, Boolean> emittedRecords2 = stateDstream2.stateSnapshots();
+ JavaPairDStream<String, Boolean> mappedDStream = stateDstream2.stateSnapshots();
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/State.scala b/streaming/src/main/scala/org/apache/spark/streaming/State.scala
index 604e64fc61..b47bdda2c2 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/State.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/State.scala
@@ -23,14 +23,14 @@ import org.apache.spark.annotation.Experimental
/**
* :: Experimental ::
- * Abstract class for getting and updating the tracked state in the `trackStateByKey` operation of
- * a [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a
- * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
+ * Abstract class for getting and updating the state in mapping function used in the `mapWithState`
+ * operation of a [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala)
+ * or a [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
*
* Scala example of using `State`:
* {{{
- * // A tracking function that maintains an integer state and return a String
- * def trackStateFunc(data: Option[Int], state: State[Int]): Option[String] = {
+ * // A mapping function that maintains an integer state and returns a String
+ * def mappingFunction(key: String, value: Option[Int], state: State[Int]): Option[String] = {
* // Check if state exists
* if (state.exists) {
* val existingState = state.get // Get the existing state
@@ -52,12 +52,12 @@ import org.apache.spark.annotation.Experimental
*
* Java example of using `State`:
* {{{
- * // A tracking function that maintains an integer state and return a String
- * Function2<Optional<Integer>, State<Integer>, Optional<String>> trackStateFunc =
- * new Function2<Optional<Integer>, State<Integer>, Optional<String>>() {
+ * // A mapping function that maintains an integer state and returns a String
+ * Function3<String, Optional<Integer>, State<Integer>, String> mappingFunction =
+ * new Function3<String, Optional<Integer>, State<Integer>, String>() {
*
* @Override
- * public Optional<String> call(Optional<Integer> one, State<Integer> state) {
+ * public String call(String key, Optional<Integer> value, State<Integer> state) {
* if (state.exists()) {
* int existingState = state.get(); // Get the existing state
* boolean shouldRemove = ...; // Decide whether to remove the state
@@ -75,6 +75,8 @@ import org.apache.spark.annotation.Experimental
* }
* };
* }}}
+ *
+ * @tparam S Class of the state
*/
@Experimental
sealed abstract class State[S] {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
index bea5b9df20..9f6f95223f 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
@@ -20,7 +20,7 @@ package org.apache.spark.streaming
import com.google.common.base.Optional
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.{JavaPairRDD, JavaUtils}
-import org.apache.spark.api.java.function.{Function2 => JFunction2, Function4 => JFunction4}
+import org.apache.spark.api.java.function.{Function3 => JFunction3, Function4 => JFunction4}
import org.apache.spark.rdd.RDD
import org.apache.spark.util.ClosureCleaner
import org.apache.spark.{HashPartitioner, Partitioner}
@@ -28,7 +28,7 @@ import org.apache.spark.{HashPartitioner, Partitioner}
/**
* :: Experimental ::
* Abstract class representing all the specifications of the DStream transformation
- * `trackStateByKey` operation of a
+ * `mapWithState` operation of a
* [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a
* [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
* Use the [[org.apache.spark.streaming.StateSpec StateSpec.apply()]] or
@@ -37,50 +37,63 @@ import org.apache.spark.{HashPartitioner, Partitioner}
*
* Example in Scala:
* {{{
- * def trackingFunction(data: Option[ValueType], wrappedState: State[StateType]): EmittedType = {
- * ...
+ * // A mapping function that maintains an integer state and return a String
+ * def mappingFunction(key: String, value: Option[Int], state: State[Int]): Option[String] = {
+ * // Use state.exists(), state.get(), state.update() and state.remove()
+ * // to manage state, and return the necessary string
* }
*
- * val spec = StateSpec.function(trackingFunction).numPartitions(10)
+ * val spec = StateSpec.function(mappingFunction).numPartitions(10)
*
- * val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType](spec)
+ * val mapWithStateDStream = keyValueDStream.mapWithState[StateType, MappedType](spec)
* }}}
*
* Example in Java:
* {{{
- * StateSpec<KeyType, ValueType, StateType, EmittedDataType> spec =
- * StateSpec.<KeyType, ValueType, StateType, EmittedDataType>function(trackingFunction)
- * .numPartition(10);
+ * // A mapping function that maintains an integer state and return a string
+ * Function3<String, Optional<Integer>, State<Integer>, String> mappingFunction =
+ * new Function3<String, Optional<Integer>, State<Integer>, String>() {
+ * @Override
+ * public Optional<String> call(Optional<Integer> value, State<Integer> state) {
+ * // Use state.exists(), state.get(), state.update() and state.remove()
+ * // to manage state, and return the necessary string
+ * }
+ * };
*
- * JavaTrackStateDStream<KeyType, ValueType, StateType, EmittedType> emittedRecordDStream =
- * javaPairDStream.<StateType, EmittedDataType>trackStateByKey(spec);
+ * JavaMapWithStateDStream<String, Integer, Integer, String> mapWithStateDStream =
+ * keyValueDStream.mapWithState(StateSpec.function(mappingFunc));
* }}}
+ *
+ * @tparam KeyType Class of the state key
+ * @tparam ValueType Class of the state value
+ * @tparam StateType Class of the state data
+ * @tparam MappedType Class of the mapped elements
*/
@Experimental
-sealed abstract class StateSpec[KeyType, ValueType, StateType, EmittedType] extends Serializable {
+sealed abstract class StateSpec[KeyType, ValueType, StateType, MappedType] extends Serializable {
- /** Set the RDD containing the initial states that will be used by `trackStateByKey` */
+ /** Set the RDD containing the initial states that will be used by `mapWithState` */
def initialState(rdd: RDD[(KeyType, StateType)]): this.type
- /** Set the RDD containing the initial states that will be used by `trackStateByKey` */
+ /** Set the RDD containing the initial states that will be used by `mapWithState` */
def initialState(javaPairRDD: JavaPairRDD[KeyType, StateType]): this.type
/**
- * Set the number of partitions by which the state RDDs generated by `trackStateByKey`
+ * Set the number of partitions by which the state RDDs generated by `mapWithState`
* will be partitioned. Hash partitioning will be used.
*/
def numPartitions(numPartitions: Int): this.type
/**
- * Set the partitioner by which the state RDDs generated by `trackStateByKey` will be
+ * Set the partitioner by which the state RDDs generated by `mapWithState` will be
* be partitioned.
*/
def partitioner(partitioner: Partitioner): this.type
/**
* Set the duration after which the state of an idle key will be removed. A key and its state is
- * considered idle if it has not received any data for at least the given duration. The state
- * tracking function will be called one final time on the idle states that are going to be
+ * considered idle if it has not received any data for at least the given duration. The
+ * mapping function will be called one final time on the idle states that are going to be
* removed; [[org.apache.spark.streaming.State State.isTimingOut()]] set
* to `true` in that call.
*/
@@ -91,115 +104,124 @@ sealed abstract class StateSpec[KeyType, ValueType, StateType, EmittedType] exte
/**
* :: Experimental ::
* Builder object for creating instances of [[org.apache.spark.streaming.StateSpec StateSpec]]
- * that is used for specifying the parameters of the DStream transformation `trackStateByKey`
+ * that is used for specifying the parameters of the DStream transformation `mapWithState`
* that is used for specifying the parameters of the DStream transformation
- * `trackStateByKey` operation of a
+ * `mapWithState` operation of a
* [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a
* [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
*
* Example in Scala:
* {{{
- * def trackingFunction(data: Option[ValueType], wrappedState: State[StateType]): EmittedType = {
- * ...
+ * // A mapping function that maintains an integer state and return a String
+ * def mappingFunction(key: String, value: Option[Int], state: State[Int]): Option[String] = {
+ * // Use state.exists(), state.get(), state.update() and state.remove()
+ * // to manage state, and return the necessary string
* }
*
- * val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType](
- * StateSpec.function(trackingFunction).numPartitions(10))
+ * val spec = StateSpec.function(mappingFunction).numPartitions(10)
+ *
+ * val mapWithStateDStream = keyValueDStream.mapWithState[StateType, MappedType](spec)
* }}}
*
* Example in Java:
* {{{
- * StateSpec<KeyType, ValueType, StateType, EmittedDataType> spec =
- * StateSpec.<KeyType, ValueType, StateType, EmittedDataType>function(trackingFunction)
- * .numPartition(10);
+ * // A mapping function that maintains an integer state and return a string
+ * Function3<String, Optional<Integer>, State<Integer>, String> mappingFunction =
+ * new Function3<String, Optional<Integer>, State<Integer>, String>() {
+ * @Override
+ * public Optional<String> call(Optional<Integer> value, State<Integer> state) {
+ * // Use state.exists(), state.get(), state.update() and state.remove()
+ * // to manage state, and return the necessary string
+ * }
+ * };
*
- * JavaTrackStateDStream<KeyType, ValueType, StateType, EmittedType> emittedRecordDStream =
- * javaPairDStream.<StateType, EmittedDataType>trackStateByKey(spec);
- * }}}
+ * JavaMapWithStateDStream<String, Integer, Integer, String> mapWithStateDStream =
+ * keyValueDStream.mapWithState(StateSpec.function(mappingFunc));
+ *}}}
*/
@Experimental
object StateSpec {
/**
* Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications
- * of the `trackStateByKey` operation on a
+ * of the `mapWithState` operation on a
* [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]].
*
- * @param trackingFunction The function applied on every data item to manage the associated state
- * and generate the emitted data
+ * @param mappingFunction The function applied on every data item to manage the associated state
+ * and generate the mapped data
* @tparam KeyType Class of the keys
* @tparam ValueType Class of the values
* @tparam StateType Class of the states data
- * @tparam EmittedType Class of the emitted data
+ * @tparam MappedType Class of the mapped data
*/
- def function[KeyType, ValueType, StateType, EmittedType](
- trackingFunction: (Time, KeyType, Option[ValueType], State[StateType]) => Option[EmittedType]
- ): StateSpec[KeyType, ValueType, StateType, EmittedType] = {
- ClosureCleaner.clean(trackingFunction, checkSerializable = true)
- new StateSpecImpl(trackingFunction)
+ def function[KeyType, ValueType, StateType, MappedType](
+ mappingFunction: (Time, KeyType, Option[ValueType], State[StateType]) => Option[MappedType]
+ ): StateSpec[KeyType, ValueType, StateType, MappedType] = {
+ ClosureCleaner.clean(mappingFunction, checkSerializable = true)
+ new StateSpecImpl(mappingFunction)
}
/**
* Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications
- * of the `trackStateByKey` operation on a
+ * of the `mapWithState` operation on a
* [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]].
*
- * @param trackingFunction The function applied on every data item to manage the associated state
- * and generate the emitted data
+ * @param mappingFunction The function applied on every data item to manage the associated state
+ * and generate the mapped data
* @tparam ValueType Class of the values
* @tparam StateType Class of the states data
- * @tparam EmittedType Class of the emitted data
+ * @tparam MappedType Class of the mapped data
*/
- def function[KeyType, ValueType, StateType, EmittedType](
- trackingFunction: (Option[ValueType], State[StateType]) => EmittedType
- ): StateSpec[KeyType, ValueType, StateType, EmittedType] = {
- ClosureCleaner.clean(trackingFunction, checkSerializable = true)
+ def function[KeyType, ValueType, StateType, MappedType](
+ mappingFunction: (KeyType, Option[ValueType], State[StateType]) => MappedType
+ ): StateSpec[KeyType, ValueType, StateType, MappedType] = {
+ ClosureCleaner.clean(mappingFunction, checkSerializable = true)
val wrappedFunction =
- (time: Time, key: Any, value: Option[ValueType], state: State[StateType]) => {
- Some(trackingFunction(value, state))
+ (time: Time, key: KeyType, value: Option[ValueType], state: State[StateType]) => {
+ Some(mappingFunction(key, value, state))
}
new StateSpecImpl(wrappedFunction)
}
/**
* Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all
- * the specifications of the `trackStateByKey` operation on a
+ * the specifications of the `mapWithState` operation on a
* [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]].
*
- * @param javaTrackingFunction The function applied on every data item to manage the associated
- * state and generate the emitted data
+ * @param mappingFunction The function applied on every data item to manage the associated
+ * state and generate the mapped data
* @tparam KeyType Class of the keys
* @tparam ValueType Class of the values
* @tparam StateType Class of the states data
- * @tparam EmittedType Class of the emitted data
+ * @tparam MappedType Class of the mapped data
*/
- def function[KeyType, ValueType, StateType, EmittedType](javaTrackingFunction:
- JFunction4[Time, KeyType, Optional[ValueType], State[StateType], Optional[EmittedType]]):
- StateSpec[KeyType, ValueType, StateType, EmittedType] = {
- val trackingFunc = (time: Time, k: KeyType, v: Option[ValueType], s: State[StateType]) => {
- val t = javaTrackingFunction.call(time, k, JavaUtils.optionToOptional(v), s)
+ def function[KeyType, ValueType, StateType, MappedType](mappingFunction:
+ JFunction4[Time, KeyType, Optional[ValueType], State[StateType], Optional[MappedType]]):
+ StateSpec[KeyType, ValueType, StateType, MappedType] = {
+ val wrappedFunc = (time: Time, k: KeyType, v: Option[ValueType], s: State[StateType]) => {
+ val t = mappingFunction.call(time, k, JavaUtils.optionToOptional(v), s)
Option(t.orNull)
}
- StateSpec.function(trackingFunc)
+ StateSpec.function(wrappedFunc)
}
/**
* Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications
- * of the `trackStateByKey` operation on a
+ * of the `mapWithState` operation on a
* [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]].
*
- * @param javaTrackingFunction The function applied on every data item to manage the associated
- * state and generate the emitted data
+ * @param mappingFunction The function applied on every data item to manage the associated
+ * state and generate the mapped data
* @tparam ValueType Class of the values
* @tparam StateType Class of the states data
- * @tparam EmittedType Class of the emitted data
+ * @tparam MappedType Class of the mapped data
*/
- def function[KeyType, ValueType, StateType, EmittedType](
- javaTrackingFunction: JFunction2[Optional[ValueType], State[StateType], EmittedType]):
- StateSpec[KeyType, ValueType, StateType, EmittedType] = {
- val trackingFunc = (v: Option[ValueType], s: State[StateType]) => {
- javaTrackingFunction.call(Optional.fromNullable(v.get), s)
+ def function[KeyType, ValueType, StateType, MappedType](
+ mappingFunction: JFunction3[KeyType, Optional[ValueType], State[StateType], MappedType]):
+ StateSpec[KeyType, ValueType, StateType, MappedType] = {
+ val wrappedFunc = (k: KeyType, v: Option[ValueType], s: State[StateType]) => {
+ mappingFunction.call(k, Optional.fromNullable(v.get), s)
}
- StateSpec.function(trackingFunc)
+ StateSpec.function(wrappedFunc)
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaMapWithStateDStream.scala
index f459930d06..16c0d6fff8 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaMapWithStateDStream.scala
@@ -19,23 +19,23 @@ package org.apache.spark.streaming.api.java
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaSparkContext
-import org.apache.spark.streaming.dstream.TrackStateDStream
+import org.apache.spark.streaming.dstream.MapWithStateDStream
/**
* :: Experimental ::
- * [[JavaDStream]] representing the stream of records emitted by the tracking function in the
- * `trackStateByKey` operation on a [[JavaPairDStream]]. Additionally, it also gives access to the
+ * DStream representing the stream of data generated by `mapWithState` operation on a
+ * [[JavaPairDStream]]. Additionally, it also gives access to the
* stream of state snapshots, that is, the state data of all keys after a batch has updated them.
*
- * @tparam KeyType Class of the state key
- * @tparam ValueType Class of the state value
- * @tparam StateType Class of the state
- * @tparam EmittedType Class of the emitted records
+ * @tparam KeyType Class of the keys
+ * @tparam ValueType Class of the values
+ * @tparam StateType Class of the state data
+ * @tparam MappedType Class of the mapped data
*/
@Experimental
-class JavaTrackStateDStream[KeyType, ValueType, StateType, EmittedType](
- dstream: TrackStateDStream[KeyType, ValueType, StateType, EmittedType])
- extends JavaDStream[EmittedType](dstream)(JavaSparkContext.fakeClassTag) {
+class JavaMapWithStateDStream[KeyType, ValueType, StateType, MappedType] private[streaming](
+ dstream: MapWithStateDStream[KeyType, ValueType, StateType, MappedType])
+ extends JavaDStream[MappedType](dstream)(JavaSparkContext.fakeClassTag) {
def stateSnapshots(): JavaPairDStream[KeyType, StateType] =
new JavaPairDStream(dstream.stateSnapshots())(
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
index 70e32b383e..42ddd63f0f 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
@@ -430,42 +430,36 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
/**
* :: Experimental ::
- * Return a new [[JavaDStream]] of data generated by combining the key-value data in `this` stream
- * with a continuously updated per-key state. The user-provided state tracking function is
- * applied on each keyed data item along with its corresponding state. The function can choose to
- * update/remove the state and return a transformed data, which forms the
- * [[JavaTrackStateDStream]].
+ * Return a [[JavaMapWithStateDStream]] by applying a function to every key-value element of
+ * `this` stream, while maintaining some state data for each unique key. The mapping function
+ * and other specification (e.g. partitioners, timeouts, initial state data, etc.) of this
+ * transformation can be specified using [[StateSpec]] class. The state data is accessible in
+ * as a parameter of type [[State]] in the mapping function.
*
- * The specifications of this transformation is made through the
- * [[org.apache.spark.streaming.StateSpec StateSpec]] class. Besides the tracking function, there
- * are a number of optional parameters - initial state data, number of partitions, timeouts, etc.
- * See the [[org.apache.spark.streaming.StateSpec StateSpec]] for more details.
- *
- * Example of using `trackStateByKey`:
+ * Example of using `mapWithState`:
* {{{
- * // A tracking function that maintains an integer state and return a String
- * Function2<Optional<Integer>, State<Integer>, Optional<String>> trackStateFunc =
- * new Function2<Optional<Integer>, State<Integer>, Optional<String>>() {
- *
- * @Override
- * public Optional<String> call(Optional<Integer> one, State<Integer> state) {
- * // Check if state exists, accordingly update/remove state and return transformed data
- * }
+ * // A mapping function that maintains an integer state and return a string
+ * Function3<String, Optional<Integer>, State<Integer>, String> mappingFunction =
+ * new Function3<String, Optional<Integer>, State<Integer>, String>() {
+ * @Override
+ * public Optional<String> call(Optional<Integer> value, State<Integer> state) {
+ * // Use state.exists(), state.get(), state.update() and state.remove()
+ * // to manage state, and return the necessary string
+ * }
* };
*
- * JavaTrackStateDStream<Integer, Integer, Integer, String> trackStateDStream =
- * keyValueDStream.<Integer, String>trackStateByKey(
- * StateSpec.function(trackStateFunc).numPartitions(10));
- * }}}
+ * JavaMapWithStateDStream<String, Integer, Integer, String> mapWithStateDStream =
+ * keyValueDStream.mapWithState(StateSpec.function(mappingFunc));
+ *}}}
*
* @param spec Specification of this transformation
- * @tparam StateType Class type of the state
- * @tparam EmittedType Class type of the tranformed data return by the tracking function
+ * @tparam StateType Class type of the state data
+ * @tparam MappedType Class type of the mapped data
*/
@Experimental
- def trackStateByKey[StateType, EmittedType](spec: StateSpec[K, V, StateType, EmittedType]):
- JavaTrackStateDStream[K, V, StateType, EmittedType] = {
- new JavaTrackStateDStream(dstream.trackStateByKey(spec)(
+ def mapWithState[StateType, MappedType](spec: StateSpec[K, V, StateType, MappedType]):
+ JavaMapWithStateDStream[K, V, StateType, MappedType] = {
+ new JavaMapWithStateDStream(dstream.mapWithState(spec)(
JavaSparkContext.fakeClassTag,
JavaSparkContext.fakeClassTag))
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala
index ea6213420e..706465d4e2 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala
@@ -24,53 +24,52 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.{EmptyRDD, RDD}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming._
-import org.apache.spark.streaming.rdd.{TrackStateRDD, TrackStateRDDRecord}
-import org.apache.spark.streaming.dstream.InternalTrackStateDStream._
+import org.apache.spark.streaming.rdd.{MapWithStateRDD, MapWithStateRDDRecord}
+import org.apache.spark.streaming.dstream.InternalMapWithStateDStream._
/**
* :: Experimental ::
- * DStream representing the stream of records emitted by the tracking function in the
- * `trackStateByKey` operation on a
+ * DStream representing the stream of data generated by `mapWithState` operation on a
* [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]].
* Additionally, it also gives access to the stream of state snapshots, that is, the state data of
* all keys after a batch has updated them.
*
- * @tparam KeyType Class of the state key
- * @tparam ValueType Class of the state value
+ * @tparam KeyType Class of the key
+ * @tparam ValueType Class of the value
* @tparam StateType Class of the state data
- * @tparam EmittedType Class of the emitted records
+ * @tparam MappedType Class of the mapped data
*/
@Experimental
-sealed abstract class TrackStateDStream[KeyType, ValueType, StateType, EmittedType: ClassTag](
- ssc: StreamingContext) extends DStream[EmittedType](ssc) {
+sealed abstract class MapWithStateDStream[KeyType, ValueType, StateType, MappedType: ClassTag](
+ ssc: StreamingContext) extends DStream[MappedType](ssc) {
/** Return a pair DStream where each RDD is the snapshot of the state of all the keys. */
def stateSnapshots(): DStream[(KeyType, StateType)]
}
-/** Internal implementation of the [[TrackStateDStream]] */
-private[streaming] class TrackStateDStreamImpl[
- KeyType: ClassTag, ValueType: ClassTag, StateType: ClassTag, EmittedType: ClassTag](
+/** Internal implementation of the [[MapWithStateDStream]] */
+private[streaming] class MapWithStateDStreamImpl[
+ KeyType: ClassTag, ValueType: ClassTag, StateType: ClassTag, MappedType: ClassTag](
dataStream: DStream[(KeyType, ValueType)],
- spec: StateSpecImpl[KeyType, ValueType, StateType, EmittedType])
- extends TrackStateDStream[KeyType, ValueType, StateType, EmittedType](dataStream.context) {
+ spec: StateSpecImpl[KeyType, ValueType, StateType, MappedType])
+ extends MapWithStateDStream[KeyType, ValueType, StateType, MappedType](dataStream.context) {
private val internalStream =
- new InternalTrackStateDStream[KeyType, ValueType, StateType, EmittedType](dataStream, spec)
+ new InternalMapWithStateDStream[KeyType, ValueType, StateType, MappedType](dataStream, spec)
override def slideDuration: Duration = internalStream.slideDuration
override def dependencies: List[DStream[_]] = List(internalStream)
- override def compute(validTime: Time): Option[RDD[EmittedType]] = {
- internalStream.getOrCompute(validTime).map { _.flatMap[EmittedType] { _.emittedRecords } }
+ override def compute(validTime: Time): Option[RDD[MappedType]] = {
+ internalStream.getOrCompute(validTime).map { _.flatMap[MappedType] { _.mappedData } }
}
/**
* Forward the checkpoint interval to the internal DStream that computes the state maps. This
* to make sure that this DStream does not get checkpointed, only the internal stream.
*/
- override def checkpoint(checkpointInterval: Duration): DStream[EmittedType] = {
+ override def checkpoint(checkpointInterval: Duration): DStream[MappedType] = {
internalStream.checkpoint(checkpointInterval)
this
}
@@ -87,32 +86,32 @@ private[streaming] class TrackStateDStreamImpl[
def stateClass: Class[_] = implicitly[ClassTag[StateType]].runtimeClass
- def emittedClass: Class[_] = implicitly[ClassTag[EmittedType]].runtimeClass
+ def mappedClass: Class[_] = implicitly[ClassTag[MappedType]].runtimeClass
}
/**
* A DStream that allows per-key state to be maintains, and arbitrary records to be generated
- * based on updates to the state. This is the main DStream that implements the `trackStateByKey`
+ * based on updates to the state. This is the main DStream that implements the `mapWithState`
* operation on DStreams.
*
* @param parent Parent (key, value) stream that is the source
- * @param spec Specifications of the trackStateByKey operation
+ * @param spec Specifications of the mapWithState operation
* @tparam K Key type
* @tparam V Value type
* @tparam S Type of the state maintained
- * @tparam E Type of the emitted data
+ * @tparam E Type of the mapped data
*/
private[streaming]
-class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
+class InternalMapWithStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
parent: DStream[(K, V)], spec: StateSpecImpl[K, V, S, E])
- extends DStream[TrackStateRDDRecord[K, S, E]](parent.context) {
+ extends DStream[MapWithStateRDDRecord[K, S, E]](parent.context) {
persist(StorageLevel.MEMORY_ONLY)
private val partitioner = spec.getPartitioner().getOrElse(
new HashPartitioner(ssc.sc.defaultParallelism))
- private val trackingFunction = spec.getFunction()
+ private val mappingFunction = spec.getFunction()
override def slideDuration: Duration = parent.slideDuration
@@ -130,7 +129,7 @@ class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassT
}
/** Method that generates a RDD for the given time */
- override def compute(validTime: Time): Option[RDD[TrackStateRDDRecord[K, S, E]]] = {
+ override def compute(validTime: Time): Option[RDD[MapWithStateRDDRecord[K, S, E]]] = {
// Get the previous state or create a new empty state RDD
val prevStateRDD = getOrCompute(validTime - slideDuration) match {
case Some(rdd) =>
@@ -138,13 +137,13 @@ class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassT
// If the RDD is not partitioned the right way, let us repartition it using the
// partition index as the key. This is to ensure that state RDD is always partitioned
// before creating another state RDD using it
- TrackStateRDD.createFromRDD[K, V, S, E](
+ MapWithStateRDD.createFromRDD[K, V, S, E](
rdd.flatMap { _.stateMap.getAll() }, partitioner, validTime)
} else {
rdd
}
case None =>
- TrackStateRDD.createFromPairRDD[K, V, S, E](
+ MapWithStateRDD.createFromPairRDD[K, V, S, E](
spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)),
partitioner,
validTime
@@ -161,11 +160,11 @@ class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassT
val timeoutThresholdTime = spec.getTimeoutInterval().map { interval =>
(validTime - interval).milliseconds
}
- Some(new TrackStateRDD(
- prevStateRDD, partitionedDataRDD, trackingFunction, validTime, timeoutThresholdTime))
+ Some(new MapWithStateRDD(
+ prevStateRDD, partitionedDataRDD, mappingFunction, validTime, timeoutThresholdTime))
}
}
-private[streaming] object InternalTrackStateDStream {
+private[streaming] object InternalMapWithStateDStream {
private val DEFAULT_CHECKPOINT_DURATION_MULTIPLIER = 10
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
index 2762309134..a64a1fe93f 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
@@ -352,39 +352,36 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)])
/**
* :: Experimental ::
- * Return a new DStream of data generated by combining the key-value data in `this` stream
- * with a continuously updated per-key state. The user-provided state tracking function is
- * applied on each keyed data item along with its corresponding state. The function can choose to
- * update/remove the state and return a transformed data, which forms the
- * [[org.apache.spark.streaming.dstream.TrackStateDStream]].
+ * Return a [[MapWithStateDStream]] by applying a function to every key-value element of
+ * `this` stream, while maintaining some state data for each unique key. The mapping function
+ * and other specification (e.g. partitioners, timeouts, initial state data, etc.) of this
+ * transformation can be specified using [[StateSpec]] class. The state data is accessible in
+ * as a parameter of type [[State]] in the mapping function.
*
- * The specifications of this transformation is made through the
- * [[org.apache.spark.streaming.StateSpec StateSpec]] class. Besides the tracking function, there
- * are a number of optional parameters - initial state data, number of partitions, timeouts, etc.
- * See the [[org.apache.spark.streaming.StateSpec StateSpec spec docs]] for more details.
- *
- * Example of using `trackStateByKey`:
+ * Example of using `mapWithState`:
* {{{
- * def trackingFunction(data: Option[Int], wrappedState: State[Int]): String = {
- * // Check if state exists, accordingly update/remove state and return transformed data
+ * // A mapping function that maintains an integer state and return a String
+ * def mappingFunction(key: String, value: Option[Int], state: State[Int]): Option[String] = {
+ * // Use state.exists(), state.get(), state.update() and state.remove()
+ * // to manage state, and return the necessary string
* }
*
- * val spec = StateSpec.function(trackingFunction).numPartitions(10)
+ * val spec = StateSpec.function(mappingFunction).numPartitions(10)
*
- * val trackStateDStream = keyValueDStream.trackStateByKey[Int, String](spec)
+ * val mapWithStateDStream = keyValueDStream.mapWithState[StateType, MappedType](spec)
* }}}
*
* @param spec Specification of this transformation
- * @tparam StateType Class type of the state
- * @tparam EmittedType Class type of the tranformed data return by the tracking function
+ * @tparam StateType Class type of the state data
+ * @tparam MappedType Class type of the mapped data
*/
@Experimental
- def trackStateByKey[StateType: ClassTag, EmittedType: ClassTag](
- spec: StateSpec[K, V, StateType, EmittedType]
- ): TrackStateDStream[K, V, StateType, EmittedType] = {
- new TrackStateDStreamImpl[K, V, StateType, EmittedType](
+ def mapWithState[StateType: ClassTag, MappedType: ClassTag](
+ spec: StateSpec[K, V, StateType, MappedType]
+ ): MapWithStateDStream[K, V, StateType, MappedType] = {
+ new MapWithStateDStreamImpl[K, V, StateType, MappedType](
self,
- spec.asInstanceOf[StateSpecImpl[K, V, StateType, EmittedType]]
+ spec.asInstanceOf[StateSpecImpl[K, V, StateType, MappedType]]
)
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala
index 30aafcf146..ed95171f73 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala
@@ -29,60 +29,60 @@ import org.apache.spark.util.Utils
import org.apache.spark._
/**
- * Record storing the keyed-state [[TrackStateRDD]]. Each record contains a [[StateMap]] and a
- * sequence of records returned by the tracking function of `trackStateByKey`.
+ * Record storing the keyed-state [[MapWithStateRDD]]. Each record contains a [[StateMap]] and a
+ * sequence of records returned by the mapping function of `mapWithState`.
*/
-private[streaming] case class TrackStateRDDRecord[K, S, E](
- var stateMap: StateMap[K, S], var emittedRecords: Seq[E])
+private[streaming] case class MapWithStateRDDRecord[K, S, E](
+ var stateMap: StateMap[K, S], var mappedData: Seq[E])
-private[streaming] object TrackStateRDDRecord {
+private[streaming] object MapWithStateRDDRecord {
def updateRecordWithData[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
- prevRecord: Option[TrackStateRDDRecord[K, S, E]],
+ prevRecord: Option[MapWithStateRDDRecord[K, S, E]],
dataIterator: Iterator[(K, V)],
- updateFunction: (Time, K, Option[V], State[S]) => Option[E],
+ mappingFunction: (Time, K, Option[V], State[S]) => Option[E],
batchTime: Time,
timeoutThresholdTime: Option[Long],
removeTimedoutData: Boolean
- ): TrackStateRDDRecord[K, S, E] = {
+ ): MapWithStateRDDRecord[K, S, E] = {
// Create a new state map by cloning the previous one (if it exists) or by creating an empty one
val newStateMap = prevRecord.map { _.stateMap.copy() }. getOrElse { new EmptyStateMap[K, S]() }
- val emittedRecords = new ArrayBuffer[E]
+ val mappedData = new ArrayBuffer[E]
val wrappedState = new StateImpl[S]()
- // Call the tracking function on each record in the data iterator, and accordingly
- // update the states touched, and collect the data returned by the tracking function
+ // Call the mapping function on each record in the data iterator, and accordingly
+ // update the states touched, and collect the data returned by the mapping function
dataIterator.foreach { case (key, value) =>
wrappedState.wrap(newStateMap.get(key))
- val emittedRecord = updateFunction(batchTime, key, Some(value), wrappedState)
+ val returned = mappingFunction(batchTime, key, Some(value), wrappedState)
if (wrappedState.isRemoved) {
newStateMap.remove(key)
} else if (wrappedState.isUpdated || timeoutThresholdTime.isDefined) {
newStateMap.put(key, wrappedState.get(), batchTime.milliseconds)
}
- emittedRecords ++= emittedRecord
+ mappedData ++= returned
}
- // Get the timed out state records, call the tracking function on each and collect the
+ // Get the timed out state records, call the mapping function on each and collect the
// data returned
if (removeTimedoutData && timeoutThresholdTime.isDefined) {
newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) =>
wrappedState.wrapTiminoutState(state)
- val emittedRecord = updateFunction(batchTime, key, None, wrappedState)
- emittedRecords ++= emittedRecord
+ val returned = mappingFunction(batchTime, key, None, wrappedState)
+ mappedData ++= returned
newStateMap.remove(key)
}
}
- TrackStateRDDRecord(newStateMap, emittedRecords)
+ MapWithStateRDDRecord(newStateMap, mappedData)
}
}
/**
- * Partition of the [[TrackStateRDD]], which depends on corresponding partitions of prev state
+ * Partition of the [[MapWithStateRDD]], which depends on corresponding partitions of prev state
* RDD, and a partitioned keyed-data RDD
*/
-private[streaming] class TrackStateRDDPartition(
+private[streaming] class MapWithStateRDDPartition(
idx: Int,
@transient private var prevStateRDD: RDD[_],
@transient private var partitionedDataRDD: RDD[_]) extends Partition {
@@ -104,27 +104,28 @@ private[streaming] class TrackStateRDDPartition(
/**
- * RDD storing the keyed-state of `trackStateByKey` and corresponding emitted records.
- * Each partition of this RDD has a single record of type [[TrackStateRDDRecord]]. This contains a
- * [[StateMap]] (containing the keyed-states) and the sequence of records returned by the tracking
- * function of `trackStateByKey`.
- * @param prevStateRDD The previous TrackStateRDD on whose StateMap data `this` RDD will be created
+ * RDD storing the keyed states of `mapWithState` operation and corresponding mapped data.
+ * Each partition of this RDD has a single record of type [[MapWithStateRDDRecord]]. This contains a
+ * [[StateMap]] (containing the keyed-states) and the sequence of records returned by the mapping
+ * function of `mapWithState`.
+ * @param prevStateRDD The previous MapWithStateRDD on whose StateMap data `this` RDD
+ * will be created
* @param partitionedDataRDD The partitioned data RDD which is used update the previous StateMaps
* in the `prevStateRDD` to create `this` RDD
- * @param trackingFunction The function that will be used to update state and return new data
+ * @param mappingFunction The function that will be used to update state and return new data
* @param batchTime The time of the batch to which this RDD belongs to. Use to update
* @param timeoutThresholdTime The time to indicate which keys are timeout
*/
-private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
- private var prevStateRDD: RDD[TrackStateRDDRecord[K, S, E]],
+private[streaming] class MapWithStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
+ private var prevStateRDD: RDD[MapWithStateRDDRecord[K, S, E]],
private var partitionedDataRDD: RDD[(K, V)],
- trackingFunction: (Time, K, Option[V], State[S]) => Option[E],
+ mappingFunction: (Time, K, Option[V], State[S]) => Option[E],
batchTime: Time,
timeoutThresholdTime: Option[Long]
- ) extends RDD[TrackStateRDDRecord[K, S, E]](
+ ) extends RDD[MapWithStateRDDRecord[K, S, E]](
partitionedDataRDD.sparkContext,
List(
- new OneToOneDependency[TrackStateRDDRecord[K, S, E]](prevStateRDD),
+ new OneToOneDependency[MapWithStateRDDRecord[K, S, E]](prevStateRDD),
new OneToOneDependency(partitionedDataRDD))
) {
@@ -141,19 +142,19 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E:
}
override def compute(
- partition: Partition, context: TaskContext): Iterator[TrackStateRDDRecord[K, S, E]] = {
+ partition: Partition, context: TaskContext): Iterator[MapWithStateRDDRecord[K, S, E]] = {
- val stateRDDPartition = partition.asInstanceOf[TrackStateRDDPartition]
+ val stateRDDPartition = partition.asInstanceOf[MapWithStateRDDPartition]
val prevStateRDDIterator = prevStateRDD.iterator(
stateRDDPartition.previousSessionRDDPartition, context)
val dataIterator = partitionedDataRDD.iterator(
stateRDDPartition.partitionedDataRDDPartition, context)
val prevRecord = if (prevStateRDDIterator.hasNext) Some(prevStateRDDIterator.next()) else None
- val newRecord = TrackStateRDDRecord.updateRecordWithData(
+ val newRecord = MapWithStateRDDRecord.updateRecordWithData(
prevRecord,
dataIterator,
- trackingFunction,
+ mappingFunction,
batchTime,
timeoutThresholdTime,
removeTimedoutData = doFullScan // remove timedout data only when full scan is enabled
@@ -163,7 +164,7 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E:
override protected def getPartitions: Array[Partition] = {
Array.tabulate(prevStateRDD.partitions.length) { i =>
- new TrackStateRDDPartition(i, prevStateRDD, partitionedDataRDD)}
+ new MapWithStateRDDPartition(i, prevStateRDD, partitionedDataRDD)}
}
override def clearDependencies(): Unit = {
@@ -177,52 +178,46 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E:
}
}
-private[streaming] object TrackStateRDD {
+private[streaming] object MapWithStateRDD {
def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
pairRDD: RDD[(K, S)],
partitioner: Partitioner,
- updateTime: Time): TrackStateRDD[K, V, S, E] = {
+ updateTime: Time): MapWithStateRDD[K, V, S, E] = {
- val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions ({ iterator =>
+ val stateRDD = pairRDD.partitionBy(partitioner).mapPartitions ({ iterator =>
val stateMap = StateMap.create[K, S](SparkEnv.get.conf)
iterator.foreach { case (key, state) => stateMap.put(key, state, updateTime.milliseconds) }
- Iterator(TrackStateRDDRecord(stateMap, Seq.empty[E]))
+ Iterator(MapWithStateRDDRecord(stateMap, Seq.empty[E]))
}, preservesPartitioning = true)
val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner)
val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None
- new TrackStateRDD[K, V, S, E](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None)
+ new MapWithStateRDD[K, V, S, E](
+ stateRDD, emptyDataRDD, noOpFunc, updateTime, None)
}
def createFromRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
rdd: RDD[(K, S, Long)],
partitioner: Partitioner,
- updateTime: Time): TrackStateRDD[K, V, S, E] = {
+ updateTime: Time): MapWithStateRDD[K, V, S, E] = {
val pairRDD = rdd.map { x => (x._1, (x._2, x._3)) }
- val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions({ iterator =>
+ val stateRDD = pairRDD.partitionBy(partitioner).mapPartitions({ iterator =>
val stateMap = StateMap.create[K, S](SparkEnv.get.conf)
iterator.foreach { case (key, (state, updateTime)) =>
stateMap.put(key, state, updateTime)
}
- Iterator(TrackStateRDDRecord(stateMap, Seq.empty[E]))
+ Iterator(MapWithStateRDDRecord(stateMap, Seq.empty[E]))
}, preservesPartitioning = true)
val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner)
val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None
- new TrackStateRDD[K, V, S, E](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None)
- }
-}
-
-private[streaming] class EmittedRecordsRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
- parent: TrackStateRDD[K, V, S, T]) extends RDD[T](parent) {
- override protected def getPartitions: Array[Partition] = parent.partitions
- override def compute(partition: Partition, context: TaskContext): Iterator[T] = {
- parent.compute(partition, context).flatMap { _.emittedRecords }
+ new MapWithStateRDD[K, V, S, E](
+ stateRDD, emptyDataRDD, noOpFunc, updateTime, None)
}
}
diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java
index 89d0bb7b61..bc4bc2eb42 100644
--- a/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java
+++ b/streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java
@@ -37,12 +37,12 @@ import org.junit.Test;
import org.apache.spark.HashPartitioner;
import org.apache.spark.api.java.JavaPairRDD;
-import org.apache.spark.api.java.function.Function2;
+import org.apache.spark.api.java.function.Function3;
import org.apache.spark.api.java.function.Function4;
import org.apache.spark.streaming.api.java.JavaPairDStream;
-import org.apache.spark.streaming.api.java.JavaTrackStateDStream;
+import org.apache.spark.streaming.api.java.JavaMapWithStateDStream;
-public class JavaTrackStateByKeySuite extends LocalJavaStreamingContext implements Serializable {
+public class JavaMapWithStateSuite extends LocalJavaStreamingContext implements Serializable {
/**
* This test is only for testing the APIs. It's not necessary to run it.
@@ -52,7 +52,7 @@ public class JavaTrackStateByKeySuite extends LocalJavaStreamingContext implemen
JavaPairDStream<String, Integer> wordsDstream = null;
final Function4<Time, String, Optional<Integer>, State<Boolean>, Optional<Double>>
- trackStateFunc =
+ mappingFunc =
new Function4<Time, String, Optional<Integer>, State<Boolean>, Optional<Double>>() {
@Override
@@ -68,21 +68,21 @@ public class JavaTrackStateByKeySuite extends LocalJavaStreamingContext implemen
}
};
- JavaTrackStateDStream<String, Integer, Boolean, Double> stateDstream =
- wordsDstream.trackStateByKey(
- StateSpec.function(trackStateFunc)
+ JavaMapWithStateDStream<String, Integer, Boolean, Double> stateDstream =
+ wordsDstream.mapWithState(
+ StateSpec.function(mappingFunc)
.initialState(initialRDD)
.numPartitions(10)
.partitioner(new HashPartitioner(10))
.timeout(Durations.seconds(10)));
- JavaPairDStream<String, Boolean> emittedRecords = stateDstream.stateSnapshots();
+ JavaPairDStream<String, Boolean> stateSnapshots = stateDstream.stateSnapshots();
- final Function2<Optional<Integer>, State<Boolean>, Double> trackStateFunc2 =
- new Function2<Optional<Integer>, State<Boolean>, Double>() {
+ final Function3<String, Optional<Integer>, State<Boolean>, Double> mappingFunc2 =
+ new Function3<String, Optional<Integer>, State<Boolean>, Double>() {
@Override
- public Double call(Optional<Integer> one, State<Boolean> state) {
+ public Double call(String key, Optional<Integer> one, State<Boolean> state) {
// Use all State's methods here
state.exists();
state.get();
@@ -93,15 +93,15 @@ public class JavaTrackStateByKeySuite extends LocalJavaStreamingContext implemen
}
};
- JavaTrackStateDStream<String, Integer, Boolean, Double> stateDstream2 =
- wordsDstream.trackStateByKey(
- StateSpec.<String, Integer, Boolean, Double>function(trackStateFunc2)
+ JavaMapWithStateDStream<String, Integer, Boolean, Double> stateDstream2 =
+ wordsDstream.mapWithState(
+ StateSpec.<String, Integer, Boolean, Double>function(mappingFunc2)
.initialState(initialRDD)
.numPartitions(10)
.partitioner(new HashPartitioner(10))
.timeout(Durations.seconds(10)));
- JavaPairDStream<String, Boolean> emittedRecords2 = stateDstream2.stateSnapshots();
+ JavaPairDStream<String, Boolean> stateSnapshots2 = stateDstream2.stateSnapshots();
}
@Test
@@ -148,11 +148,11 @@ public class JavaTrackStateByKeySuite extends LocalJavaStreamingContext implemen
new Tuple2<String, Integer>("c", 1))
);
- Function2<Optional<Integer>, State<Integer>, Integer> trackStateFunc =
- new Function2<Optional<Integer>, State<Integer>, Integer>() {
+ Function3<String, Optional<Integer>, State<Integer>, Integer> mappingFunc =
+ new Function3<String, Optional<Integer>, State<Integer>, Integer>() {
@Override
- public Integer call(Optional<Integer> value, State<Integer> state) throws Exception {
+ public Integer call(String key, Optional<Integer> value, State<Integer> state) throws Exception {
int sum = value.or(0) + (state.exists() ? state.get() : 0);
state.update(sum);
return sum;
@@ -160,29 +160,29 @@ public class JavaTrackStateByKeySuite extends LocalJavaStreamingContext implemen
};
testOperation(
inputData,
- StateSpec.<String, Integer, Integer, Integer>function(trackStateFunc),
+ StateSpec.<String, Integer, Integer, Integer>function(mappingFunc),
outputData,
stateData);
}
private <K, S, T> void testOperation(
List<List<K>> input,
- StateSpec<K, Integer, S, T> trackStateSpec,
+ StateSpec<K, Integer, S, T> mapWithStateSpec,
List<Set<T>> expectedOutputs,
List<Set<Tuple2<K, S>>> expectedStateSnapshots) {
int numBatches = expectedOutputs.size();
JavaDStream<K> inputStream = JavaTestUtils.attachTestInputStream(ssc, input, 2);
- JavaTrackStateDStream<K, Integer, S, T> trackeStateStream =
+ JavaMapWithStateDStream<K, Integer, S, T> mapWithStateDStream =
JavaPairDStream.fromJavaDStream(inputStream.map(new Function<K, Tuple2<K, Integer>>() {
@Override
public Tuple2<K, Integer> call(K x) throws Exception {
return new Tuple2<K, Integer>(x, 1);
}
- })).trackStateByKey(trackStateSpec);
+ })).mapWithState(mapWithStateSpec);
final List<Set<T>> collectedOutputs =
Collections.synchronizedList(Lists.<Set<T>>newArrayList());
- trackeStateStream.foreachRDD(new Function<JavaRDD<T>, Void>() {
+ mapWithStateDStream.foreachRDD(new Function<JavaRDD<T>, Void>() {
@Override
public Void call(JavaRDD<T> rdd) throws Exception {
collectedOutputs.add(Sets.newHashSet(rdd.collect()));
@@ -191,7 +191,7 @@ public class JavaTrackStateByKeySuite extends LocalJavaStreamingContext implemen
});
final List<Set<Tuple2<K, S>>> collectedStateSnapshots =
Collections.synchronizedList(Lists.<Set<Tuple2<K, S>>>newArrayList());
- trackeStateStream.stateSnapshots().foreachRDD(new Function<JavaPairRDD<K, S>, Void>() {
+ mapWithStateDStream.stateSnapshots().foreachRDD(new Function<JavaPairRDD<K, S>, Void>() {
@Override
public Void call(JavaPairRDD<K, S> rdd) throws Exception {
collectedStateSnapshots.add(Sets.newHashSet(rdd.collect()));
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala
index 1fc320d31b..4b08085e09 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala
@@ -25,11 +25,11 @@ import scala.reflect.ClassTag
import org.scalatest.PrivateMethodTester._
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
-import org.apache.spark.streaming.dstream.{DStream, InternalTrackStateDStream, TrackStateDStream, TrackStateDStreamImpl}
+import org.apache.spark.streaming.dstream.{DStream, InternalMapWithStateDStream, MapWithStateDStream, MapWithStateDStreamImpl}
import org.apache.spark.util.{ManualClock, Utils}
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
-class TrackStateByKeySuite extends SparkFunSuite
+class MapWithStateSuite extends SparkFunSuite
with DStreamCheckpointTester with BeforeAndAfterAll with BeforeAndAfter {
private var sc: SparkContext = null
@@ -49,7 +49,7 @@ class TrackStateByKeySuite extends SparkFunSuite
}
override def beforeAll(): Unit = {
- val conf = new SparkConf().setMaster("local").setAppName("TrackStateByKeySuite")
+ val conf = new SparkConf().setMaster("local").setAppName("MapWithStateSuite")
conf.set("spark.streaming.clock", classOf[ManualClock].getName())
sc = new SparkContext(conf)
}
@@ -129,7 +129,7 @@ class TrackStateByKeySuite extends SparkFunSuite
testState(Some(3), shouldBeTimingOut = true)
}
- test("trackStateByKey - basic operations with simple API") {
+ test("mapWithState - basic operations with simple API") {
val inputData =
Seq(
Seq(),
@@ -164,17 +164,17 @@ class TrackStateByKeySuite extends SparkFunSuite
)
// state maintains running count, and updated count is returned
- val trackStateFunc = (value: Option[Int], state: State[Int]) => {
+ val mappingFunc = (key: String, value: Option[Int], state: State[Int]) => {
val sum = value.getOrElse(0) + state.getOption.getOrElse(0)
state.update(sum)
sum
}
testOperation[String, Int, Int](
- inputData, StateSpec.function(trackStateFunc), outputData, stateData)
+ inputData, StateSpec.function(mappingFunc), outputData, stateData)
}
- test("trackStateByKey - basic operations with advanced API") {
+ test("mapWithState - basic operations with advanced API") {
val inputData =
Seq(
Seq(),
@@ -209,65 +209,65 @@ class TrackStateByKeySuite extends SparkFunSuite
)
// state maintains running count, key string doubled and returned
- val trackStateFunc = (batchTime: Time, key: String, value: Option[Int], state: State[Int]) => {
+ val mappingFunc = (batchTime: Time, key: String, value: Option[Int], state: State[Int]) => {
val sum = value.getOrElse(0) + state.getOption.getOrElse(0)
state.update(sum)
Some(key * 2)
}
- testOperation(inputData, StateSpec.function(trackStateFunc), outputData, stateData)
+ testOperation(inputData, StateSpec.function(mappingFunc), outputData, stateData)
}
- test("trackStateByKey - type inferencing and class tags") {
+ test("mapWithState - type inferencing and class tags") {
- // Simple track state function with value as Int, state as Double and emitted type as Double
- val simpleFunc = (value: Option[Int], state: State[Double]) => {
+ // Simple track state function with value as Int, state as Double and mapped type as Double
+ val simpleFunc = (key: String, value: Option[Int], state: State[Double]) => {
0L
}
// Advanced track state function with key as String, value as Int, state as Double and
- // emitted type as Double
+ // mapped type as Double
val advancedFunc = (time: Time, key: String, value: Option[Int], state: State[Double]) => {
Some(0L)
}
- def testTypes(dstream: TrackStateDStream[_, _, _, _]): Unit = {
- val dstreamImpl = dstream.asInstanceOf[TrackStateDStreamImpl[_, _, _, _]]
+ def testTypes(dstream: MapWithStateDStream[_, _, _, _]): Unit = {
+ val dstreamImpl = dstream.asInstanceOf[MapWithStateDStreamImpl[_, _, _, _]]
assert(dstreamImpl.keyClass === classOf[String])
assert(dstreamImpl.valueClass === classOf[Int])
assert(dstreamImpl.stateClass === classOf[Double])
- assert(dstreamImpl.emittedClass === classOf[Long])
+ assert(dstreamImpl.mappedClass === classOf[Long])
}
val ssc = new StreamingContext(sc, batchDuration)
val inputStream = new TestInputStream[(String, Int)](ssc, Seq.empty, numPartitions = 2)
- // Defining StateSpec inline with trackStateByKey and simple function implicitly gets the types
- val simpleFunctionStateStream1 = inputStream.trackStateByKey(
+ // Defining StateSpec inline with mapWithState and simple function implicitly gets the types
+ val simpleFunctionStateStream1 = inputStream.mapWithState(
StateSpec.function(simpleFunc).numPartitions(1))
testTypes(simpleFunctionStateStream1)
// Separately defining StateSpec with simple function requires explicitly specifying types
val simpleFuncSpec = StateSpec.function[String, Int, Double, Long](simpleFunc)
- val simpleFunctionStateStream2 = inputStream.trackStateByKey(simpleFuncSpec)
+ val simpleFunctionStateStream2 = inputStream.mapWithState(simpleFuncSpec)
testTypes(simpleFunctionStateStream2)
// Separately defining StateSpec with advanced function implicitly gets the types
val advFuncSpec1 = StateSpec.function(advancedFunc)
- val advFunctionStateStream1 = inputStream.trackStateByKey(advFuncSpec1)
+ val advFunctionStateStream1 = inputStream.mapWithState(advFuncSpec1)
testTypes(advFunctionStateStream1)
- // Defining StateSpec inline with trackStateByKey and advanced func implicitly gets the types
- val advFunctionStateStream2 = inputStream.trackStateByKey(
+ // Defining StateSpec inline with mapWithState and advanced func implicitly gets the types
+ val advFunctionStateStream2 = inputStream.mapWithState(
StateSpec.function(simpleFunc).numPartitions(1))
testTypes(advFunctionStateStream2)
- // Defining StateSpec inline with trackStateByKey and advanced func implicitly gets the types
+ // Defining StateSpec inline with mapWithState and advanced func implicitly gets the types
val advFuncSpec2 = StateSpec.function[String, Int, Double, Long](advancedFunc)
- val advFunctionStateStream3 = inputStream.trackStateByKey[Double, Long](advFuncSpec2)
+ val advFunctionStateStream3 = inputStream.mapWithState[Double, Long](advFuncSpec2)
testTypes(advFunctionStateStream3)
}
- test("trackStateByKey - states as emitted records") {
+ test("mapWithState - states as mapped data") {
val inputData =
Seq(
Seq(),
@@ -301,17 +301,17 @@ class TrackStateByKeySuite extends SparkFunSuite
Seq(("a", 5), ("b", 3), ("c", 1))
)
- val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => {
+ val mappingFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => {
val sum = value.getOrElse(0) + state.getOption.getOrElse(0)
val output = (key, sum)
state.update(sum)
Some(output)
}
- testOperation(inputData, StateSpec.function(trackStateFunc), outputData, stateData)
+ testOperation(inputData, StateSpec.function(mappingFunc), outputData, stateData)
}
- test("trackStateByKey - initial states, with nothing emitted") {
+ test("mapWithState - initial states, with nothing returned as from mapping function") {
val initialState = Seq(("a", 5), ("b", 10), ("c", -20), ("d", 0))
@@ -339,18 +339,18 @@ class TrackStateByKeySuite extends SparkFunSuite
Seq(("a", 10), ("b", 13), ("c", -19), ("d", 0))
)
- val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => {
+ val mappingFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => {
val sum = value.getOrElse(0) + state.getOption.getOrElse(0)
val output = (key, sum)
state.update(sum)
None.asInstanceOf[Option[Int]]
}
- val trackStateSpec = StateSpec.function(trackStateFunc).initialState(sc.makeRDD(initialState))
- testOperation(inputData, trackStateSpec, outputData, stateData)
+ val mapWithStateSpec = StateSpec.function(mappingFunc).initialState(sc.makeRDD(initialState))
+ testOperation(inputData, mapWithStateSpec, outputData, stateData)
}
- test("trackStateByKey - state removing") {
+ test("mapWithState - state removing") {
val inputData =
Seq(
Seq(),
@@ -388,7 +388,7 @@ class TrackStateByKeySuite extends SparkFunSuite
Seq()
)
- val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => {
+ val mappingFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => {
if (state.exists) {
state.remove()
Some(key)
@@ -399,10 +399,10 @@ class TrackStateByKeySuite extends SparkFunSuite
}
testOperation(
- inputData, StateSpec.function(trackStateFunc).numPartitions(1), outputData, stateData)
+ inputData, StateSpec.function(mappingFunc).numPartitions(1), outputData, stateData)
}
- test("trackStateByKey - state timing out") {
+ test("mapWithState - state timing out") {
val inputData =
Seq(
Seq("a", "b", "c"),
@@ -413,7 +413,7 @@ class TrackStateByKeySuite extends SparkFunSuite
Seq("a") // a will not time out
) ++ Seq.fill(20)(Seq("a")) // a will continue to stay active
- val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => {
+ val mappingFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => {
if (value.isDefined) {
state.update(1)
}
@@ -425,9 +425,9 @@ class TrackStateByKeySuite extends SparkFunSuite
}
val (collectedOutputs, collectedStateSnapshots) = getOperationOutput(
- inputData, StateSpec.function(trackStateFunc).timeout(Seconds(3)), 20)
+ inputData, StateSpec.function(mappingFunc).timeout(Seconds(3)), 20)
- // b and c should be emitted once each, when they were marked as expired
+ // b and c should be returned once each, when they were marked as expired
assert(collectedOutputs.flatten.sorted === Seq("b", "c"))
// States for a, b, c should be defined at one point of time
@@ -439,8 +439,8 @@ class TrackStateByKeySuite extends SparkFunSuite
assert(collectedStateSnapshots.last.toSet === Set(("a", 1)))
}
- test("trackStateByKey - checkpoint durations") {
- val privateMethod = PrivateMethod[InternalTrackStateDStream[_, _, _, _]]('internalStream)
+ test("mapWithState - checkpoint durations") {
+ val privateMethod = PrivateMethod[InternalMapWithStateDStream[_, _, _, _]]('internalStream)
def testCheckpointDuration(
batchDuration: Duration,
@@ -451,18 +451,18 @@ class TrackStateByKeySuite extends SparkFunSuite
try {
val inputStream = new TestInputStream(ssc, Seq.empty[Seq[Int]], 2).map(_ -> 1)
- val dummyFunc = (value: Option[Int], state: State[Int]) => 0
- val trackStateStream = inputStream.trackStateByKey(StateSpec.function(dummyFunc))
- val internalTrackStateStream = trackStateStream invokePrivate privateMethod()
+ val dummyFunc = (key: Int, value: Option[Int], state: State[Int]) => 0
+ val mapWithStateStream = inputStream.mapWithState(StateSpec.function(dummyFunc))
+ val internalmapWithStateStream = mapWithStateStream invokePrivate privateMethod()
explicitCheckpointDuration.foreach { d =>
- trackStateStream.checkpoint(d)
+ mapWithStateStream.checkpoint(d)
}
- trackStateStream.register()
+ mapWithStateStream.register()
ssc.checkpoint(checkpointDir.toString)
ssc.start() // should initialize all the checkpoint durations
- assert(trackStateStream.checkpointDuration === null)
- assert(internalTrackStateStream.checkpointDuration === expectedCheckpointDuration)
+ assert(mapWithStateStream.checkpointDuration === null)
+ assert(internalmapWithStateStream.checkpointDuration === expectedCheckpointDuration)
} finally {
ssc.stop(stopSparkContext = false)
}
@@ -478,7 +478,7 @@ class TrackStateByKeySuite extends SparkFunSuite
}
- test("trackStateByKey - driver failure recovery") {
+ test("mapWithState - driver failure recovery") {
val inputData =
Seq(
Seq(),
@@ -505,16 +505,16 @@ class TrackStateByKeySuite extends SparkFunSuite
val checkpointDuration = batchDuration * (stateData.size / 2)
- val runningCount = (value: Option[Int], state: State[Int]) => {
+ val runningCount = (key: String, value: Option[Int], state: State[Int]) => {
state.update(state.getOption().getOrElse(0) + value.getOrElse(0))
state.get()
}
- val trackStateStream = dstream.map { _ -> 1 }.trackStateByKey(
+ val mapWithStateStream = dstream.map { _ -> 1 }.mapWithState(
StateSpec.function(runningCount))
// Set internval make sure there is one RDD checkpointing
- trackStateStream.checkpoint(checkpointDuration)
- trackStateStream.stateSnapshots()
+ mapWithStateStream.checkpoint(checkpointDuration)
+ mapWithStateStream.stateSnapshots()
}
testCheckpointedOperation(inputData, operation, stateData, inputData.size / 2,
@@ -523,28 +523,28 @@ class TrackStateByKeySuite extends SparkFunSuite
private def testOperation[K: ClassTag, S: ClassTag, T: ClassTag](
input: Seq[Seq[K]],
- trackStateSpec: StateSpec[K, Int, S, T],
+ mapWithStateSpec: StateSpec[K, Int, S, T],
expectedOutputs: Seq[Seq[T]],
expectedStateSnapshots: Seq[Seq[(K, S)]]
): Unit = {
require(expectedOutputs.size == expectedStateSnapshots.size)
val (collectedOutputs, collectedStateSnapshots) =
- getOperationOutput(input, trackStateSpec, expectedOutputs.size)
+ getOperationOutput(input, mapWithStateSpec, expectedOutputs.size)
assert(expectedOutputs, collectedOutputs, "outputs")
assert(expectedStateSnapshots, collectedStateSnapshots, "state snapshots")
}
private def getOperationOutput[K: ClassTag, S: ClassTag, T: ClassTag](
input: Seq[Seq[K]],
- trackStateSpec: StateSpec[K, Int, S, T],
+ mapWithStateSpec: StateSpec[K, Int, S, T],
numBatches: Int
): (Seq[Seq[T]], Seq[Seq[(K, S)]]) = {
// Setup the stream computation
val ssc = new StreamingContext(sc, Seconds(1))
val inputStream = new TestInputStream(ssc, input, numPartitions = 2)
- val trackeStateStream = inputStream.map(x => (x, 1)).trackStateByKey(trackStateSpec)
+ val trackeStateStream = inputStream.map(x => (x, 1)).mapWithState(mapWithStateSpec)
val collectedOutputs = new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]]
val outputStream = new TestOutputStream(trackeStateStream, collectedOutputs)
val collectedStateSnapshots = new ArrayBuffer[Seq[(K, S)]] with SynchronizedBuffer[Seq[(K, S)]]
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala
index 3b2d43f2ce..aa95bd33dd 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala
@@ -30,14 +30,14 @@ import org.apache.spark.streaming.util.OpenHashMapBasedStateMap
import org.apache.spark.streaming.{State, Time}
import org.apache.spark.util.Utils
-class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with BeforeAndAfterAll {
+class MapWithStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with BeforeAndAfterAll {
private var sc: SparkContext = null
private var checkpointDir: File = _
override def beforeAll(): Unit = {
sc = new SparkContext(
- new SparkConf().setMaster("local").setAppName("TrackStateRDDSuite"))
+ new SparkConf().setMaster("local").setAppName("MapWithStateRDDSuite"))
checkpointDir = Utils.createTempDir()
sc.setCheckpointDir(checkpointDir.toString)
}
@@ -54,7 +54,7 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef
test("creation from pair RDD") {
val data = Seq((1, "1"), (2, "2"), (3, "3"))
val partitioner = new HashPartitioner(10)
- val rdd = TrackStateRDD.createFromPairRDD[Int, Int, String, Int](
+ val rdd = MapWithStateRDD.createFromPairRDD[Int, Int, String, Int](
sc.parallelize(data), partitioner, Time(123))
assertRDD[Int, Int, String, Int](rdd, data.map { x => (x._1, x._2, 123)}.toSet, Set.empty)
assert(rdd.partitions.size === partitioner.numPartitions)
@@ -62,7 +62,7 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef
assert(rdd.partitioner === Some(partitioner))
}
- test("updating state and generating emitted data in TrackStateRecord") {
+ test("updating state and generating mapped data in MapWithStateRDDRecord") {
val initialTime = 1000L
val updatedTime = 2000L
@@ -71,7 +71,7 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef
/**
* Assert that applying given data on a prior record generates correct updated record, with
- * correct state map and emitted data
+ * correct state map and mapped data
*/
def assertRecordUpdate(
initStates: Iterable[Int],
@@ -86,18 +86,18 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef
val initialStateMap = new OpenHashMapBasedStateMap[String, Int]()
initStates.foreach { s => initialStateMap.put("key", s, initialTime) }
functionCalled = false
- val record = TrackStateRDDRecord[String, Int, Int](initialStateMap, Seq.empty)
+ val record = MapWithStateRDDRecord[String, Int, Int](initialStateMap, Seq.empty)
val dataIterator = data.map { v => ("key", v) }.iterator
val removedStates = new ArrayBuffer[Int]
val timingOutStates = new ArrayBuffer[Int]
/**
- * Tracking function that updates/removes state based on instructions in the data, and
+ * Mapping function that updates/removes state based on instructions in the data, and
* return state (when instructed or when state is timing out).
*/
def testFunc(t: Time, key: String, data: Option[String], state: State[Int]): Option[Int] = {
functionCalled = true
- assert(t.milliseconds === updatedTime, "tracking func called with wrong time")
+ assert(t.milliseconds === updatedTime, "mapping func called with wrong time")
data match {
case Some("noop") =>
@@ -120,22 +120,22 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef
}
}
- val updatedRecord = TrackStateRDDRecord.updateRecordWithData[String, String, Int, Int](
+ val updatedRecord = MapWithStateRDDRecord.updateRecordWithData[String, String, Int, Int](
Some(record), dataIterator, testFunc,
Time(updatedTime), timeoutThreshold, removeTimedoutData)
val updatedStateData = updatedRecord.stateMap.getAll().map { x => (x._2, x._3) }
assert(updatedStateData.toSet === expectedStates.toSet,
- "states do not match after updating the TrackStateRecord")
+ "states do not match after updating the MapWithStateRDDRecord")
- assert(updatedRecord.emittedRecords.toSet === expectedOutput.toSet,
- "emitted data do not match after updating the TrackStateRecord")
+ assert(updatedRecord.mappedData.toSet === expectedOutput.toSet,
+ "mapped data do not match after updating the MapWithStateRDDRecord")
assert(timingOutStates.toSet === expectedTimingOutStates.toSet, "timing out states do not " +
- "match those that were expected to do so while updating the TrackStateRecord")
+ "match those that were expected to do so while updating the MapWithStateRDDRecord")
assert(removedStates.toSet === expectedRemovedStates.toSet, "removed states do not " +
- "match those that were expected to do so while updating the TrackStateRecord")
+ "match those that were expected to do so while updating the MapWithStateRDDRecord")
}
@@ -187,12 +187,12 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef
}
- test("states generated by TrackStateRDD") {
+ test("states generated by MapWithStateRDD") {
val initStates = Seq(("k1", 0), ("k2", 0))
val initTime = 123
val initStateWthTime = initStates.map { x => (x._1, x._2, initTime) }.toSet
val partitioner = new HashPartitioner(2)
- val initStateRDD = TrackStateRDD.createFromPairRDD[String, Int, Int, Int](
+ val initStateRDD = MapWithStateRDD.createFromPairRDD[String, Int, Int, Int](
sc.parallelize(initStates), partitioner, Time(initTime)).persist()
assertRDD(initStateRDD, initStateWthTime, Set.empty)
@@ -203,21 +203,21 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef
* creates a new state RDD with expected states
*/
def testStateUpdates(
- testStateRDD: TrackStateRDD[String, Int, Int, Int],
+ testStateRDD: MapWithStateRDD[String, Int, Int, Int],
testData: Seq[(String, Int)],
- expectedStates: Set[(String, Int, Int)]): TrackStateRDD[String, Int, Int, Int] = {
+ expectedStates: Set[(String, Int, Int)]): MapWithStateRDD[String, Int, Int, Int] = {
- // Persist the test TrackStateRDD so that its not recomputed while doing the next operation.
- // This is to make sure that we only track which state keys are being touched in the next op.
+ // Persist the test MapWithStateRDD so that its not recomputed while doing the next operation.
+ // This is to make sure that we only touch which state keys are being touched in the next op.
testStateRDD.persist().count()
// To track which keys are being touched
- TrackStateRDDSuite.touchedStateKeys.clear()
+ MapWithStateRDDSuite.touchedStateKeys.clear()
- val trackingFunc = (time: Time, key: String, data: Option[Int], state: State[Int]) => {
+ val mappingFunction = (time: Time, key: String, data: Option[Int], state: State[Int]) => {
// Track the key that has been touched
- TrackStateRDDSuite.touchedStateKeys += key
+ MapWithStateRDDSuite.touchedStateKeys += key
// If the data is 0, do not do anything with the state
// else if the data is 1, increment the state if it exists, or set new state to 0
@@ -236,12 +236,12 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef
// Assert that the new state RDD has expected state data
val newStateRDD = assertOperation(
- testStateRDD, newDataRDD, trackingFunc, updateTime, expectedStates, Set.empty)
+ testStateRDD, newDataRDD, mappingFunction, updateTime, expectedStates, Set.empty)
// Assert that the function was called only for the keys present in the data
- assert(TrackStateRDDSuite.touchedStateKeys.size === testData.size,
+ assert(MapWithStateRDDSuite.touchedStateKeys.size === testData.size,
"More number of keys are being touched than that is expected")
- assert(TrackStateRDDSuite.touchedStateKeys.toSet === testData.toMap.keys,
+ assert(MapWithStateRDDSuite.touchedStateKeys.toSet === testData.toMap.keys,
"Keys not in the data are being touched unexpectedly")
// Assert that the test RDD's data has not changed
@@ -289,19 +289,19 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef
test("checkpointing") {
/**
- * This tests whether the TrackStateRDD correctly truncates any references to its parent RDDs -
- * the data RDD and the parent TrackStateRDD.
+ * This tests whether the MapWithStateRDD correctly truncates any references to its parent RDDs
+ * - the data RDD and the parent MapWithStateRDD.
*/
- def rddCollectFunc(rdd: RDD[TrackStateRDDRecord[Int, Int, Int]])
+ def rddCollectFunc(rdd: RDD[MapWithStateRDDRecord[Int, Int, Int]])
: Set[(List[(Int, Int, Long)], List[Int])] = {
- rdd.map { record => (record.stateMap.getAll().toList, record.emittedRecords.toList) }
+ rdd.map { record => (record.stateMap.getAll().toList, record.mappedData.toList) }
.collect.toSet
}
- /** Generate TrackStateRDD with data RDD having a long lineage */
+ /** Generate MapWithStateRDD with data RDD having a long lineage */
def makeStateRDDWithLongLineageDataRDD(longLineageRDD: RDD[Int])
- : TrackStateRDD[Int, Int, Int, Int] = {
- TrackStateRDD.createFromPairRDD(longLineageRDD.map { _ -> 1}, partitioner, Time(0))
+ : MapWithStateRDD[Int, Int, Int, Int] = {
+ MapWithStateRDD.createFromPairRDD(longLineageRDD.map { _ -> 1}, partitioner, Time(0))
}
testRDD(
@@ -309,15 +309,15 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef
testRDDPartitions(
makeStateRDDWithLongLineageDataRDD, reliableCheckpoint = true, rddCollectFunc _)
- /** Generate TrackStateRDD with parent state RDD having a long lineage */
+ /** Generate MapWithStateRDD with parent state RDD having a long lineage */
def makeStateRDDWithLongLineageParenttateRDD(
- longLineageRDD: RDD[Int]): TrackStateRDD[Int, Int, Int, Int] = {
+ longLineageRDD: RDD[Int]): MapWithStateRDD[Int, Int, Int, Int] = {
- // Create a TrackStateRDD that has a long lineage using the data RDD with a long lineage
+ // Create a MapWithStateRDD that has a long lineage using the data RDD with a long lineage
val stateRDDWithLongLineage = makeStateRDDWithLongLineageDataRDD(longLineageRDD)
- // Create a new TrackStateRDD, with the lineage lineage TrackStateRDD as the parent
- new TrackStateRDD[Int, Int, Int, Int](
+ // Create a new MapWithStateRDD, with the lineage lineage MapWithStateRDD as the parent
+ new MapWithStateRDD[Int, Int, Int, Int](
stateRDDWithLongLineage,
stateRDDWithLongLineage.sparkContext.emptyRDD[(Int, Int)].partitionBy(partitioner),
(time: Time, key: Int, value: Option[Int], state: State[Int]) => None,
@@ -333,25 +333,25 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef
}
test("checkpointing empty state RDD") {
- val emptyStateRDD = TrackStateRDD.createFromPairRDD[Int, Int, Int, Int](
+ val emptyStateRDD = MapWithStateRDD.createFromPairRDD[Int, Int, Int, Int](
sc.emptyRDD[(Int, Int)], new HashPartitioner(10), Time(0))
emptyStateRDD.checkpoint()
assert(emptyStateRDD.flatMap { _.stateMap.getAll() }.collect().isEmpty)
- val cpRDD = sc.checkpointFile[TrackStateRDDRecord[Int, Int, Int]](
+ val cpRDD = sc.checkpointFile[MapWithStateRDDRecord[Int, Int, Int]](
emptyStateRDD.getCheckpointFile.get)
assert(cpRDD.flatMap { _.stateMap.getAll() }.collect().isEmpty)
}
- /** Assert whether the `trackStateByKey` operation generates expected results */
+ /** Assert whether the `mapWithState` operation generates expected results */
private def assertOperation[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
- testStateRDD: TrackStateRDD[K, V, S, T],
+ testStateRDD: MapWithStateRDD[K, V, S, T],
newDataRDD: RDD[(K, V)],
- trackStateFunc: (Time, K, Option[V], State[S]) => Option[T],
+ mappingFunction: (Time, K, Option[V], State[S]) => Option[T],
currentTime: Long,
expectedStates: Set[(K, S, Int)],
- expectedEmittedRecords: Set[T],
+ expectedMappedData: Set[T],
doFullScan: Boolean = false
- ): TrackStateRDD[K, V, S, T] = {
+ ): MapWithStateRDD[K, V, S, T] = {
val partitionedNewDataRDD = if (newDataRDD.partitioner != testStateRDD.partitioner) {
newDataRDD.partitionBy(testStateRDD.partitioner.get)
@@ -359,31 +359,31 @@ class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with Bef
newDataRDD
}
- val newStateRDD = new TrackStateRDD[K, V, S, T](
- testStateRDD, newDataRDD, trackStateFunc, Time(currentTime), None)
+ val newStateRDD = new MapWithStateRDD[K, V, S, T](
+ testStateRDD, newDataRDD, mappingFunction, Time(currentTime), None)
if (doFullScan) newStateRDD.setFullScan()
// Persist to make sure that it gets computed only once and we can track precisely how many
// state keys the computing touched
newStateRDD.persist().count()
- assertRDD(newStateRDD, expectedStates, expectedEmittedRecords)
+ assertRDD(newStateRDD, expectedStates, expectedMappedData)
newStateRDD
}
- /** Assert whether the [[TrackStateRDD]] has the expected state ad emitted records */
+ /** Assert whether the [[MapWithStateRDD]] has the expected state and mapped data */
private def assertRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
- trackStateRDD: TrackStateRDD[K, V, S, T],
+ stateRDD: MapWithStateRDD[K, V, S, T],
expectedStates: Set[(K, S, Int)],
- expectedEmittedRecords: Set[T]): Unit = {
- val states = trackStateRDD.flatMap { _.stateMap.getAll() }.collect().toSet
- val emittedRecords = trackStateRDD.flatMap { _.emittedRecords }.collect().toSet
+ expectedMappedData: Set[T]): Unit = {
+ val states = stateRDD.flatMap { _.stateMap.getAll() }.collect().toSet
+ val mappedData = stateRDD.flatMap { _.mappedData }.collect().toSet
assert(states === expectedStates,
- "states after track state operation were not as expected")
- assert(emittedRecords === expectedEmittedRecords,
- "emitted records after track state operation were not as expected")
+ "states after mapWithState operation were not as expected")
+ assert(mappedData === expectedMappedData,
+ "mapped data after mapWithState operation were not as expected")
}
}
-object TrackStateRDDSuite {
+object MapWithStateRDDSuite {
private val touchedStateKeys = new ArrayBuffer[String]()
}