aboutsummaryrefslogtreecommitdiff
path: root/streaming/src/main
diff options
context:
space:
mode:
authorTathagata Das <tathagata.das1565@gmail.com>2015-12-09 20:47:15 -0800
committerShixiong Zhu <shixiong@databricks.com>2015-12-09 20:47:15 -0800
commitbd2cd4f53d1ca10f4896bd39b0e180d4929867a2 (patch)
tree308b2c7239d67191f95bd5673ab98f916b40bd58 /streaming/src/main
parent2166c2a75083c2262e071a652dd52b1a33348b6e (diff)
downloadspark-bd2cd4f53d1ca10f4896bd39b0e180d4929867a2.tar.gz
spark-bd2cd4f53d1ca10f4896bd39b0e180d4929867a2.tar.bz2
spark-bd2cd4f53d1ca10f4896bd39b0e180d4929867a2.zip
[SPARK-12244][SPARK-12245][STREAMING] Rename trackStateByKey to mapWithState and change tracking function signature
SPARK-12244: Based on feedback from early users and personal experience attempting to explain it, the name trackStateByKey had two problem. "trackState" is a completely new term which really does not give any intuition on what the operation is the resultant data stream of objects returned by the function is called in docs as the "emitted" data for the lack of a better. "mapWithState" makes sense because the API is like a mapping function like (Key, Value) => T with State as an additional parameter. The resultant data stream is "mapped data". So both problems are solved. SPARK-12245: From initial experiences, not having the key in the function makes it hard to return mapped stuff, as the whole information of the records is not there. Basically the user is restricted to doing something like mapValue() instead of map(). So adding the key as a parameter. Author: Tathagata Das <tathagata.das1565@gmail.com> Closes #10224 from tdas/rename.
Diffstat (limited to 'streaming/src/main')
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/State.scala20
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala160
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaMapWithStateDStream.scala (renamed from streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala)20
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala50
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala (renamed from streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala)61
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala41
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala (renamed from streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala)99
7 files changed, 230 insertions, 221 deletions
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/State.scala b/streaming/src/main/scala/org/apache/spark/streaming/State.scala
index 604e64fc61..b47bdda2c2 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/State.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/State.scala
@@ -23,14 +23,14 @@ import org.apache.spark.annotation.Experimental
/**
* :: Experimental ::
- * Abstract class for getting and updating the tracked state in the `trackStateByKey` operation of
- * a [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a
- * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
+ * Abstract class for getting and updating the state in mapping function used in the `mapWithState`
+ * operation of a [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala)
+ * or a [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
*
* Scala example of using `State`:
* {{{
- * // A tracking function that maintains an integer state and return a String
- * def trackStateFunc(data: Option[Int], state: State[Int]): Option[String] = {
+ * // A mapping function that maintains an integer state and returns a String
+ * def mappingFunction(key: String, value: Option[Int], state: State[Int]): Option[String] = {
* // Check if state exists
* if (state.exists) {
* val existingState = state.get // Get the existing state
@@ -52,12 +52,12 @@ import org.apache.spark.annotation.Experimental
*
* Java example of using `State`:
* {{{
- * // A tracking function that maintains an integer state and return a String
- * Function2<Optional<Integer>, State<Integer>, Optional<String>> trackStateFunc =
- * new Function2<Optional<Integer>, State<Integer>, Optional<String>>() {
+ * // A mapping function that maintains an integer state and returns a String
+ * Function3<String, Optional<Integer>, State<Integer>, String> mappingFunction =
+ * new Function3<String, Optional<Integer>, State<Integer>, String>() {
*
* @Override
- * public Optional<String> call(Optional<Integer> one, State<Integer> state) {
+ * public String call(String key, Optional<Integer> value, State<Integer> state) {
* if (state.exists()) {
* int existingState = state.get(); // Get the existing state
* boolean shouldRemove = ...; // Decide whether to remove the state
@@ -75,6 +75,8 @@ import org.apache.spark.annotation.Experimental
* }
* };
* }}}
+ *
+ * @tparam S Class of the state
*/
@Experimental
sealed abstract class State[S] {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
index bea5b9df20..9f6f95223f 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
@@ -20,7 +20,7 @@ package org.apache.spark.streaming
import com.google.common.base.Optional
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.{JavaPairRDD, JavaUtils}
-import org.apache.spark.api.java.function.{Function2 => JFunction2, Function4 => JFunction4}
+import org.apache.spark.api.java.function.{Function3 => JFunction3, Function4 => JFunction4}
import org.apache.spark.rdd.RDD
import org.apache.spark.util.ClosureCleaner
import org.apache.spark.{HashPartitioner, Partitioner}
@@ -28,7 +28,7 @@ import org.apache.spark.{HashPartitioner, Partitioner}
/**
* :: Experimental ::
* Abstract class representing all the specifications of the DStream transformation
- * `trackStateByKey` operation of a
+ * `mapWithState` operation of a
* [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a
* [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
* Use the [[org.apache.spark.streaming.StateSpec StateSpec.apply()]] or
@@ -37,50 +37,63 @@ import org.apache.spark.{HashPartitioner, Partitioner}
*
* Example in Scala:
* {{{
- * def trackingFunction(data: Option[ValueType], wrappedState: State[StateType]): EmittedType = {
- * ...
+ * // A mapping function that maintains an integer state and return a String
+ * def mappingFunction(key: String, value: Option[Int], state: State[Int]): Option[String] = {
+ * // Use state.exists(), state.get(), state.update() and state.remove()
+ * // to manage state, and return the necessary string
* }
*
- * val spec = StateSpec.function(trackingFunction).numPartitions(10)
+ * val spec = StateSpec.function(mappingFunction).numPartitions(10)
*
- * val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType](spec)
+ * val mapWithStateDStream = keyValueDStream.mapWithState[StateType, MappedType](spec)
* }}}
*
* Example in Java:
* {{{
- * StateSpec<KeyType, ValueType, StateType, EmittedDataType> spec =
- * StateSpec.<KeyType, ValueType, StateType, EmittedDataType>function(trackingFunction)
- * .numPartition(10);
+ * // A mapping function that maintains an integer state and return a string
+ * Function3<String, Optional<Integer>, State<Integer>, String> mappingFunction =
+ * new Function3<String, Optional<Integer>, State<Integer>, String>() {
+ * @Override
+ * public Optional<String> call(Optional<Integer> value, State<Integer> state) {
+ * // Use state.exists(), state.get(), state.update() and state.remove()
+ * // to manage state, and return the necessary string
+ * }
+ * };
*
- * JavaTrackStateDStream<KeyType, ValueType, StateType, EmittedType> emittedRecordDStream =
- * javaPairDStream.<StateType, EmittedDataType>trackStateByKey(spec);
+ * JavaMapWithStateDStream<String, Integer, Integer, String> mapWithStateDStream =
+ * keyValueDStream.mapWithState(StateSpec.function(mappingFunc));
* }}}
+ *
+ * @tparam KeyType Class of the state key
+ * @tparam ValueType Class of the state value
+ * @tparam StateType Class of the state data
+ * @tparam MappedType Class of the mapped elements
*/
@Experimental
-sealed abstract class StateSpec[KeyType, ValueType, StateType, EmittedType] extends Serializable {
+sealed abstract class StateSpec[KeyType, ValueType, StateType, MappedType] extends Serializable {
- /** Set the RDD containing the initial states that will be used by `trackStateByKey` */
+ /** Set the RDD containing the initial states that will be used by `mapWithState` */
def initialState(rdd: RDD[(KeyType, StateType)]): this.type
- /** Set the RDD containing the initial states that will be used by `trackStateByKey` */
+ /** Set the RDD containing the initial states that will be used by `mapWithState` */
def initialState(javaPairRDD: JavaPairRDD[KeyType, StateType]): this.type
/**
- * Set the number of partitions by which the state RDDs generated by `trackStateByKey`
+ * Set the number of partitions by which the state RDDs generated by `mapWithState`
* will be partitioned. Hash partitioning will be used.
*/
def numPartitions(numPartitions: Int): this.type
/**
- * Set the partitioner by which the state RDDs generated by `trackStateByKey` will be
+ * Set the partitioner by which the state RDDs generated by `mapWithState` will be
* be partitioned.
*/
def partitioner(partitioner: Partitioner): this.type
/**
* Set the duration after which the state of an idle key will be removed. A key and its state is
- * considered idle if it has not received any data for at least the given duration. The state
- * tracking function will be called one final time on the idle states that are going to be
+ * considered idle if it has not received any data for at least the given duration. The
+ * mapping function will be called one final time on the idle states that are going to be
* removed; [[org.apache.spark.streaming.State State.isTimingOut()]] set
* to `true` in that call.
*/
@@ -91,115 +104,124 @@ sealed abstract class StateSpec[KeyType, ValueType, StateType, EmittedType] exte
/**
* :: Experimental ::
* Builder object for creating instances of [[org.apache.spark.streaming.StateSpec StateSpec]]
- * that is used for specifying the parameters of the DStream transformation `trackStateByKey`
+ * that is used for specifying the parameters of the DStream transformation `mapWithState`
* that is used for specifying the parameters of the DStream transformation
- * `trackStateByKey` operation of a
+ * `mapWithState` operation of a
* [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a
* [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
*
* Example in Scala:
* {{{
- * def trackingFunction(data: Option[ValueType], wrappedState: State[StateType]): EmittedType = {
- * ...
+ * // A mapping function that maintains an integer state and return a String
+ * def mappingFunction(key: String, value: Option[Int], state: State[Int]): Option[String] = {
+ * // Use state.exists(), state.get(), state.update() and state.remove()
+ * // to manage state, and return the necessary string
* }
*
- * val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType](
- * StateSpec.function(trackingFunction).numPartitions(10))
+ * val spec = StateSpec.function(mappingFunction).numPartitions(10)
+ *
+ * val mapWithStateDStream = keyValueDStream.mapWithState[StateType, MappedType](spec)
* }}}
*
* Example in Java:
* {{{
- * StateSpec<KeyType, ValueType, StateType, EmittedDataType> spec =
- * StateSpec.<KeyType, ValueType, StateType, EmittedDataType>function(trackingFunction)
- * .numPartition(10);
+ * // A mapping function that maintains an integer state and return a string
+ * Function3<String, Optional<Integer>, State<Integer>, String> mappingFunction =
+ * new Function3<String, Optional<Integer>, State<Integer>, String>() {
+ * @Override
+ * public Optional<String> call(Optional<Integer> value, State<Integer> state) {
+ * // Use state.exists(), state.get(), state.update() and state.remove()
+ * // to manage state, and return the necessary string
+ * }
+ * };
*
- * JavaTrackStateDStream<KeyType, ValueType, StateType, EmittedType> emittedRecordDStream =
- * javaPairDStream.<StateType, EmittedDataType>trackStateByKey(spec);
- * }}}
+ * JavaMapWithStateDStream<String, Integer, Integer, String> mapWithStateDStream =
+ * keyValueDStream.mapWithState(StateSpec.function(mappingFunc));
+ *}}}
*/
@Experimental
object StateSpec {
/**
* Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications
- * of the `trackStateByKey` operation on a
+ * of the `mapWithState` operation on a
* [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]].
*
- * @param trackingFunction The function applied on every data item to manage the associated state
- * and generate the emitted data
+ * @param mappingFunction The function applied on every data item to manage the associated state
+ * and generate the mapped data
* @tparam KeyType Class of the keys
* @tparam ValueType Class of the values
* @tparam StateType Class of the states data
- * @tparam EmittedType Class of the emitted data
+ * @tparam MappedType Class of the mapped data
*/
- def function[KeyType, ValueType, StateType, EmittedType](
- trackingFunction: (Time, KeyType, Option[ValueType], State[StateType]) => Option[EmittedType]
- ): StateSpec[KeyType, ValueType, StateType, EmittedType] = {
- ClosureCleaner.clean(trackingFunction, checkSerializable = true)
- new StateSpecImpl(trackingFunction)
+ def function[KeyType, ValueType, StateType, MappedType](
+ mappingFunction: (Time, KeyType, Option[ValueType], State[StateType]) => Option[MappedType]
+ ): StateSpec[KeyType, ValueType, StateType, MappedType] = {
+ ClosureCleaner.clean(mappingFunction, checkSerializable = true)
+ new StateSpecImpl(mappingFunction)
}
/**
* Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications
- * of the `trackStateByKey` operation on a
+ * of the `mapWithState` operation on a
* [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]].
*
- * @param trackingFunction The function applied on every data item to manage the associated state
- * and generate the emitted data
+ * @param mappingFunction The function applied on every data item to manage the associated state
+ * and generate the mapped data
* @tparam ValueType Class of the values
* @tparam StateType Class of the states data
- * @tparam EmittedType Class of the emitted data
+ * @tparam MappedType Class of the mapped data
*/
- def function[KeyType, ValueType, StateType, EmittedType](
- trackingFunction: (Option[ValueType], State[StateType]) => EmittedType
- ): StateSpec[KeyType, ValueType, StateType, EmittedType] = {
- ClosureCleaner.clean(trackingFunction, checkSerializable = true)
+ def function[KeyType, ValueType, StateType, MappedType](
+ mappingFunction: (KeyType, Option[ValueType], State[StateType]) => MappedType
+ ): StateSpec[KeyType, ValueType, StateType, MappedType] = {
+ ClosureCleaner.clean(mappingFunction, checkSerializable = true)
val wrappedFunction =
- (time: Time, key: Any, value: Option[ValueType], state: State[StateType]) => {
- Some(trackingFunction(value, state))
+ (time: Time, key: KeyType, value: Option[ValueType], state: State[StateType]) => {
+ Some(mappingFunction(key, value, state))
}
new StateSpecImpl(wrappedFunction)
}
/**
* Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all
- * the specifications of the `trackStateByKey` operation on a
+ * the specifications of the `mapWithState` operation on a
* [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]].
*
- * @param javaTrackingFunction The function applied on every data item to manage the associated
- * state and generate the emitted data
+ * @param mappingFunction The function applied on every data item to manage the associated
+ * state and generate the mapped data
* @tparam KeyType Class of the keys
* @tparam ValueType Class of the values
* @tparam StateType Class of the states data
- * @tparam EmittedType Class of the emitted data
+ * @tparam MappedType Class of the mapped data
*/
- def function[KeyType, ValueType, StateType, EmittedType](javaTrackingFunction:
- JFunction4[Time, KeyType, Optional[ValueType], State[StateType], Optional[EmittedType]]):
- StateSpec[KeyType, ValueType, StateType, EmittedType] = {
- val trackingFunc = (time: Time, k: KeyType, v: Option[ValueType], s: State[StateType]) => {
- val t = javaTrackingFunction.call(time, k, JavaUtils.optionToOptional(v), s)
+ def function[KeyType, ValueType, StateType, MappedType](mappingFunction:
+ JFunction4[Time, KeyType, Optional[ValueType], State[StateType], Optional[MappedType]]):
+ StateSpec[KeyType, ValueType, StateType, MappedType] = {
+ val wrappedFunc = (time: Time, k: KeyType, v: Option[ValueType], s: State[StateType]) => {
+ val t = mappingFunction.call(time, k, JavaUtils.optionToOptional(v), s)
Option(t.orNull)
}
- StateSpec.function(trackingFunc)
+ StateSpec.function(wrappedFunc)
}
/**
* Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications
- * of the `trackStateByKey` operation on a
+ * of the `mapWithState` operation on a
* [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]].
*
- * @param javaTrackingFunction The function applied on every data item to manage the associated
- * state and generate the emitted data
+ * @param mappingFunction The function applied on every data item to manage the associated
+ * state and generate the mapped data
* @tparam ValueType Class of the values
* @tparam StateType Class of the states data
- * @tparam EmittedType Class of the emitted data
+ * @tparam MappedType Class of the mapped data
*/
- def function[KeyType, ValueType, StateType, EmittedType](
- javaTrackingFunction: JFunction2[Optional[ValueType], State[StateType], EmittedType]):
- StateSpec[KeyType, ValueType, StateType, EmittedType] = {
- val trackingFunc = (v: Option[ValueType], s: State[StateType]) => {
- javaTrackingFunction.call(Optional.fromNullable(v.get), s)
+ def function[KeyType, ValueType, StateType, MappedType](
+ mappingFunction: JFunction3[KeyType, Optional[ValueType], State[StateType], MappedType]):
+ StateSpec[KeyType, ValueType, StateType, MappedType] = {
+ val wrappedFunc = (k: KeyType, v: Option[ValueType], s: State[StateType]) => {
+ mappingFunction.call(k, Optional.fromNullable(v.get), s)
}
- StateSpec.function(trackingFunc)
+ StateSpec.function(wrappedFunc)
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaMapWithStateDStream.scala
index f459930d06..16c0d6fff8 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaMapWithStateDStream.scala
@@ -19,23 +19,23 @@ package org.apache.spark.streaming.api.java
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaSparkContext
-import org.apache.spark.streaming.dstream.TrackStateDStream
+import org.apache.spark.streaming.dstream.MapWithStateDStream
/**
* :: Experimental ::
- * [[JavaDStream]] representing the stream of records emitted by the tracking function in the
- * `trackStateByKey` operation on a [[JavaPairDStream]]. Additionally, it also gives access to the
+ * DStream representing the stream of data generated by `mapWithState` operation on a
+ * [[JavaPairDStream]]. Additionally, it also gives access to the
* stream of state snapshots, that is, the state data of all keys after a batch has updated them.
*
- * @tparam KeyType Class of the state key
- * @tparam ValueType Class of the state value
- * @tparam StateType Class of the state
- * @tparam EmittedType Class of the emitted records
+ * @tparam KeyType Class of the keys
+ * @tparam ValueType Class of the values
+ * @tparam StateType Class of the state data
+ * @tparam MappedType Class of the mapped data
*/
@Experimental
-class JavaTrackStateDStream[KeyType, ValueType, StateType, EmittedType](
- dstream: TrackStateDStream[KeyType, ValueType, StateType, EmittedType])
- extends JavaDStream[EmittedType](dstream)(JavaSparkContext.fakeClassTag) {
+class JavaMapWithStateDStream[KeyType, ValueType, StateType, MappedType] private[streaming](
+ dstream: MapWithStateDStream[KeyType, ValueType, StateType, MappedType])
+ extends JavaDStream[MappedType](dstream)(JavaSparkContext.fakeClassTag) {
def stateSnapshots(): JavaPairDStream[KeyType, StateType] =
new JavaPairDStream(dstream.stateSnapshots())(
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
index 70e32b383e..42ddd63f0f 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
@@ -430,42 +430,36 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
/**
* :: Experimental ::
- * Return a new [[JavaDStream]] of data generated by combining the key-value data in `this` stream
- * with a continuously updated per-key state. The user-provided state tracking function is
- * applied on each keyed data item along with its corresponding state. The function can choose to
- * update/remove the state and return a transformed data, which forms the
- * [[JavaTrackStateDStream]].
+ * Return a [[JavaMapWithStateDStream]] by applying a function to every key-value element of
+ * `this` stream, while maintaining some state data for each unique key. The mapping function
+ * and other specification (e.g. partitioners, timeouts, initial state data, etc.) of this
+ * transformation can be specified using [[StateSpec]] class. The state data is accessible in
+ * as a parameter of type [[State]] in the mapping function.
*
- * The specifications of this transformation is made through the
- * [[org.apache.spark.streaming.StateSpec StateSpec]] class. Besides the tracking function, there
- * are a number of optional parameters - initial state data, number of partitions, timeouts, etc.
- * See the [[org.apache.spark.streaming.StateSpec StateSpec]] for more details.
- *
- * Example of using `trackStateByKey`:
+ * Example of using `mapWithState`:
* {{{
- * // A tracking function that maintains an integer state and return a String
- * Function2<Optional<Integer>, State<Integer>, Optional<String>> trackStateFunc =
- * new Function2<Optional<Integer>, State<Integer>, Optional<String>>() {
- *
- * @Override
- * public Optional<String> call(Optional<Integer> one, State<Integer> state) {
- * // Check if state exists, accordingly update/remove state and return transformed data
- * }
+ * // A mapping function that maintains an integer state and return a string
+ * Function3<String, Optional<Integer>, State<Integer>, String> mappingFunction =
+ * new Function3<String, Optional<Integer>, State<Integer>, String>() {
+ * @Override
+ * public Optional<String> call(Optional<Integer> value, State<Integer> state) {
+ * // Use state.exists(), state.get(), state.update() and state.remove()
+ * // to manage state, and return the necessary string
+ * }
* };
*
- * JavaTrackStateDStream<Integer, Integer, Integer, String> trackStateDStream =
- * keyValueDStream.<Integer, String>trackStateByKey(
- * StateSpec.function(trackStateFunc).numPartitions(10));
- * }}}
+ * JavaMapWithStateDStream<String, Integer, Integer, String> mapWithStateDStream =
+ * keyValueDStream.mapWithState(StateSpec.function(mappingFunc));
+ *}}}
*
* @param spec Specification of this transformation
- * @tparam StateType Class type of the state
- * @tparam EmittedType Class type of the tranformed data return by the tracking function
+ * @tparam StateType Class type of the state data
+ * @tparam MappedType Class type of the mapped data
*/
@Experimental
- def trackStateByKey[StateType, EmittedType](spec: StateSpec[K, V, StateType, EmittedType]):
- JavaTrackStateDStream[K, V, StateType, EmittedType] = {
- new JavaTrackStateDStream(dstream.trackStateByKey(spec)(
+ def mapWithState[StateType, MappedType](spec: StateSpec[K, V, StateType, MappedType]):
+ JavaMapWithStateDStream[K, V, StateType, MappedType] = {
+ new JavaMapWithStateDStream(dstream.mapWithState(spec)(
JavaSparkContext.fakeClassTag,
JavaSparkContext.fakeClassTag))
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala
index ea6213420e..706465d4e2 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala
@@ -24,53 +24,52 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.{EmptyRDD, RDD}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming._
-import org.apache.spark.streaming.rdd.{TrackStateRDD, TrackStateRDDRecord}
-import org.apache.spark.streaming.dstream.InternalTrackStateDStream._
+import org.apache.spark.streaming.rdd.{MapWithStateRDD, MapWithStateRDDRecord}
+import org.apache.spark.streaming.dstream.InternalMapWithStateDStream._
/**
* :: Experimental ::
- * DStream representing the stream of records emitted by the tracking function in the
- * `trackStateByKey` operation on a
+ * DStream representing the stream of data generated by `mapWithState` operation on a
* [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]].
* Additionally, it also gives access to the stream of state snapshots, that is, the state data of
* all keys after a batch has updated them.
*
- * @tparam KeyType Class of the state key
- * @tparam ValueType Class of the state value
+ * @tparam KeyType Class of the key
+ * @tparam ValueType Class of the value
* @tparam StateType Class of the state data
- * @tparam EmittedType Class of the emitted records
+ * @tparam MappedType Class of the mapped data
*/
@Experimental
-sealed abstract class TrackStateDStream[KeyType, ValueType, StateType, EmittedType: ClassTag](
- ssc: StreamingContext) extends DStream[EmittedType](ssc) {
+sealed abstract class MapWithStateDStream[KeyType, ValueType, StateType, MappedType: ClassTag](
+ ssc: StreamingContext) extends DStream[MappedType](ssc) {
/** Return a pair DStream where each RDD is the snapshot of the state of all the keys. */
def stateSnapshots(): DStream[(KeyType, StateType)]
}
-/** Internal implementation of the [[TrackStateDStream]] */
-private[streaming] class TrackStateDStreamImpl[
- KeyType: ClassTag, ValueType: ClassTag, StateType: ClassTag, EmittedType: ClassTag](
+/** Internal implementation of the [[MapWithStateDStream]] */
+private[streaming] class MapWithStateDStreamImpl[
+ KeyType: ClassTag, ValueType: ClassTag, StateType: ClassTag, MappedType: ClassTag](
dataStream: DStream[(KeyType, ValueType)],
- spec: StateSpecImpl[KeyType, ValueType, StateType, EmittedType])
- extends TrackStateDStream[KeyType, ValueType, StateType, EmittedType](dataStream.context) {
+ spec: StateSpecImpl[KeyType, ValueType, StateType, MappedType])
+ extends MapWithStateDStream[KeyType, ValueType, StateType, MappedType](dataStream.context) {
private val internalStream =
- new InternalTrackStateDStream[KeyType, ValueType, StateType, EmittedType](dataStream, spec)
+ new InternalMapWithStateDStream[KeyType, ValueType, StateType, MappedType](dataStream, spec)
override def slideDuration: Duration = internalStream.slideDuration
override def dependencies: List[DStream[_]] = List(internalStream)
- override def compute(validTime: Time): Option[RDD[EmittedType]] = {
- internalStream.getOrCompute(validTime).map { _.flatMap[EmittedType] { _.emittedRecords } }
+ override def compute(validTime: Time): Option[RDD[MappedType]] = {
+ internalStream.getOrCompute(validTime).map { _.flatMap[MappedType] { _.mappedData } }
}
/**
* Forward the checkpoint interval to the internal DStream that computes the state maps. This
* to make sure that this DStream does not get checkpointed, only the internal stream.
*/
- override def checkpoint(checkpointInterval: Duration): DStream[EmittedType] = {
+ override def checkpoint(checkpointInterval: Duration): DStream[MappedType] = {
internalStream.checkpoint(checkpointInterval)
this
}
@@ -87,32 +86,32 @@ private[streaming] class TrackStateDStreamImpl[
def stateClass: Class[_] = implicitly[ClassTag[StateType]].runtimeClass
- def emittedClass: Class[_] = implicitly[ClassTag[EmittedType]].runtimeClass
+ def mappedClass: Class[_] = implicitly[ClassTag[MappedType]].runtimeClass
}
/**
* A DStream that allows per-key state to be maintains, and arbitrary records to be generated
- * based on updates to the state. This is the main DStream that implements the `trackStateByKey`
+ * based on updates to the state. This is the main DStream that implements the `mapWithState`
* operation on DStreams.
*
* @param parent Parent (key, value) stream that is the source
- * @param spec Specifications of the trackStateByKey operation
+ * @param spec Specifications of the mapWithState operation
* @tparam K Key type
* @tparam V Value type
* @tparam S Type of the state maintained
- * @tparam E Type of the emitted data
+ * @tparam E Type of the mapped data
*/
private[streaming]
-class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
+class InternalMapWithStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
parent: DStream[(K, V)], spec: StateSpecImpl[K, V, S, E])
- extends DStream[TrackStateRDDRecord[K, S, E]](parent.context) {
+ extends DStream[MapWithStateRDDRecord[K, S, E]](parent.context) {
persist(StorageLevel.MEMORY_ONLY)
private val partitioner = spec.getPartitioner().getOrElse(
new HashPartitioner(ssc.sc.defaultParallelism))
- private val trackingFunction = spec.getFunction()
+ private val mappingFunction = spec.getFunction()
override def slideDuration: Duration = parent.slideDuration
@@ -130,7 +129,7 @@ class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassT
}
/** Method that generates a RDD for the given time */
- override def compute(validTime: Time): Option[RDD[TrackStateRDDRecord[K, S, E]]] = {
+ override def compute(validTime: Time): Option[RDD[MapWithStateRDDRecord[K, S, E]]] = {
// Get the previous state or create a new empty state RDD
val prevStateRDD = getOrCompute(validTime - slideDuration) match {
case Some(rdd) =>
@@ -138,13 +137,13 @@ class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassT
// If the RDD is not partitioned the right way, let us repartition it using the
// partition index as the key. This is to ensure that state RDD is always partitioned
// before creating another state RDD using it
- TrackStateRDD.createFromRDD[K, V, S, E](
+ MapWithStateRDD.createFromRDD[K, V, S, E](
rdd.flatMap { _.stateMap.getAll() }, partitioner, validTime)
} else {
rdd
}
case None =>
- TrackStateRDD.createFromPairRDD[K, V, S, E](
+ MapWithStateRDD.createFromPairRDD[K, V, S, E](
spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)),
partitioner,
validTime
@@ -161,11 +160,11 @@ class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassT
val timeoutThresholdTime = spec.getTimeoutInterval().map { interval =>
(validTime - interval).milliseconds
}
- Some(new TrackStateRDD(
- prevStateRDD, partitionedDataRDD, trackingFunction, validTime, timeoutThresholdTime))
+ Some(new MapWithStateRDD(
+ prevStateRDD, partitionedDataRDD, mappingFunction, validTime, timeoutThresholdTime))
}
}
-private[streaming] object InternalTrackStateDStream {
+private[streaming] object InternalMapWithStateDStream {
private val DEFAULT_CHECKPOINT_DURATION_MULTIPLIER = 10
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
index 2762309134..a64a1fe93f 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
@@ -352,39 +352,36 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)])
/**
* :: Experimental ::
- * Return a new DStream of data generated by combining the key-value data in `this` stream
- * with a continuously updated per-key state. The user-provided state tracking function is
- * applied on each keyed data item along with its corresponding state. The function can choose to
- * update/remove the state and return a transformed data, which forms the
- * [[org.apache.spark.streaming.dstream.TrackStateDStream]].
+ * Return a [[MapWithStateDStream]] by applying a function to every key-value element of
+ * `this` stream, while maintaining some state data for each unique key. The mapping function
+ * and other specification (e.g. partitioners, timeouts, initial state data, etc.) of this
+ * transformation can be specified using [[StateSpec]] class. The state data is accessible in
+ * as a parameter of type [[State]] in the mapping function.
*
- * The specifications of this transformation is made through the
- * [[org.apache.spark.streaming.StateSpec StateSpec]] class. Besides the tracking function, there
- * are a number of optional parameters - initial state data, number of partitions, timeouts, etc.
- * See the [[org.apache.spark.streaming.StateSpec StateSpec spec docs]] for more details.
- *
- * Example of using `trackStateByKey`:
+ * Example of using `mapWithState`:
* {{{
- * def trackingFunction(data: Option[Int], wrappedState: State[Int]): String = {
- * // Check if state exists, accordingly update/remove state and return transformed data
+ * // A mapping function that maintains an integer state and return a String
+ * def mappingFunction(key: String, value: Option[Int], state: State[Int]): Option[String] = {
+ * // Use state.exists(), state.get(), state.update() and state.remove()
+ * // to manage state, and return the necessary string
* }
*
- * val spec = StateSpec.function(trackingFunction).numPartitions(10)
+ * val spec = StateSpec.function(mappingFunction).numPartitions(10)
*
- * val trackStateDStream = keyValueDStream.trackStateByKey[Int, String](spec)
+ * val mapWithStateDStream = keyValueDStream.mapWithState[StateType, MappedType](spec)
* }}}
*
* @param spec Specification of this transformation
- * @tparam StateType Class type of the state
- * @tparam EmittedType Class type of the tranformed data return by the tracking function
+ * @tparam StateType Class type of the state data
+ * @tparam MappedType Class type of the mapped data
*/
@Experimental
- def trackStateByKey[StateType: ClassTag, EmittedType: ClassTag](
- spec: StateSpec[K, V, StateType, EmittedType]
- ): TrackStateDStream[K, V, StateType, EmittedType] = {
- new TrackStateDStreamImpl[K, V, StateType, EmittedType](
+ def mapWithState[StateType: ClassTag, MappedType: ClassTag](
+ spec: StateSpec[K, V, StateType, MappedType]
+ ): MapWithStateDStream[K, V, StateType, MappedType] = {
+ new MapWithStateDStreamImpl[K, V, StateType, MappedType](
self,
- spec.asInstanceOf[StateSpecImpl[K, V, StateType, EmittedType]]
+ spec.asInstanceOf[StateSpecImpl[K, V, StateType, MappedType]]
)
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala
index 30aafcf146..ed95171f73 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala
@@ -29,60 +29,60 @@ import org.apache.spark.util.Utils
import org.apache.spark._
/**
- * Record storing the keyed-state [[TrackStateRDD]]. Each record contains a [[StateMap]] and a
- * sequence of records returned by the tracking function of `trackStateByKey`.
+ * Record storing the keyed-state [[MapWithStateRDD]]. Each record contains a [[StateMap]] and a
+ * sequence of records returned by the mapping function of `mapWithState`.
*/
-private[streaming] case class TrackStateRDDRecord[K, S, E](
- var stateMap: StateMap[K, S], var emittedRecords: Seq[E])
+private[streaming] case class MapWithStateRDDRecord[K, S, E](
+ var stateMap: StateMap[K, S], var mappedData: Seq[E])
-private[streaming] object TrackStateRDDRecord {
+private[streaming] object MapWithStateRDDRecord {
def updateRecordWithData[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
- prevRecord: Option[TrackStateRDDRecord[K, S, E]],
+ prevRecord: Option[MapWithStateRDDRecord[K, S, E]],
dataIterator: Iterator[(K, V)],
- updateFunction: (Time, K, Option[V], State[S]) => Option[E],
+ mappingFunction: (Time, K, Option[V], State[S]) => Option[E],
batchTime: Time,
timeoutThresholdTime: Option[Long],
removeTimedoutData: Boolean
- ): TrackStateRDDRecord[K, S, E] = {
+ ): MapWithStateRDDRecord[K, S, E] = {
// Create a new state map by cloning the previous one (if it exists) or by creating an empty one
val newStateMap = prevRecord.map { _.stateMap.copy() }. getOrElse { new EmptyStateMap[K, S]() }
- val emittedRecords = new ArrayBuffer[E]
+ val mappedData = new ArrayBuffer[E]
val wrappedState = new StateImpl[S]()
- // Call the tracking function on each record in the data iterator, and accordingly
- // update the states touched, and collect the data returned by the tracking function
+ // Call the mapping function on each record in the data iterator, and accordingly
+ // update the states touched, and collect the data returned by the mapping function
dataIterator.foreach { case (key, value) =>
wrappedState.wrap(newStateMap.get(key))
- val emittedRecord = updateFunction(batchTime, key, Some(value), wrappedState)
+ val returned = mappingFunction(batchTime, key, Some(value), wrappedState)
if (wrappedState.isRemoved) {
newStateMap.remove(key)
} else if (wrappedState.isUpdated || timeoutThresholdTime.isDefined) {
newStateMap.put(key, wrappedState.get(), batchTime.milliseconds)
}
- emittedRecords ++= emittedRecord
+ mappedData ++= returned
}
- // Get the timed out state records, call the tracking function on each and collect the
+ // Get the timed out state records, call the mapping function on each and collect the
// data returned
if (removeTimedoutData && timeoutThresholdTime.isDefined) {
newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) =>
wrappedState.wrapTiminoutState(state)
- val emittedRecord = updateFunction(batchTime, key, None, wrappedState)
- emittedRecords ++= emittedRecord
+ val returned = mappingFunction(batchTime, key, None, wrappedState)
+ mappedData ++= returned
newStateMap.remove(key)
}
}
- TrackStateRDDRecord(newStateMap, emittedRecords)
+ MapWithStateRDDRecord(newStateMap, mappedData)
}
}
/**
- * Partition of the [[TrackStateRDD]], which depends on corresponding partitions of prev state
+ * Partition of the [[MapWithStateRDD]], which depends on corresponding partitions of prev state
* RDD, and a partitioned keyed-data RDD
*/
-private[streaming] class TrackStateRDDPartition(
+private[streaming] class MapWithStateRDDPartition(
idx: Int,
@transient private var prevStateRDD: RDD[_],
@transient private var partitionedDataRDD: RDD[_]) extends Partition {
@@ -104,27 +104,28 @@ private[streaming] class TrackStateRDDPartition(
/**
- * RDD storing the keyed-state of `trackStateByKey` and corresponding emitted records.
- * Each partition of this RDD has a single record of type [[TrackStateRDDRecord]]. This contains a
- * [[StateMap]] (containing the keyed-states) and the sequence of records returned by the tracking
- * function of `trackStateByKey`.
- * @param prevStateRDD The previous TrackStateRDD on whose StateMap data `this` RDD will be created
+ * RDD storing the keyed states of `mapWithState` operation and corresponding mapped data.
+ * Each partition of this RDD has a single record of type [[MapWithStateRDDRecord]]. This contains a
+ * [[StateMap]] (containing the keyed-states) and the sequence of records returned by the mapping
+ * function of `mapWithState`.
+ * @param prevStateRDD The previous MapWithStateRDD on whose StateMap data `this` RDD
+ * will be created
* @param partitionedDataRDD The partitioned data RDD which is used update the previous StateMaps
* in the `prevStateRDD` to create `this` RDD
- * @param trackingFunction The function that will be used to update state and return new data
+ * @param mappingFunction The function that will be used to update state and return new data
* @param batchTime The time of the batch to which this RDD belongs to. Use to update
* @param timeoutThresholdTime The time to indicate which keys are timeout
*/
-private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
- private var prevStateRDD: RDD[TrackStateRDDRecord[K, S, E]],
+private[streaming] class MapWithStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
+ private var prevStateRDD: RDD[MapWithStateRDDRecord[K, S, E]],
private var partitionedDataRDD: RDD[(K, V)],
- trackingFunction: (Time, K, Option[V], State[S]) => Option[E],
+ mappingFunction: (Time, K, Option[V], State[S]) => Option[E],
batchTime: Time,
timeoutThresholdTime: Option[Long]
- ) extends RDD[TrackStateRDDRecord[K, S, E]](
+ ) extends RDD[MapWithStateRDDRecord[K, S, E]](
partitionedDataRDD.sparkContext,
List(
- new OneToOneDependency[TrackStateRDDRecord[K, S, E]](prevStateRDD),
+ new OneToOneDependency[MapWithStateRDDRecord[K, S, E]](prevStateRDD),
new OneToOneDependency(partitionedDataRDD))
) {
@@ -141,19 +142,19 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E:
}
override def compute(
- partition: Partition, context: TaskContext): Iterator[TrackStateRDDRecord[K, S, E]] = {
+ partition: Partition, context: TaskContext): Iterator[MapWithStateRDDRecord[K, S, E]] = {
- val stateRDDPartition = partition.asInstanceOf[TrackStateRDDPartition]
+ val stateRDDPartition = partition.asInstanceOf[MapWithStateRDDPartition]
val prevStateRDDIterator = prevStateRDD.iterator(
stateRDDPartition.previousSessionRDDPartition, context)
val dataIterator = partitionedDataRDD.iterator(
stateRDDPartition.partitionedDataRDDPartition, context)
val prevRecord = if (prevStateRDDIterator.hasNext) Some(prevStateRDDIterator.next()) else None
- val newRecord = TrackStateRDDRecord.updateRecordWithData(
+ val newRecord = MapWithStateRDDRecord.updateRecordWithData(
prevRecord,
dataIterator,
- trackingFunction,
+ mappingFunction,
batchTime,
timeoutThresholdTime,
removeTimedoutData = doFullScan // remove timedout data only when full scan is enabled
@@ -163,7 +164,7 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E:
override protected def getPartitions: Array[Partition] = {
Array.tabulate(prevStateRDD.partitions.length) { i =>
- new TrackStateRDDPartition(i, prevStateRDD, partitionedDataRDD)}
+ new MapWithStateRDDPartition(i, prevStateRDD, partitionedDataRDD)}
}
override def clearDependencies(): Unit = {
@@ -177,52 +178,46 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E:
}
}
-private[streaming] object TrackStateRDD {
+private[streaming] object MapWithStateRDD {
def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
pairRDD: RDD[(K, S)],
partitioner: Partitioner,
- updateTime: Time): TrackStateRDD[K, V, S, E] = {
+ updateTime: Time): MapWithStateRDD[K, V, S, E] = {
- val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions ({ iterator =>
+ val stateRDD = pairRDD.partitionBy(partitioner).mapPartitions ({ iterator =>
val stateMap = StateMap.create[K, S](SparkEnv.get.conf)
iterator.foreach { case (key, state) => stateMap.put(key, state, updateTime.milliseconds) }
- Iterator(TrackStateRDDRecord(stateMap, Seq.empty[E]))
+ Iterator(MapWithStateRDDRecord(stateMap, Seq.empty[E]))
}, preservesPartitioning = true)
val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner)
val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None
- new TrackStateRDD[K, V, S, E](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None)
+ new MapWithStateRDD[K, V, S, E](
+ stateRDD, emptyDataRDD, noOpFunc, updateTime, None)
}
def createFromRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
rdd: RDD[(K, S, Long)],
partitioner: Partitioner,
- updateTime: Time): TrackStateRDD[K, V, S, E] = {
+ updateTime: Time): MapWithStateRDD[K, V, S, E] = {
val pairRDD = rdd.map { x => (x._1, (x._2, x._3)) }
- val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions({ iterator =>
+ val stateRDD = pairRDD.partitionBy(partitioner).mapPartitions({ iterator =>
val stateMap = StateMap.create[K, S](SparkEnv.get.conf)
iterator.foreach { case (key, (state, updateTime)) =>
stateMap.put(key, state, updateTime)
}
- Iterator(TrackStateRDDRecord(stateMap, Seq.empty[E]))
+ Iterator(MapWithStateRDDRecord(stateMap, Seq.empty[E]))
}, preservesPartitioning = true)
val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner)
val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None
- new TrackStateRDD[K, V, S, E](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None)
- }
-}
-
-private[streaming] class EmittedRecordsRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
- parent: TrackStateRDD[K, V, S, T]) extends RDD[T](parent) {
- override protected def getPartitions: Array[Partition] = parent.partitions
- override def compute(partition: Partition, context: TaskContext): Iterator[T] = {
- parent.compute(partition, context).flatMap { _.emittedRecords }
+ new MapWithStateRDD[K, V, S, E](
+ stateRDD, emptyDataRDD, noOpFunc, updateTime, None)
}
}