aboutsummaryrefslogtreecommitdiff
path: root/streaming/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'streaming/src/main')
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/State.scala20
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala160
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaMapWithStateDStream.scala (renamed from streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala)20
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala50
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala (renamed from streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala)61
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala41
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala (renamed from streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala)99
7 files changed, 230 insertions, 221 deletions
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/State.scala b/streaming/src/main/scala/org/apache/spark/streaming/State.scala
index 604e64fc61..b47bdda2c2 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/State.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/State.scala
@@ -23,14 +23,14 @@ import org.apache.spark.annotation.Experimental
/**
* :: Experimental ::
- * Abstract class for getting and updating the tracked state in the `trackStateByKey` operation of
- * a [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a
- * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
+ * Abstract class for getting and updating the state in mapping function used in the `mapWithState`
+ * operation of a [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala)
+ * or a [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
*
* Scala example of using `State`:
* {{{
- * // A tracking function that maintains an integer state and return a String
- * def trackStateFunc(data: Option[Int], state: State[Int]): Option[String] = {
+ * // A mapping function that maintains an integer state and returns a String
+ * def mappingFunction(key: String, value: Option[Int], state: State[Int]): Option[String] = {
* // Check if state exists
* if (state.exists) {
* val existingState = state.get // Get the existing state
@@ -52,12 +52,12 @@ import org.apache.spark.annotation.Experimental
*
* Java example of using `State`:
* {{{
- * // A tracking function that maintains an integer state and return a String
- * Function2<Optional<Integer>, State<Integer>, Optional<String>> trackStateFunc =
- * new Function2<Optional<Integer>, State<Integer>, Optional<String>>() {
+ * // A mapping function that maintains an integer state and returns a String
+ * Function3<String, Optional<Integer>, State<Integer>, String> mappingFunction =
+ * new Function3<String, Optional<Integer>, State<Integer>, String>() {
*
* @Override
- * public Optional<String> call(Optional<Integer> one, State<Integer> state) {
+ * public String call(String key, Optional<Integer> value, State<Integer> state) {
* if (state.exists()) {
* int existingState = state.get(); // Get the existing state
* boolean shouldRemove = ...; // Decide whether to remove the state
@@ -75,6 +75,8 @@ import org.apache.spark.annotation.Experimental
* }
* };
* }}}
+ *
+ * @tparam S Class of the state
*/
@Experimental
sealed abstract class State[S] {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
index bea5b9df20..9f6f95223f 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
@@ -20,7 +20,7 @@ package org.apache.spark.streaming
import com.google.common.base.Optional
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.{JavaPairRDD, JavaUtils}
-import org.apache.spark.api.java.function.{Function2 => JFunction2, Function4 => JFunction4}
+import org.apache.spark.api.java.function.{Function3 => JFunction3, Function4 => JFunction4}
import org.apache.spark.rdd.RDD
import org.apache.spark.util.ClosureCleaner
import org.apache.spark.{HashPartitioner, Partitioner}
@@ -28,7 +28,7 @@ import org.apache.spark.{HashPartitioner, Partitioner}
/**
* :: Experimental ::
* Abstract class representing all the specifications of the DStream transformation
- * `trackStateByKey` operation of a
+ * `mapWithState` operation of a
* [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a
* [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
* Use the [[org.apache.spark.streaming.StateSpec StateSpec.apply()]] or
@@ -37,50 +37,63 @@ import org.apache.spark.{HashPartitioner, Partitioner}
*
* Example in Scala:
* {{{
- * def trackingFunction(data: Option[ValueType], wrappedState: State[StateType]): EmittedType = {
- * ...
+ * // A mapping function that maintains an integer state and return a String
+ * def mappingFunction(key: String, value: Option[Int], state: State[Int]): Option[String] = {
+ * // Use state.exists(), state.get(), state.update() and state.remove()
+ * // to manage state, and return the necessary string
* }
*
- * val spec = StateSpec.function(trackingFunction).numPartitions(10)
+ * val spec = StateSpec.function(mappingFunction).numPartitions(10)
*
- * val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType](spec)
+ * val mapWithStateDStream = keyValueDStream.mapWithState[StateType, MappedType](spec)
* }}}
*
* Example in Java:
* {{{
- * StateSpec<KeyType, ValueType, StateType, EmittedDataType> spec =
- * StateSpec.<KeyType, ValueType, StateType, EmittedDataType>function(trackingFunction)
- * .numPartition(10);
+ * // A mapping function that maintains an integer state and return a string
+ * Function3<String, Optional<Integer>, State<Integer>, String> mappingFunction =
+ * new Function3<String, Optional<Integer>, State<Integer>, String>() {
+ * @Override
+ * public Optional<String> call(Optional<Integer> value, State<Integer> state) {
+ * // Use state.exists(), state.get(), state.update() and state.remove()
+ * // to manage state, and return the necessary string
+ * }
+ * };
*
- * JavaTrackStateDStream<KeyType, ValueType, StateType, EmittedType> emittedRecordDStream =
- * javaPairDStream.<StateType, EmittedDataType>trackStateByKey(spec);
+ * JavaMapWithStateDStream<String, Integer, Integer, String> mapWithStateDStream =
+ * keyValueDStream.mapWithState(StateSpec.function(mappingFunc));
* }}}
+ *
+ * @tparam KeyType Class of the state key
+ * @tparam ValueType Class of the state value
+ * @tparam StateType Class of the state data
+ * @tparam MappedType Class of the mapped elements
*/
@Experimental
-sealed abstract class StateSpec[KeyType, ValueType, StateType, EmittedType] extends Serializable {
+sealed abstract class StateSpec[KeyType, ValueType, StateType, MappedType] extends Serializable {
- /** Set the RDD containing the initial states that will be used by `trackStateByKey` */
+ /** Set the RDD containing the initial states that will be used by `mapWithState` */
def initialState(rdd: RDD[(KeyType, StateType)]): this.type
- /** Set the RDD containing the initial states that will be used by `trackStateByKey` */
+ /** Set the RDD containing the initial states that will be used by `mapWithState` */
def initialState(javaPairRDD: JavaPairRDD[KeyType, StateType]): this.type
/**
- * Set the number of partitions by which the state RDDs generated by `trackStateByKey`
+ * Set the number of partitions by which the state RDDs generated by `mapWithState`
* will be partitioned. Hash partitioning will be used.
*/
def numPartitions(numPartitions: Int): this.type
/**
- * Set the partitioner by which the state RDDs generated by `trackStateByKey` will be
+ * Set the partitioner by which the state RDDs generated by `mapWithState` will be
* be partitioned.
*/
def partitioner(partitioner: Partitioner): this.type
/**
* Set the duration after which the state of an idle key will be removed. A key and its state is
- * considered idle if it has not received any data for at least the given duration. The state
- * tracking function will be called one final time on the idle states that are going to be
+ * considered idle if it has not received any data for at least the given duration. The
+ * mapping function will be called one final time on the idle states that are going to be
* removed; [[org.apache.spark.streaming.State State.isTimingOut()]] set
* to `true` in that call.
*/
@@ -91,115 +104,124 @@ sealed abstract class StateSpec[KeyType, ValueType, StateType, EmittedType] exte
/**
* :: Experimental ::
* Builder object for creating instances of [[org.apache.spark.streaming.StateSpec StateSpec]]
- * that is used for specifying the parameters of the DStream transformation `trackStateByKey`
+ * that is used for specifying the parameters of the DStream transformation `mapWithState`
* that is used for specifying the parameters of the DStream transformation
- * `trackStateByKey` operation of a
+ * `mapWithState` operation of a
* [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a
* [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
*
* Example in Scala:
* {{{
- * def trackingFunction(data: Option[ValueType], wrappedState: State[StateType]): EmittedType = {
- * ...
+ * // A mapping function that maintains an integer state and return a String
+ * def mappingFunction(key: String, value: Option[Int], state: State[Int]): Option[String] = {
+ * // Use state.exists(), state.get(), state.update() and state.remove()
+ * // to manage state, and return the necessary string
* }
*
- * val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType](
- * StateSpec.function(trackingFunction).numPartitions(10))
+ * val spec = StateSpec.function(mappingFunction).numPartitions(10)
+ *
+ * val mapWithStateDStream = keyValueDStream.mapWithState[StateType, MappedType](spec)
* }}}
*
* Example in Java:
* {{{
- * StateSpec<KeyType, ValueType, StateType, EmittedDataType> spec =
- * StateSpec.<KeyType, ValueType, StateType, EmittedDataType>function(trackingFunction)
- * .numPartition(10);
+ * // A mapping function that maintains an integer state and return a string
+ * Function3<String, Optional<Integer>, State<Integer>, String> mappingFunction =
+ * new Function3<String, Optional<Integer>, State<Integer>, String>() {
+ * @Override
+ * public Optional<String> call(Optional<Integer> value, State<Integer> state) {
+ * // Use state.exists(), state.get(), state.update() and state.remove()
+ * // to manage state, and return the necessary string
+ * }
+ * };
*
- * JavaTrackStateDStream<KeyType, ValueType, StateType, EmittedType> emittedRecordDStream =
- * javaPairDStream.<StateType, EmittedDataType>trackStateByKey(spec);
- * }}}
+ * JavaMapWithStateDStream<String, Integer, Integer, String> mapWithStateDStream =
+ * keyValueDStream.mapWithState(StateSpec.function(mappingFunc));
+ *}}}
*/
@Experimental
object StateSpec {
/**
* Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications
- * of the `trackStateByKey` operation on a
+ * of the `mapWithState` operation on a
* [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]].
*
- * @param trackingFunction The function applied on every data item to manage the associated state
- * and generate the emitted data
+ * @param mappingFunction The function applied on every data item to manage the associated state
+ * and generate the mapped data
* @tparam KeyType Class of the keys
* @tparam ValueType Class of the values
* @tparam StateType Class of the states data
- * @tparam EmittedType Class of the emitted data
+ * @tparam MappedType Class of the mapped data
*/
- def function[KeyType, ValueType, StateType, EmittedType](
- trackingFunction: (Time, KeyType, Option[ValueType], State[StateType]) => Option[EmittedType]
- ): StateSpec[KeyType, ValueType, StateType, EmittedType] = {
- ClosureCleaner.clean(trackingFunction, checkSerializable = true)
- new StateSpecImpl(trackingFunction)
+ def function[KeyType, ValueType, StateType, MappedType](
+ mappingFunction: (Time, KeyType, Option[ValueType], State[StateType]) => Option[MappedType]
+ ): StateSpec[KeyType, ValueType, StateType, MappedType] = {
+ ClosureCleaner.clean(mappingFunction, checkSerializable = true)
+ new StateSpecImpl(mappingFunction)
}
/**
* Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications
- * of the `trackStateByKey` operation on a
+ * of the `mapWithState` operation on a
* [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]].
*
- * @param trackingFunction The function applied on every data item to manage the associated state
- * and generate the emitted data
+ * @param mappingFunction The function applied on every data item to manage the associated state
+ * and generate the mapped data
* @tparam ValueType Class of the values
* @tparam StateType Class of the states data
- * @tparam EmittedType Class of the emitted data
+ * @tparam MappedType Class of the mapped data
*/
- def function[KeyType, ValueType, StateType, EmittedType](
- trackingFunction: (Option[ValueType], State[StateType]) => EmittedType
- ): StateSpec[KeyType, ValueType, StateType, EmittedType] = {
- ClosureCleaner.clean(trackingFunction, checkSerializable = true)
+ def function[KeyType, ValueType, StateType, MappedType](
+ mappingFunction: (KeyType, Option[ValueType], State[StateType]) => MappedType
+ ): StateSpec[KeyType, ValueType, StateType, MappedType] = {
+ ClosureCleaner.clean(mappingFunction, checkSerializable = true)
val wrappedFunction =
- (time: Time, key: Any, value: Option[ValueType], state: State[StateType]) => {
- Some(trackingFunction(value, state))
+ (time: Time, key: KeyType, value: Option[ValueType], state: State[StateType]) => {
+ Some(mappingFunction(key, value, state))
}
new StateSpecImpl(wrappedFunction)
}
/**
* Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all
- * the specifications of the `trackStateByKey` operation on a
+ * the specifications of the `mapWithState` operation on a
* [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]].
*
- * @param javaTrackingFunction The function applied on every data item to manage the associated
- * state and generate the emitted data
+ * @param mappingFunction The function applied on every data item to manage the associated
+ * state and generate the mapped data
* @tparam KeyType Class of the keys
* @tparam ValueType Class of the values
* @tparam StateType Class of the states data
- * @tparam EmittedType Class of the emitted data
+ * @tparam MappedType Class of the mapped data
*/
- def function[KeyType, ValueType, StateType, EmittedType](javaTrackingFunction:
- JFunction4[Time, KeyType, Optional[ValueType], State[StateType], Optional[EmittedType]]):
- StateSpec[KeyType, ValueType, StateType, EmittedType] = {
- val trackingFunc = (time: Time, k: KeyType, v: Option[ValueType], s: State[StateType]) => {
- val t = javaTrackingFunction.call(time, k, JavaUtils.optionToOptional(v), s)
+ def function[KeyType, ValueType, StateType, MappedType](mappingFunction:
+ JFunction4[Time, KeyType, Optional[ValueType], State[StateType], Optional[MappedType]]):
+ StateSpec[KeyType, ValueType, StateType, MappedType] = {
+ val wrappedFunc = (time: Time, k: KeyType, v: Option[ValueType], s: State[StateType]) => {
+ val t = mappingFunction.call(time, k, JavaUtils.optionToOptional(v), s)
Option(t.orNull)
}
- StateSpec.function(trackingFunc)
+ StateSpec.function(wrappedFunc)
}
/**
* Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications
- * of the `trackStateByKey` operation on a
+ * of the `mapWithState` operation on a
* [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]].
*
- * @param javaTrackingFunction The function applied on every data item to manage the associated
- * state and generate the emitted data
+ * @param mappingFunction The function applied on every data item to manage the associated
+ * state and generate the mapped data
* @tparam ValueType Class of the values
* @tparam StateType Class of the states data
- * @tparam EmittedType Class of the emitted data
+ * @tparam MappedType Class of the mapped data
*/
- def function[KeyType, ValueType, StateType, EmittedType](
- javaTrackingFunction: JFunction2[Optional[ValueType], State[StateType], EmittedType]):
- StateSpec[KeyType, ValueType, StateType, EmittedType] = {
- val trackingFunc = (v: Option[ValueType], s: State[StateType]) => {
- javaTrackingFunction.call(Optional.fromNullable(v.get), s)
+ def function[KeyType, ValueType, StateType, MappedType](
+ mappingFunction: JFunction3[KeyType, Optional[ValueType], State[StateType], MappedType]):
+ StateSpec[KeyType, ValueType, StateType, MappedType] = {
+ val wrappedFunc = (k: KeyType, v: Option[ValueType], s: State[StateType]) => {
+ mappingFunction.call(k, Optional.fromNullable(v.get), s)
}
- StateSpec.function(trackingFunc)
+ StateSpec.function(wrappedFunc)
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaMapWithStateDStream.scala
index f459930d06..16c0d6fff8 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaMapWithStateDStream.scala
@@ -19,23 +19,23 @@ package org.apache.spark.streaming.api.java
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaSparkContext
-import org.apache.spark.streaming.dstream.TrackStateDStream
+import org.apache.spark.streaming.dstream.MapWithStateDStream
/**
* :: Experimental ::
- * [[JavaDStream]] representing the stream of records emitted by the tracking function in the
- * `trackStateByKey` operation on a [[JavaPairDStream]]. Additionally, it also gives access to the
+ * DStream representing the stream of data generated by `mapWithState` operation on a
+ * [[JavaPairDStream]]. Additionally, it also gives access to the
* stream of state snapshots, that is, the state data of all keys after a batch has updated them.
*
- * @tparam KeyType Class of the state key
- * @tparam ValueType Class of the state value
- * @tparam StateType Class of the state
- * @tparam EmittedType Class of the emitted records
+ * @tparam KeyType Class of the keys
+ * @tparam ValueType Class of the values
+ * @tparam StateType Class of the state data
+ * @tparam MappedType Class of the mapped data
*/
@Experimental
-class JavaTrackStateDStream[KeyType, ValueType, StateType, EmittedType](
- dstream: TrackStateDStream[KeyType, ValueType, StateType, EmittedType])
- extends JavaDStream[EmittedType](dstream)(JavaSparkContext.fakeClassTag) {
+class JavaMapWithStateDStream[KeyType, ValueType, StateType, MappedType] private[streaming](
+ dstream: MapWithStateDStream[KeyType, ValueType, StateType, MappedType])
+ extends JavaDStream[MappedType](dstream)(JavaSparkContext.fakeClassTag) {
def stateSnapshots(): JavaPairDStream[KeyType, StateType] =
new JavaPairDStream(dstream.stateSnapshots())(
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
index 70e32b383e..42ddd63f0f 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
@@ -430,42 +430,36 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
/**
* :: Experimental ::
- * Return a new [[JavaDStream]] of data generated by combining the key-value data in `this` stream
- * with a continuously updated per-key state. The user-provided state tracking function is
- * applied on each keyed data item along with its corresponding state. The function can choose to
- * update/remove the state and return a transformed data, which forms the
- * [[JavaTrackStateDStream]].
+ * Return a [[JavaMapWithStateDStream]] by applying a function to every key-value element of
+ * `this` stream, while maintaining some state data for each unique key. The mapping function
+ * and other specification (e.g. partitioners, timeouts, initial state data, etc.) of this
+ * transformation can be specified using [[StateSpec]] class. The state data is accessible in
+ * as a parameter of type [[State]] in the mapping function.
*
- * The specifications of this transformation is made through the
- * [[org.apache.spark.streaming.StateSpec StateSpec]] class. Besides the tracking function, there
- * are a number of optional parameters - initial state data, number of partitions, timeouts, etc.
- * See the [[org.apache.spark.streaming.StateSpec StateSpec]] for more details.
- *
- * Example of using `trackStateByKey`:
+ * Example of using `mapWithState`:
* {{{
- * // A tracking function that maintains an integer state and return a String
- * Function2<Optional<Integer>, State<Integer>, Optional<String>> trackStateFunc =
- * new Function2<Optional<Integer>, State<Integer>, Optional<String>>() {
- *
- * @Override
- * public Optional<String> call(Optional<Integer> one, State<Integer> state) {
- * // Check if state exists, accordingly update/remove state and return transformed data
- * }
+ * // A mapping function that maintains an integer state and return a string
+ * Function3<String, Optional<Integer>, State<Integer>, String> mappingFunction =
+ * new Function3<String, Optional<Integer>, State<Integer>, String>() {
+ * @Override
+ * public Optional<String> call(Optional<Integer> value, State<Integer> state) {
+ * // Use state.exists(), state.get(), state.update() and state.remove()
+ * // to manage state, and return the necessary string
+ * }
* };
*
- * JavaTrackStateDStream<Integer, Integer, Integer, String> trackStateDStream =
- * keyValueDStream.<Integer, String>trackStateByKey(
- * StateSpec.function(trackStateFunc).numPartitions(10));
- * }}}
+ * JavaMapWithStateDStream<String, Integer, Integer, String> mapWithStateDStream =
+ * keyValueDStream.mapWithState(StateSpec.function(mappingFunc));
+ *}}}
*
* @param spec Specification of this transformation
- * @tparam StateType Class type of the state
- * @tparam EmittedType Class type of the tranformed data return by the tracking function
+ * @tparam StateType Class type of the state data
+ * @tparam MappedType Class type of the mapped data
*/
@Experimental
- def trackStateByKey[StateType, EmittedType](spec: StateSpec[K, V, StateType, EmittedType]):
- JavaTrackStateDStream[K, V, StateType, EmittedType] = {
- new JavaTrackStateDStream(dstream.trackStateByKey(spec)(
+ def mapWithState[StateType, MappedType](spec: StateSpec[K, V, StateType, MappedType]):
+ JavaMapWithStateDStream[K, V, StateType, MappedType] = {
+ new JavaMapWithStateDStream(dstream.mapWithState(spec)(
JavaSparkContext.fakeClassTag,
JavaSparkContext.fakeClassTag))
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala
index ea6213420e..706465d4e2 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala
@@ -24,53 +24,52 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.{EmptyRDD, RDD}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming._
-import org.apache.spark.streaming.rdd.{TrackStateRDD, TrackStateRDDRecord}
-import org.apache.spark.streaming.dstream.InternalTrackStateDStream._
+import org.apache.spark.streaming.rdd.{MapWithStateRDD, MapWithStateRDDRecord}
+import org.apache.spark.streaming.dstream.InternalMapWithStateDStream._
/**
* :: Experimental ::
- * DStream representing the stream of records emitted by the tracking function in the
- * `trackStateByKey` operation on a
+ * DStream representing the stream of data generated by `mapWithState` operation on a
* [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]].
* Additionally, it also gives access to the stream of state snapshots, that is, the state data of
* all keys after a batch has updated them.
*
- * @tparam KeyType Class of the state key
- * @tparam ValueType Class of the state value
+ * @tparam KeyType Class of the key
+ * @tparam ValueType Class of the value
* @tparam StateType Class of the state data
- * @tparam EmittedType Class of the emitted records
+ * @tparam MappedType Class of the mapped data
*/
@Experimental
-sealed abstract class TrackStateDStream[KeyType, ValueType, StateType, EmittedType: ClassTag](
- ssc: StreamingContext) extends DStream[EmittedType](ssc) {
+sealed abstract class MapWithStateDStream[KeyType, ValueType, StateType, MappedType: ClassTag](
+ ssc: StreamingContext) extends DStream[MappedType](ssc) {
/** Return a pair DStream where each RDD is the snapshot of the state of all the keys. */
def stateSnapshots(): DStream[(KeyType, StateType)]
}
-/** Internal implementation of the [[TrackStateDStream]] */
-private[streaming] class TrackStateDStreamImpl[
- KeyType: ClassTag, ValueType: ClassTag, StateType: ClassTag, EmittedType: ClassTag](
+/** Internal implementation of the [[MapWithStateDStream]] */
+private[streaming] class MapWithStateDStreamImpl[
+ KeyType: ClassTag, ValueType: ClassTag, StateType: ClassTag, MappedType: ClassTag](
dataStream: DStream[(KeyType, ValueType)],
- spec: StateSpecImpl[KeyType, ValueType, StateType, EmittedType])
- extends TrackStateDStream[KeyType, ValueType, StateType, EmittedType](dataStream.context) {
+ spec: StateSpecImpl[KeyType, ValueType, StateType, MappedType])
+ extends MapWithStateDStream[KeyType, ValueType, StateType, MappedType](dataStream.context) {
private val internalStream =
- new InternalTrackStateDStream[KeyType, ValueType, StateType, EmittedType](dataStream, spec)
+ new InternalMapWithStateDStream[KeyType, ValueType, StateType, MappedType](dataStream, spec)
override def slideDuration: Duration = internalStream.slideDuration
override def dependencies: List[DStream[_]] = List(internalStream)
- override def compute(validTime: Time): Option[RDD[EmittedType]] = {
- internalStream.getOrCompute(validTime).map { _.flatMap[EmittedType] { _.emittedRecords } }
+ override def compute(validTime: Time): Option[RDD[MappedType]] = {
+ internalStream.getOrCompute(validTime).map { _.flatMap[MappedType] { _.mappedData } }
}
/**
* Forward the checkpoint interval to the internal DStream that computes the state maps. This
* to make sure that this DStream does not get checkpointed, only the internal stream.
*/
- override def checkpoint(checkpointInterval: Duration): DStream[EmittedType] = {
+ override def checkpoint(checkpointInterval: Duration): DStream[MappedType] = {
internalStream.checkpoint(checkpointInterval)
this
}
@@ -87,32 +86,32 @@ private[streaming] class TrackStateDStreamImpl[
def stateClass: Class[_] = implicitly[ClassTag[StateType]].runtimeClass
- def emittedClass: Class[_] = implicitly[ClassTag[EmittedType]].runtimeClass
+ def mappedClass: Class[_] = implicitly[ClassTag[MappedType]].runtimeClass
}
/**
* A DStream that allows per-key state to be maintains, and arbitrary records to be generated
- * based on updates to the state. This is the main DStream that implements the `trackStateByKey`
+ * based on updates to the state. This is the main DStream that implements the `mapWithState`
* operation on DStreams.
*
* @param parent Parent (key, value) stream that is the source
- * @param spec Specifications of the trackStateByKey operation
+ * @param spec Specifications of the mapWithState operation
* @tparam K Key type
* @tparam V Value type
* @tparam S Type of the state maintained
- * @tparam E Type of the emitted data
+ * @tparam E Type of the mapped data
*/
private[streaming]
-class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
+class InternalMapWithStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
parent: DStream[(K, V)], spec: StateSpecImpl[K, V, S, E])
- extends DStream[TrackStateRDDRecord[K, S, E]](parent.context) {
+ extends DStream[MapWithStateRDDRecord[K, S, E]](parent.context) {
persist(StorageLevel.MEMORY_ONLY)
private val partitioner = spec.getPartitioner().getOrElse(
new HashPartitioner(ssc.sc.defaultParallelism))
- private val trackingFunction = spec.getFunction()
+ private val mappingFunction = spec.getFunction()
override def slideDuration: Duration = parent.slideDuration
@@ -130,7 +129,7 @@ class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassT
}
/** Method that generates a RDD for the given time */
- override def compute(validTime: Time): Option[RDD[TrackStateRDDRecord[K, S, E]]] = {
+ override def compute(validTime: Time): Option[RDD[MapWithStateRDDRecord[K, S, E]]] = {
// Get the previous state or create a new empty state RDD
val prevStateRDD = getOrCompute(validTime - slideDuration) match {
case Some(rdd) =>
@@ -138,13 +137,13 @@ class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassT
// If the RDD is not partitioned the right way, let us repartition it using the
// partition index as the key. This is to ensure that state RDD is always partitioned
// before creating another state RDD using it
- TrackStateRDD.createFromRDD[K, V, S, E](
+ MapWithStateRDD.createFromRDD[K, V, S, E](
rdd.flatMap { _.stateMap.getAll() }, partitioner, validTime)
} else {
rdd
}
case None =>
- TrackStateRDD.createFromPairRDD[K, V, S, E](
+ MapWithStateRDD.createFromPairRDD[K, V, S, E](
spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)),
partitioner,
validTime
@@ -161,11 +160,11 @@ class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassT
val timeoutThresholdTime = spec.getTimeoutInterval().map { interval =>
(validTime - interval).milliseconds
}
- Some(new TrackStateRDD(
- prevStateRDD, partitionedDataRDD, trackingFunction, validTime, timeoutThresholdTime))
+ Some(new MapWithStateRDD(
+ prevStateRDD, partitionedDataRDD, mappingFunction, validTime, timeoutThresholdTime))
}
}
-private[streaming] object InternalTrackStateDStream {
+private[streaming] object InternalMapWithStateDStream {
private val DEFAULT_CHECKPOINT_DURATION_MULTIPLIER = 10
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
index 2762309134..a64a1fe93f 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
@@ -352,39 +352,36 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)])
/**
* :: Experimental ::
- * Return a new DStream of data generated by combining the key-value data in `this` stream
- * with a continuously updated per-key state. The user-provided state tracking function is
- * applied on each keyed data item along with its corresponding state. The function can choose to
- * update/remove the state and return a transformed data, which forms the
- * [[org.apache.spark.streaming.dstream.TrackStateDStream]].
+ * Return a [[MapWithStateDStream]] by applying a function to every key-value element of
+ * `this` stream, while maintaining some state data for each unique key. The mapping function
+ * and other specification (e.g. partitioners, timeouts, initial state data, etc.) of this
+ * transformation can be specified using [[StateSpec]] class. The state data is accessible in
+ * as a parameter of type [[State]] in the mapping function.
*
- * The specifications of this transformation is made through the
- * [[org.apache.spark.streaming.StateSpec StateSpec]] class. Besides the tracking function, there
- * are a number of optional parameters - initial state data, number of partitions, timeouts, etc.
- * See the [[org.apache.spark.streaming.StateSpec StateSpec spec docs]] for more details.
- *
- * Example of using `trackStateByKey`:
+ * Example of using `mapWithState`:
* {{{
- * def trackingFunction(data: Option[Int], wrappedState: State[Int]): String = {
- * // Check if state exists, accordingly update/remove state and return transformed data
+ * // A mapping function that maintains an integer state and return a String
+ * def mappingFunction(key: String, value: Option[Int], state: State[Int]): Option[String] = {
+ * // Use state.exists(), state.get(), state.update() and state.remove()
+ * // to manage state, and return the necessary string
* }
*
- * val spec = StateSpec.function(trackingFunction).numPartitions(10)
+ * val spec = StateSpec.function(mappingFunction).numPartitions(10)
*
- * val trackStateDStream = keyValueDStream.trackStateByKey[Int, String](spec)
+ * val mapWithStateDStream = keyValueDStream.mapWithState[StateType, MappedType](spec)
* }}}
*
* @param spec Specification of this transformation
- * @tparam StateType Class type of the state
- * @tparam EmittedType Class type of the tranformed data return by the tracking function
+ * @tparam StateType Class type of the state data
+ * @tparam MappedType Class type of the mapped data
*/
@Experimental
- def trackStateByKey[StateType: ClassTag, EmittedType: ClassTag](
- spec: StateSpec[K, V, StateType, EmittedType]
- ): TrackStateDStream[K, V, StateType, EmittedType] = {
- new TrackStateDStreamImpl[K, V, StateType, EmittedType](
+ def mapWithState[StateType: ClassTag, MappedType: ClassTag](
+ spec: StateSpec[K, V, StateType, MappedType]
+ ): MapWithStateDStream[K, V, StateType, MappedType] = {
+ new MapWithStateDStreamImpl[K, V, StateType, MappedType](
self,
- spec.asInstanceOf[StateSpecImpl[K, V, StateType, EmittedType]]
+ spec.asInstanceOf[StateSpecImpl[K, V, StateType, MappedType]]
)
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala
index 30aafcf146..ed95171f73 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala
@@ -29,60 +29,60 @@ import org.apache.spark.util.Utils
import org.apache.spark._
/**
- * Record storing the keyed-state [[TrackStateRDD]]. Each record contains a [[StateMap]] and a
- * sequence of records returned by the tracking function of `trackStateByKey`.
+ * Record storing the keyed-state [[MapWithStateRDD]]. Each record contains a [[StateMap]] and a
+ * sequence of records returned by the mapping function of `mapWithState`.
*/
-private[streaming] case class TrackStateRDDRecord[K, S, E](
- var stateMap: StateMap[K, S], var emittedRecords: Seq[E])
+private[streaming] case class MapWithStateRDDRecord[K, S, E](
+ var stateMap: StateMap[K, S], var mappedData: Seq[E])
-private[streaming] object TrackStateRDDRecord {
+private[streaming] object MapWithStateRDDRecord {
def updateRecordWithData[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
- prevRecord: Option[TrackStateRDDRecord[K, S, E]],
+ prevRecord: Option[MapWithStateRDDRecord[K, S, E]],
dataIterator: Iterator[(K, V)],
- updateFunction: (Time, K, Option[V], State[S]) => Option[E],
+ mappingFunction: (Time, K, Option[V], State[S]) => Option[E],
batchTime: Time,
timeoutThresholdTime: Option[Long],
removeTimedoutData: Boolean
- ): TrackStateRDDRecord[K, S, E] = {
+ ): MapWithStateRDDRecord[K, S, E] = {
// Create a new state map by cloning the previous one (if it exists) or by creating an empty one
val newStateMap = prevRecord.map { _.stateMap.copy() }. getOrElse { new EmptyStateMap[K, S]() }
- val emittedRecords = new ArrayBuffer[E]
+ val mappedData = new ArrayBuffer[E]
val wrappedState = new StateImpl[S]()
- // Call the tracking function on each record in the data iterator, and accordingly
- // update the states touched, and collect the data returned by the tracking function
+ // Call the mapping function on each record in the data iterator, and accordingly
+ // update the states touched, and collect the data returned by the mapping function
dataIterator.foreach { case (key, value) =>
wrappedState.wrap(newStateMap.get(key))
- val emittedRecord = updateFunction(batchTime, key, Some(value), wrappedState)
+ val returned = mappingFunction(batchTime, key, Some(value), wrappedState)
if (wrappedState.isRemoved) {
newStateMap.remove(key)
} else if (wrappedState.isUpdated || timeoutThresholdTime.isDefined) {
newStateMap.put(key, wrappedState.get(), batchTime.milliseconds)
}
- emittedRecords ++= emittedRecord
+ mappedData ++= returned
}
- // Get the timed out state records, call the tracking function on each and collect the
+ // Get the timed out state records, call the mapping function on each and collect the
// data returned
if (removeTimedoutData && timeoutThresholdTime.isDefined) {
newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) =>
wrappedState.wrapTiminoutState(state)
- val emittedRecord = updateFunction(batchTime, key, None, wrappedState)
- emittedRecords ++= emittedRecord
+ val returned = mappingFunction(batchTime, key, None, wrappedState)
+ mappedData ++= returned
newStateMap.remove(key)
}
}
- TrackStateRDDRecord(newStateMap, emittedRecords)
+ MapWithStateRDDRecord(newStateMap, mappedData)
}
}
/**
- * Partition of the [[TrackStateRDD]], which depends on corresponding partitions of prev state
+ * Partition of the [[MapWithStateRDD]], which depends on corresponding partitions of prev state
* RDD, and a partitioned keyed-data RDD
*/
-private[streaming] class TrackStateRDDPartition(
+private[streaming] class MapWithStateRDDPartition(
idx: Int,
@transient private var prevStateRDD: RDD[_],
@transient private var partitionedDataRDD: RDD[_]) extends Partition {
@@ -104,27 +104,28 @@ private[streaming] class TrackStateRDDPartition(
/**
- * RDD storing the keyed-state of `trackStateByKey` and corresponding emitted records.
- * Each partition of this RDD has a single record of type [[TrackStateRDDRecord]]. This contains a
- * [[StateMap]] (containing the keyed-states) and the sequence of records returned by the tracking
- * function of `trackStateByKey`.
- * @param prevStateRDD The previous TrackStateRDD on whose StateMap data `this` RDD will be created
+ * RDD storing the keyed states of `mapWithState` operation and corresponding mapped data.
+ * Each partition of this RDD has a single record of type [[MapWithStateRDDRecord]]. This contains a
+ * [[StateMap]] (containing the keyed-states) and the sequence of records returned by the mapping
+ * function of `mapWithState`.
+ * @param prevStateRDD The previous MapWithStateRDD on whose StateMap data `this` RDD
+ * will be created
* @param partitionedDataRDD The partitioned data RDD which is used update the previous StateMaps
* in the `prevStateRDD` to create `this` RDD
- * @param trackingFunction The function that will be used to update state and return new data
+ * @param mappingFunction The function that will be used to update state and return new data
* @param batchTime The time of the batch to which this RDD belongs to. Use to update
* @param timeoutThresholdTime The time to indicate which keys are timeout
*/
-private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
- private var prevStateRDD: RDD[TrackStateRDDRecord[K, S, E]],
+private[streaming] class MapWithStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
+ private var prevStateRDD: RDD[MapWithStateRDDRecord[K, S, E]],
private var partitionedDataRDD: RDD[(K, V)],
- trackingFunction: (Time, K, Option[V], State[S]) => Option[E],
+ mappingFunction: (Time, K, Option[V], State[S]) => Option[E],
batchTime: Time,
timeoutThresholdTime: Option[Long]
- ) extends RDD[TrackStateRDDRecord[K, S, E]](
+ ) extends RDD[MapWithStateRDDRecord[K, S, E]](
partitionedDataRDD.sparkContext,
List(
- new OneToOneDependency[TrackStateRDDRecord[K, S, E]](prevStateRDD),
+ new OneToOneDependency[MapWithStateRDDRecord[K, S, E]](prevStateRDD),
new OneToOneDependency(partitionedDataRDD))
) {
@@ -141,19 +142,19 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E:
}
override def compute(
- partition: Partition, context: TaskContext): Iterator[TrackStateRDDRecord[K, S, E]] = {
+ partition: Partition, context: TaskContext): Iterator[MapWithStateRDDRecord[K, S, E]] = {
- val stateRDDPartition = partition.asInstanceOf[TrackStateRDDPartition]
+ val stateRDDPartition = partition.asInstanceOf[MapWithStateRDDPartition]
val prevStateRDDIterator = prevStateRDD.iterator(
stateRDDPartition.previousSessionRDDPartition, context)
val dataIterator = partitionedDataRDD.iterator(
stateRDDPartition.partitionedDataRDDPartition, context)
val prevRecord = if (prevStateRDDIterator.hasNext) Some(prevStateRDDIterator.next()) else None
- val newRecord = TrackStateRDDRecord.updateRecordWithData(
+ val newRecord = MapWithStateRDDRecord.updateRecordWithData(
prevRecord,
dataIterator,
- trackingFunction,
+ mappingFunction,
batchTime,
timeoutThresholdTime,
removeTimedoutData = doFullScan // remove timedout data only when full scan is enabled
@@ -163,7 +164,7 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E:
override protected def getPartitions: Array[Partition] = {
Array.tabulate(prevStateRDD.partitions.length) { i =>
- new TrackStateRDDPartition(i, prevStateRDD, partitionedDataRDD)}
+ new MapWithStateRDDPartition(i, prevStateRDD, partitionedDataRDD)}
}
override def clearDependencies(): Unit = {
@@ -177,52 +178,46 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E:
}
}
-private[streaming] object TrackStateRDD {
+private[streaming] object MapWithStateRDD {
def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
pairRDD: RDD[(K, S)],
partitioner: Partitioner,
- updateTime: Time): TrackStateRDD[K, V, S, E] = {
+ updateTime: Time): MapWithStateRDD[K, V, S, E] = {
- val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions ({ iterator =>
+ val stateRDD = pairRDD.partitionBy(partitioner).mapPartitions ({ iterator =>
val stateMap = StateMap.create[K, S](SparkEnv.get.conf)
iterator.foreach { case (key, state) => stateMap.put(key, state, updateTime.milliseconds) }
- Iterator(TrackStateRDDRecord(stateMap, Seq.empty[E]))
+ Iterator(MapWithStateRDDRecord(stateMap, Seq.empty[E]))
}, preservesPartitioning = true)
val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner)
val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None
- new TrackStateRDD[K, V, S, E](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None)
+ new MapWithStateRDD[K, V, S, E](
+ stateRDD, emptyDataRDD, noOpFunc, updateTime, None)
}
def createFromRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
rdd: RDD[(K, S, Long)],
partitioner: Partitioner,
- updateTime: Time): TrackStateRDD[K, V, S, E] = {
+ updateTime: Time): MapWithStateRDD[K, V, S, E] = {
val pairRDD = rdd.map { x => (x._1, (x._2, x._3)) }
- val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions({ iterator =>
+ val stateRDD = pairRDD.partitionBy(partitioner).mapPartitions({ iterator =>
val stateMap = StateMap.create[K, S](SparkEnv.get.conf)
iterator.foreach { case (key, (state, updateTime)) =>
stateMap.put(key, state, updateTime)
}
- Iterator(TrackStateRDDRecord(stateMap, Seq.empty[E]))
+ Iterator(MapWithStateRDDRecord(stateMap, Seq.empty[E]))
}, preservesPartitioning = true)
val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner)
val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None
- new TrackStateRDD[K, V, S, E](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None)
- }
-}
-
-private[streaming] class EmittedRecordsRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
- parent: TrackStateRDD[K, V, S, T]) extends RDD[T](parent) {
- override protected def getPartitions: Array[Partition] = parent.partitions
- override def compute(partition: Partition, context: TaskContext): Iterator[T] = {
- parent.compute(partition, context).flatMap { _.emittedRecords }
+ new MapWithStateRDD[K, V, S, E](
+ stateRDD, emptyDataRDD, noOpFunc, updateTime, None)
}
}