diff options
Diffstat (limited to 'streaming/src/main')
-rw-r--r-- | streaming/src/main/scala/org/apache/spark/streaming/State.scala | 20 | ||||
-rw-r--r-- | streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala | 160 | ||||
-rw-r--r-- | streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaMapWithStateDStream.scala (renamed from streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala) | 20 | ||||
-rw-r--r-- | streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala | 50 | ||||
-rw-r--r-- | streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala (renamed from streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala) | 61 | ||||
-rw-r--r-- | streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala | 41 | ||||
-rw-r--r-- | streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala (renamed from streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala) | 99 |
7 files changed, 230 insertions, 221 deletions
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/State.scala b/streaming/src/main/scala/org/apache/spark/streaming/State.scala index 604e64fc61..b47bdda2c2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/State.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/State.scala @@ -23,14 +23,14 @@ import org.apache.spark.annotation.Experimental /** * :: Experimental :: - * Abstract class for getting and updating the tracked state in the `trackStateByKey` operation of - * a [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a - * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java). + * Abstract class for getting and updating the state in mapping function used in the `mapWithState` + * operation of a [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) + * or a [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java). * * Scala example of using `State`: * {{{ - * // A tracking function that maintains an integer state and return a String - * def trackStateFunc(data: Option[Int], state: State[Int]): Option[String] = { + * // A mapping function that maintains an integer state and returns a String + * def mappingFunction(key: String, value: Option[Int], state: State[Int]): Option[String] = { * // Check if state exists * if (state.exists) { * val existingState = state.get // Get the existing state @@ -52,12 +52,12 @@ import org.apache.spark.annotation.Experimental * * Java example of using `State`: * {{{ - * // A tracking function that maintains an integer state and return a String - * Function2<Optional<Integer>, State<Integer>, Optional<String>> trackStateFunc = - * new Function2<Optional<Integer>, State<Integer>, Optional<String>>() { + * // A mapping function that maintains an integer state and returns a String + * Function3<String, Optional<Integer>, State<Integer>, String> mappingFunction = + * new Function3<String, Optional<Integer>, State<Integer>, String>() { * * @Override - * public Optional<String> call(Optional<Integer> one, State<Integer> state) { + * public String call(String key, Optional<Integer> value, State<Integer> state) { * if (state.exists()) { * int existingState = state.get(); // Get the existing state * boolean shouldRemove = ...; // Decide whether to remove the state @@ -75,6 +75,8 @@ import org.apache.spark.annotation.Experimental * } * }; * }}} + * + * @tparam S Class of the state */ @Experimental sealed abstract class State[S] { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala index bea5b9df20..9f6f95223f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming import com.google.common.base.Optional import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.{JavaPairRDD, JavaUtils} -import org.apache.spark.api.java.function.{Function2 => JFunction2, Function4 => JFunction4} +import org.apache.spark.api.java.function.{Function3 => JFunction3, Function4 => JFunction4} import org.apache.spark.rdd.RDD import org.apache.spark.util.ClosureCleaner import org.apache.spark.{HashPartitioner, Partitioner} @@ -28,7 +28,7 @@ import org.apache.spark.{HashPartitioner, Partitioner} /** * :: Experimental :: * Abstract class representing all the specifications of the DStream transformation - * `trackStateByKey` operation of a + * `mapWithState` operation of a * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java). * Use the [[org.apache.spark.streaming.StateSpec StateSpec.apply()]] or @@ -37,50 +37,63 @@ import org.apache.spark.{HashPartitioner, Partitioner} * * Example in Scala: * {{{ - * def trackingFunction(data: Option[ValueType], wrappedState: State[StateType]): EmittedType = { - * ... + * // A mapping function that maintains an integer state and return a String + * def mappingFunction(key: String, value: Option[Int], state: State[Int]): Option[String] = { + * // Use state.exists(), state.get(), state.update() and state.remove() + * // to manage state, and return the necessary string * } * - * val spec = StateSpec.function(trackingFunction).numPartitions(10) + * val spec = StateSpec.function(mappingFunction).numPartitions(10) * - * val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType](spec) + * val mapWithStateDStream = keyValueDStream.mapWithState[StateType, MappedType](spec) * }}} * * Example in Java: * {{{ - * StateSpec<KeyType, ValueType, StateType, EmittedDataType> spec = - * StateSpec.<KeyType, ValueType, StateType, EmittedDataType>function(trackingFunction) - * .numPartition(10); + * // A mapping function that maintains an integer state and return a string + * Function3<String, Optional<Integer>, State<Integer>, String> mappingFunction = + * new Function3<String, Optional<Integer>, State<Integer>, String>() { + * @Override + * public Optional<String> call(Optional<Integer> value, State<Integer> state) { + * // Use state.exists(), state.get(), state.update() and state.remove() + * // to manage state, and return the necessary string + * } + * }; * - * JavaTrackStateDStream<KeyType, ValueType, StateType, EmittedType> emittedRecordDStream = - * javaPairDStream.<StateType, EmittedDataType>trackStateByKey(spec); + * JavaMapWithStateDStream<String, Integer, Integer, String> mapWithStateDStream = + * keyValueDStream.mapWithState(StateSpec.function(mappingFunc)); * }}} + * + * @tparam KeyType Class of the state key + * @tparam ValueType Class of the state value + * @tparam StateType Class of the state data + * @tparam MappedType Class of the mapped elements */ @Experimental -sealed abstract class StateSpec[KeyType, ValueType, StateType, EmittedType] extends Serializable { +sealed abstract class StateSpec[KeyType, ValueType, StateType, MappedType] extends Serializable { - /** Set the RDD containing the initial states that will be used by `trackStateByKey` */ + /** Set the RDD containing the initial states that will be used by `mapWithState` */ def initialState(rdd: RDD[(KeyType, StateType)]): this.type - /** Set the RDD containing the initial states that will be used by `trackStateByKey` */ + /** Set the RDD containing the initial states that will be used by `mapWithState` */ def initialState(javaPairRDD: JavaPairRDD[KeyType, StateType]): this.type /** - * Set the number of partitions by which the state RDDs generated by `trackStateByKey` + * Set the number of partitions by which the state RDDs generated by `mapWithState` * will be partitioned. Hash partitioning will be used. */ def numPartitions(numPartitions: Int): this.type /** - * Set the partitioner by which the state RDDs generated by `trackStateByKey` will be + * Set the partitioner by which the state RDDs generated by `mapWithState` will be * be partitioned. */ def partitioner(partitioner: Partitioner): this.type /** * Set the duration after which the state of an idle key will be removed. A key and its state is - * considered idle if it has not received any data for at least the given duration. The state - * tracking function will be called one final time on the idle states that are going to be + * considered idle if it has not received any data for at least the given duration. The + * mapping function will be called one final time on the idle states that are going to be * removed; [[org.apache.spark.streaming.State State.isTimingOut()]] set * to `true` in that call. */ @@ -91,115 +104,124 @@ sealed abstract class StateSpec[KeyType, ValueType, StateType, EmittedType] exte /** * :: Experimental :: * Builder object for creating instances of [[org.apache.spark.streaming.StateSpec StateSpec]] - * that is used for specifying the parameters of the DStream transformation `trackStateByKey` + * that is used for specifying the parameters of the DStream transformation `mapWithState` * that is used for specifying the parameters of the DStream transformation - * `trackStateByKey` operation of a + * `mapWithState` operation of a * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java). * * Example in Scala: * {{{ - * def trackingFunction(data: Option[ValueType], wrappedState: State[StateType]): EmittedType = { - * ... + * // A mapping function that maintains an integer state and return a String + * def mappingFunction(key: String, value: Option[Int], state: State[Int]): Option[String] = { + * // Use state.exists(), state.get(), state.update() and state.remove() + * // to manage state, and return the necessary string * } * - * val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType]( - * StateSpec.function(trackingFunction).numPartitions(10)) + * val spec = StateSpec.function(mappingFunction).numPartitions(10) + * + * val mapWithStateDStream = keyValueDStream.mapWithState[StateType, MappedType](spec) * }}} * * Example in Java: * {{{ - * StateSpec<KeyType, ValueType, StateType, EmittedDataType> spec = - * StateSpec.<KeyType, ValueType, StateType, EmittedDataType>function(trackingFunction) - * .numPartition(10); + * // A mapping function that maintains an integer state and return a string + * Function3<String, Optional<Integer>, State<Integer>, String> mappingFunction = + * new Function3<String, Optional<Integer>, State<Integer>, String>() { + * @Override + * public Optional<String> call(Optional<Integer> value, State<Integer> state) { + * // Use state.exists(), state.get(), state.update() and state.remove() + * // to manage state, and return the necessary string + * } + * }; * - * JavaTrackStateDStream<KeyType, ValueType, StateType, EmittedType> emittedRecordDStream = - * javaPairDStream.<StateType, EmittedDataType>trackStateByKey(spec); - * }}} + * JavaMapWithStateDStream<String, Integer, Integer, String> mapWithStateDStream = + * keyValueDStream.mapWithState(StateSpec.function(mappingFunc)); + *}}} */ @Experimental object StateSpec { /** * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications - * of the `trackStateByKey` operation on a + * of the `mapWithState` operation on a * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]]. * - * @param trackingFunction The function applied on every data item to manage the associated state - * and generate the emitted data + * @param mappingFunction The function applied on every data item to manage the associated state + * and generate the mapped data * @tparam KeyType Class of the keys * @tparam ValueType Class of the values * @tparam StateType Class of the states data - * @tparam EmittedType Class of the emitted data + * @tparam MappedType Class of the mapped data */ - def function[KeyType, ValueType, StateType, EmittedType]( - trackingFunction: (Time, KeyType, Option[ValueType], State[StateType]) => Option[EmittedType] - ): StateSpec[KeyType, ValueType, StateType, EmittedType] = { - ClosureCleaner.clean(trackingFunction, checkSerializable = true) - new StateSpecImpl(trackingFunction) + def function[KeyType, ValueType, StateType, MappedType]( + mappingFunction: (Time, KeyType, Option[ValueType], State[StateType]) => Option[MappedType] + ): StateSpec[KeyType, ValueType, StateType, MappedType] = { + ClosureCleaner.clean(mappingFunction, checkSerializable = true) + new StateSpecImpl(mappingFunction) } /** * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications - * of the `trackStateByKey` operation on a + * of the `mapWithState` operation on a * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]]. * - * @param trackingFunction The function applied on every data item to manage the associated state - * and generate the emitted data + * @param mappingFunction The function applied on every data item to manage the associated state + * and generate the mapped data * @tparam ValueType Class of the values * @tparam StateType Class of the states data - * @tparam EmittedType Class of the emitted data + * @tparam MappedType Class of the mapped data */ - def function[KeyType, ValueType, StateType, EmittedType]( - trackingFunction: (Option[ValueType], State[StateType]) => EmittedType - ): StateSpec[KeyType, ValueType, StateType, EmittedType] = { - ClosureCleaner.clean(trackingFunction, checkSerializable = true) + def function[KeyType, ValueType, StateType, MappedType]( + mappingFunction: (KeyType, Option[ValueType], State[StateType]) => MappedType + ): StateSpec[KeyType, ValueType, StateType, MappedType] = { + ClosureCleaner.clean(mappingFunction, checkSerializable = true) val wrappedFunction = - (time: Time, key: Any, value: Option[ValueType], state: State[StateType]) => { - Some(trackingFunction(value, state)) + (time: Time, key: KeyType, value: Option[ValueType], state: State[StateType]) => { + Some(mappingFunction(key, value, state)) } new StateSpecImpl(wrappedFunction) } /** * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all - * the specifications of the `trackStateByKey` operation on a + * the specifications of the `mapWithState` operation on a * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]]. * - * @param javaTrackingFunction The function applied on every data item to manage the associated - * state and generate the emitted data + * @param mappingFunction The function applied on every data item to manage the associated + * state and generate the mapped data * @tparam KeyType Class of the keys * @tparam ValueType Class of the values * @tparam StateType Class of the states data - * @tparam EmittedType Class of the emitted data + * @tparam MappedType Class of the mapped data */ - def function[KeyType, ValueType, StateType, EmittedType](javaTrackingFunction: - JFunction4[Time, KeyType, Optional[ValueType], State[StateType], Optional[EmittedType]]): - StateSpec[KeyType, ValueType, StateType, EmittedType] = { - val trackingFunc = (time: Time, k: KeyType, v: Option[ValueType], s: State[StateType]) => { - val t = javaTrackingFunction.call(time, k, JavaUtils.optionToOptional(v), s) + def function[KeyType, ValueType, StateType, MappedType](mappingFunction: + JFunction4[Time, KeyType, Optional[ValueType], State[StateType], Optional[MappedType]]): + StateSpec[KeyType, ValueType, StateType, MappedType] = { + val wrappedFunc = (time: Time, k: KeyType, v: Option[ValueType], s: State[StateType]) => { + val t = mappingFunction.call(time, k, JavaUtils.optionToOptional(v), s) Option(t.orNull) } - StateSpec.function(trackingFunc) + StateSpec.function(wrappedFunc) } /** * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications - * of the `trackStateByKey` operation on a + * of the `mapWithState` operation on a * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]]. * - * @param javaTrackingFunction The function applied on every data item to manage the associated - * state and generate the emitted data + * @param mappingFunction The function applied on every data item to manage the associated + * state and generate the mapped data * @tparam ValueType Class of the values * @tparam StateType Class of the states data - * @tparam EmittedType Class of the emitted data + * @tparam MappedType Class of the mapped data */ - def function[KeyType, ValueType, StateType, EmittedType]( - javaTrackingFunction: JFunction2[Optional[ValueType], State[StateType], EmittedType]): - StateSpec[KeyType, ValueType, StateType, EmittedType] = { - val trackingFunc = (v: Option[ValueType], s: State[StateType]) => { - javaTrackingFunction.call(Optional.fromNullable(v.get), s) + def function[KeyType, ValueType, StateType, MappedType]( + mappingFunction: JFunction3[KeyType, Optional[ValueType], State[StateType], MappedType]): + StateSpec[KeyType, ValueType, StateType, MappedType] = { + val wrappedFunc = (k: KeyType, v: Option[ValueType], s: State[StateType]) => { + mappingFunction.call(k, Optional.fromNullable(v.get), s) } - StateSpec.function(trackingFunc) + StateSpec.function(wrappedFunc) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaMapWithStateDStream.scala index f459930d06..16c0d6fff8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaMapWithStateDStream.scala @@ -19,23 +19,23 @@ package org.apache.spark.streaming.api.java import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaSparkContext -import org.apache.spark.streaming.dstream.TrackStateDStream +import org.apache.spark.streaming.dstream.MapWithStateDStream /** * :: Experimental :: - * [[JavaDStream]] representing the stream of records emitted by the tracking function in the - * `trackStateByKey` operation on a [[JavaPairDStream]]. Additionally, it also gives access to the + * DStream representing the stream of data generated by `mapWithState` operation on a + * [[JavaPairDStream]]. Additionally, it also gives access to the * stream of state snapshots, that is, the state data of all keys after a batch has updated them. * - * @tparam KeyType Class of the state key - * @tparam ValueType Class of the state value - * @tparam StateType Class of the state - * @tparam EmittedType Class of the emitted records + * @tparam KeyType Class of the keys + * @tparam ValueType Class of the values + * @tparam StateType Class of the state data + * @tparam MappedType Class of the mapped data */ @Experimental -class JavaTrackStateDStream[KeyType, ValueType, StateType, EmittedType]( - dstream: TrackStateDStream[KeyType, ValueType, StateType, EmittedType]) - extends JavaDStream[EmittedType](dstream)(JavaSparkContext.fakeClassTag) { +class JavaMapWithStateDStream[KeyType, ValueType, StateType, MappedType] private[streaming]( + dstream: MapWithStateDStream[KeyType, ValueType, StateType, MappedType]) + extends JavaDStream[MappedType](dstream)(JavaSparkContext.fakeClassTag) { def stateSnapshots(): JavaPairDStream[KeyType, StateType] = new JavaPairDStream(dstream.stateSnapshots())( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index 70e32b383e..42ddd63f0f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -430,42 +430,36 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( /** * :: Experimental :: - * Return a new [[JavaDStream]] of data generated by combining the key-value data in `this` stream - * with a continuously updated per-key state. The user-provided state tracking function is - * applied on each keyed data item along with its corresponding state. The function can choose to - * update/remove the state and return a transformed data, which forms the - * [[JavaTrackStateDStream]]. + * Return a [[JavaMapWithStateDStream]] by applying a function to every key-value element of + * `this` stream, while maintaining some state data for each unique key. The mapping function + * and other specification (e.g. partitioners, timeouts, initial state data, etc.) of this + * transformation can be specified using [[StateSpec]] class. The state data is accessible in + * as a parameter of type [[State]] in the mapping function. * - * The specifications of this transformation is made through the - * [[org.apache.spark.streaming.StateSpec StateSpec]] class. Besides the tracking function, there - * are a number of optional parameters - initial state data, number of partitions, timeouts, etc. - * See the [[org.apache.spark.streaming.StateSpec StateSpec]] for more details. - * - * Example of using `trackStateByKey`: + * Example of using `mapWithState`: * {{{ - * // A tracking function that maintains an integer state and return a String - * Function2<Optional<Integer>, State<Integer>, Optional<String>> trackStateFunc = - * new Function2<Optional<Integer>, State<Integer>, Optional<String>>() { - * - * @Override - * public Optional<String> call(Optional<Integer> one, State<Integer> state) { - * // Check if state exists, accordingly update/remove state and return transformed data - * } + * // A mapping function that maintains an integer state and return a string + * Function3<String, Optional<Integer>, State<Integer>, String> mappingFunction = + * new Function3<String, Optional<Integer>, State<Integer>, String>() { + * @Override + * public Optional<String> call(Optional<Integer> value, State<Integer> state) { + * // Use state.exists(), state.get(), state.update() and state.remove() + * // to manage state, and return the necessary string + * } * }; * - * JavaTrackStateDStream<Integer, Integer, Integer, String> trackStateDStream = - * keyValueDStream.<Integer, String>trackStateByKey( - * StateSpec.function(trackStateFunc).numPartitions(10)); - * }}} + * JavaMapWithStateDStream<String, Integer, Integer, String> mapWithStateDStream = + * keyValueDStream.mapWithState(StateSpec.function(mappingFunc)); + *}}} * * @param spec Specification of this transformation - * @tparam StateType Class type of the state - * @tparam EmittedType Class type of the tranformed data return by the tracking function + * @tparam StateType Class type of the state data + * @tparam MappedType Class type of the mapped data */ @Experimental - def trackStateByKey[StateType, EmittedType](spec: StateSpec[K, V, StateType, EmittedType]): - JavaTrackStateDStream[K, V, StateType, EmittedType] = { - new JavaTrackStateDStream(dstream.trackStateByKey(spec)( + def mapWithState[StateType, MappedType](spec: StateSpec[K, V, StateType, MappedType]): + JavaMapWithStateDStream[K, V, StateType, MappedType] = { + new JavaMapWithStateDStream(dstream.mapWithState(spec)( JavaSparkContext.fakeClassTag, JavaSparkContext.fakeClassTag)) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala index ea6213420e..706465d4e2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala @@ -24,53 +24,52 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.{EmptyRDD, RDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ -import org.apache.spark.streaming.rdd.{TrackStateRDD, TrackStateRDDRecord} -import org.apache.spark.streaming.dstream.InternalTrackStateDStream._ +import org.apache.spark.streaming.rdd.{MapWithStateRDD, MapWithStateRDDRecord} +import org.apache.spark.streaming.dstream.InternalMapWithStateDStream._ /** * :: Experimental :: - * DStream representing the stream of records emitted by the tracking function in the - * `trackStateByKey` operation on a + * DStream representing the stream of data generated by `mapWithState` operation on a * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]]. * Additionally, it also gives access to the stream of state snapshots, that is, the state data of * all keys after a batch has updated them. * - * @tparam KeyType Class of the state key - * @tparam ValueType Class of the state value + * @tparam KeyType Class of the key + * @tparam ValueType Class of the value * @tparam StateType Class of the state data - * @tparam EmittedType Class of the emitted records + * @tparam MappedType Class of the mapped data */ @Experimental -sealed abstract class TrackStateDStream[KeyType, ValueType, StateType, EmittedType: ClassTag]( - ssc: StreamingContext) extends DStream[EmittedType](ssc) { +sealed abstract class MapWithStateDStream[KeyType, ValueType, StateType, MappedType: ClassTag]( + ssc: StreamingContext) extends DStream[MappedType](ssc) { /** Return a pair DStream where each RDD is the snapshot of the state of all the keys. */ def stateSnapshots(): DStream[(KeyType, StateType)] } -/** Internal implementation of the [[TrackStateDStream]] */ -private[streaming] class TrackStateDStreamImpl[ - KeyType: ClassTag, ValueType: ClassTag, StateType: ClassTag, EmittedType: ClassTag]( +/** Internal implementation of the [[MapWithStateDStream]] */ +private[streaming] class MapWithStateDStreamImpl[ + KeyType: ClassTag, ValueType: ClassTag, StateType: ClassTag, MappedType: ClassTag]( dataStream: DStream[(KeyType, ValueType)], - spec: StateSpecImpl[KeyType, ValueType, StateType, EmittedType]) - extends TrackStateDStream[KeyType, ValueType, StateType, EmittedType](dataStream.context) { + spec: StateSpecImpl[KeyType, ValueType, StateType, MappedType]) + extends MapWithStateDStream[KeyType, ValueType, StateType, MappedType](dataStream.context) { private val internalStream = - new InternalTrackStateDStream[KeyType, ValueType, StateType, EmittedType](dataStream, spec) + new InternalMapWithStateDStream[KeyType, ValueType, StateType, MappedType](dataStream, spec) override def slideDuration: Duration = internalStream.slideDuration override def dependencies: List[DStream[_]] = List(internalStream) - override def compute(validTime: Time): Option[RDD[EmittedType]] = { - internalStream.getOrCompute(validTime).map { _.flatMap[EmittedType] { _.emittedRecords } } + override def compute(validTime: Time): Option[RDD[MappedType]] = { + internalStream.getOrCompute(validTime).map { _.flatMap[MappedType] { _.mappedData } } } /** * Forward the checkpoint interval to the internal DStream that computes the state maps. This * to make sure that this DStream does not get checkpointed, only the internal stream. */ - override def checkpoint(checkpointInterval: Duration): DStream[EmittedType] = { + override def checkpoint(checkpointInterval: Duration): DStream[MappedType] = { internalStream.checkpoint(checkpointInterval) this } @@ -87,32 +86,32 @@ private[streaming] class TrackStateDStreamImpl[ def stateClass: Class[_] = implicitly[ClassTag[StateType]].runtimeClass - def emittedClass: Class[_] = implicitly[ClassTag[EmittedType]].runtimeClass + def mappedClass: Class[_] = implicitly[ClassTag[MappedType]].runtimeClass } /** * A DStream that allows per-key state to be maintains, and arbitrary records to be generated - * based on updates to the state. This is the main DStream that implements the `trackStateByKey` + * based on updates to the state. This is the main DStream that implements the `mapWithState` * operation on DStreams. * * @param parent Parent (key, value) stream that is the source - * @param spec Specifications of the trackStateByKey operation + * @param spec Specifications of the mapWithState operation * @tparam K Key type * @tparam V Value type * @tparam S Type of the state maintained - * @tparam E Type of the emitted data + * @tparam E Type of the mapped data */ private[streaming] -class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag]( +class InternalMapWithStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag]( parent: DStream[(K, V)], spec: StateSpecImpl[K, V, S, E]) - extends DStream[TrackStateRDDRecord[K, S, E]](parent.context) { + extends DStream[MapWithStateRDDRecord[K, S, E]](parent.context) { persist(StorageLevel.MEMORY_ONLY) private val partitioner = spec.getPartitioner().getOrElse( new HashPartitioner(ssc.sc.defaultParallelism)) - private val trackingFunction = spec.getFunction() + private val mappingFunction = spec.getFunction() override def slideDuration: Duration = parent.slideDuration @@ -130,7 +129,7 @@ class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassT } /** Method that generates a RDD for the given time */ - override def compute(validTime: Time): Option[RDD[TrackStateRDDRecord[K, S, E]]] = { + override def compute(validTime: Time): Option[RDD[MapWithStateRDDRecord[K, S, E]]] = { // Get the previous state or create a new empty state RDD val prevStateRDD = getOrCompute(validTime - slideDuration) match { case Some(rdd) => @@ -138,13 +137,13 @@ class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassT // If the RDD is not partitioned the right way, let us repartition it using the // partition index as the key. This is to ensure that state RDD is always partitioned // before creating another state RDD using it - TrackStateRDD.createFromRDD[K, V, S, E]( + MapWithStateRDD.createFromRDD[K, V, S, E]( rdd.flatMap { _.stateMap.getAll() }, partitioner, validTime) } else { rdd } case None => - TrackStateRDD.createFromPairRDD[K, V, S, E]( + MapWithStateRDD.createFromPairRDD[K, V, S, E]( spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)), partitioner, validTime @@ -161,11 +160,11 @@ class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassT val timeoutThresholdTime = spec.getTimeoutInterval().map { interval => (validTime - interval).milliseconds } - Some(new TrackStateRDD( - prevStateRDD, partitionedDataRDD, trackingFunction, validTime, timeoutThresholdTime)) + Some(new MapWithStateRDD( + prevStateRDD, partitionedDataRDD, mappingFunction, validTime, timeoutThresholdTime)) } } -private[streaming] object InternalTrackStateDStream { +private[streaming] object InternalMapWithStateDStream { private val DEFAULT_CHECKPOINT_DURATION_MULTIPLIER = 10 } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index 2762309134..a64a1fe93f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -352,39 +352,36 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)]) /** * :: Experimental :: - * Return a new DStream of data generated by combining the key-value data in `this` stream - * with a continuously updated per-key state. The user-provided state tracking function is - * applied on each keyed data item along with its corresponding state. The function can choose to - * update/remove the state and return a transformed data, which forms the - * [[org.apache.spark.streaming.dstream.TrackStateDStream]]. + * Return a [[MapWithStateDStream]] by applying a function to every key-value element of + * `this` stream, while maintaining some state data for each unique key. The mapping function + * and other specification (e.g. partitioners, timeouts, initial state data, etc.) of this + * transformation can be specified using [[StateSpec]] class. The state data is accessible in + * as a parameter of type [[State]] in the mapping function. * - * The specifications of this transformation is made through the - * [[org.apache.spark.streaming.StateSpec StateSpec]] class. Besides the tracking function, there - * are a number of optional parameters - initial state data, number of partitions, timeouts, etc. - * See the [[org.apache.spark.streaming.StateSpec StateSpec spec docs]] for more details. - * - * Example of using `trackStateByKey`: + * Example of using `mapWithState`: * {{{ - * def trackingFunction(data: Option[Int], wrappedState: State[Int]): String = { - * // Check if state exists, accordingly update/remove state and return transformed data + * // A mapping function that maintains an integer state and return a String + * def mappingFunction(key: String, value: Option[Int], state: State[Int]): Option[String] = { + * // Use state.exists(), state.get(), state.update() and state.remove() + * // to manage state, and return the necessary string * } * - * val spec = StateSpec.function(trackingFunction).numPartitions(10) + * val spec = StateSpec.function(mappingFunction).numPartitions(10) * - * val trackStateDStream = keyValueDStream.trackStateByKey[Int, String](spec) + * val mapWithStateDStream = keyValueDStream.mapWithState[StateType, MappedType](spec) * }}} * * @param spec Specification of this transformation - * @tparam StateType Class type of the state - * @tparam EmittedType Class type of the tranformed data return by the tracking function + * @tparam StateType Class type of the state data + * @tparam MappedType Class type of the mapped data */ @Experimental - def trackStateByKey[StateType: ClassTag, EmittedType: ClassTag]( - spec: StateSpec[K, V, StateType, EmittedType] - ): TrackStateDStream[K, V, StateType, EmittedType] = { - new TrackStateDStreamImpl[K, V, StateType, EmittedType]( + def mapWithState[StateType: ClassTag, MappedType: ClassTag]( + spec: StateSpec[K, V, StateType, MappedType] + ): MapWithStateDStream[K, V, StateType, MappedType] = { + new MapWithStateDStreamImpl[K, V, StateType, MappedType]( self, - spec.asInstanceOf[StateSpecImpl[K, V, StateType, EmittedType]] + spec.asInstanceOf[StateSpecImpl[K, V, StateType, MappedType]] ) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala index 30aafcf146..ed95171f73 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala @@ -29,60 +29,60 @@ import org.apache.spark.util.Utils import org.apache.spark._ /** - * Record storing the keyed-state [[TrackStateRDD]]. Each record contains a [[StateMap]] and a - * sequence of records returned by the tracking function of `trackStateByKey`. + * Record storing the keyed-state [[MapWithStateRDD]]. Each record contains a [[StateMap]] and a + * sequence of records returned by the mapping function of `mapWithState`. */ -private[streaming] case class TrackStateRDDRecord[K, S, E]( - var stateMap: StateMap[K, S], var emittedRecords: Seq[E]) +private[streaming] case class MapWithStateRDDRecord[K, S, E]( + var stateMap: StateMap[K, S], var mappedData: Seq[E]) -private[streaming] object TrackStateRDDRecord { +private[streaming] object MapWithStateRDDRecord { def updateRecordWithData[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag]( - prevRecord: Option[TrackStateRDDRecord[K, S, E]], + prevRecord: Option[MapWithStateRDDRecord[K, S, E]], dataIterator: Iterator[(K, V)], - updateFunction: (Time, K, Option[V], State[S]) => Option[E], + mappingFunction: (Time, K, Option[V], State[S]) => Option[E], batchTime: Time, timeoutThresholdTime: Option[Long], removeTimedoutData: Boolean - ): TrackStateRDDRecord[K, S, E] = { + ): MapWithStateRDDRecord[K, S, E] = { // Create a new state map by cloning the previous one (if it exists) or by creating an empty one val newStateMap = prevRecord.map { _.stateMap.copy() }. getOrElse { new EmptyStateMap[K, S]() } - val emittedRecords = new ArrayBuffer[E] + val mappedData = new ArrayBuffer[E] val wrappedState = new StateImpl[S]() - // Call the tracking function on each record in the data iterator, and accordingly - // update the states touched, and collect the data returned by the tracking function + // Call the mapping function on each record in the data iterator, and accordingly + // update the states touched, and collect the data returned by the mapping function dataIterator.foreach { case (key, value) => wrappedState.wrap(newStateMap.get(key)) - val emittedRecord = updateFunction(batchTime, key, Some(value), wrappedState) + val returned = mappingFunction(batchTime, key, Some(value), wrappedState) if (wrappedState.isRemoved) { newStateMap.remove(key) } else if (wrappedState.isUpdated || timeoutThresholdTime.isDefined) { newStateMap.put(key, wrappedState.get(), batchTime.milliseconds) } - emittedRecords ++= emittedRecord + mappedData ++= returned } - // Get the timed out state records, call the tracking function on each and collect the + // Get the timed out state records, call the mapping function on each and collect the // data returned if (removeTimedoutData && timeoutThresholdTime.isDefined) { newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) => wrappedState.wrapTiminoutState(state) - val emittedRecord = updateFunction(batchTime, key, None, wrappedState) - emittedRecords ++= emittedRecord + val returned = mappingFunction(batchTime, key, None, wrappedState) + mappedData ++= returned newStateMap.remove(key) } } - TrackStateRDDRecord(newStateMap, emittedRecords) + MapWithStateRDDRecord(newStateMap, mappedData) } } /** - * Partition of the [[TrackStateRDD]], which depends on corresponding partitions of prev state + * Partition of the [[MapWithStateRDD]], which depends on corresponding partitions of prev state * RDD, and a partitioned keyed-data RDD */ -private[streaming] class TrackStateRDDPartition( +private[streaming] class MapWithStateRDDPartition( idx: Int, @transient private var prevStateRDD: RDD[_], @transient private var partitionedDataRDD: RDD[_]) extends Partition { @@ -104,27 +104,28 @@ private[streaming] class TrackStateRDDPartition( /** - * RDD storing the keyed-state of `trackStateByKey` and corresponding emitted records. - * Each partition of this RDD has a single record of type [[TrackStateRDDRecord]]. This contains a - * [[StateMap]] (containing the keyed-states) and the sequence of records returned by the tracking - * function of `trackStateByKey`. - * @param prevStateRDD The previous TrackStateRDD on whose StateMap data `this` RDD will be created + * RDD storing the keyed states of `mapWithState` operation and corresponding mapped data. + * Each partition of this RDD has a single record of type [[MapWithStateRDDRecord]]. This contains a + * [[StateMap]] (containing the keyed-states) and the sequence of records returned by the mapping + * function of `mapWithState`. + * @param prevStateRDD The previous MapWithStateRDD on whose StateMap data `this` RDD + * will be created * @param partitionedDataRDD The partitioned data RDD which is used update the previous StateMaps * in the `prevStateRDD` to create `this` RDD - * @param trackingFunction The function that will be used to update state and return new data + * @param mappingFunction The function that will be used to update state and return new data * @param batchTime The time of the batch to which this RDD belongs to. Use to update * @param timeoutThresholdTime The time to indicate which keys are timeout */ -private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag]( - private var prevStateRDD: RDD[TrackStateRDDRecord[K, S, E]], +private[streaming] class MapWithStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag]( + private var prevStateRDD: RDD[MapWithStateRDDRecord[K, S, E]], private var partitionedDataRDD: RDD[(K, V)], - trackingFunction: (Time, K, Option[V], State[S]) => Option[E], + mappingFunction: (Time, K, Option[V], State[S]) => Option[E], batchTime: Time, timeoutThresholdTime: Option[Long] - ) extends RDD[TrackStateRDDRecord[K, S, E]]( + ) extends RDD[MapWithStateRDDRecord[K, S, E]]( partitionedDataRDD.sparkContext, List( - new OneToOneDependency[TrackStateRDDRecord[K, S, E]](prevStateRDD), + new OneToOneDependency[MapWithStateRDDRecord[K, S, E]](prevStateRDD), new OneToOneDependency(partitionedDataRDD)) ) { @@ -141,19 +142,19 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: } override def compute( - partition: Partition, context: TaskContext): Iterator[TrackStateRDDRecord[K, S, E]] = { + partition: Partition, context: TaskContext): Iterator[MapWithStateRDDRecord[K, S, E]] = { - val stateRDDPartition = partition.asInstanceOf[TrackStateRDDPartition] + val stateRDDPartition = partition.asInstanceOf[MapWithStateRDDPartition] val prevStateRDDIterator = prevStateRDD.iterator( stateRDDPartition.previousSessionRDDPartition, context) val dataIterator = partitionedDataRDD.iterator( stateRDDPartition.partitionedDataRDDPartition, context) val prevRecord = if (prevStateRDDIterator.hasNext) Some(prevStateRDDIterator.next()) else None - val newRecord = TrackStateRDDRecord.updateRecordWithData( + val newRecord = MapWithStateRDDRecord.updateRecordWithData( prevRecord, dataIterator, - trackingFunction, + mappingFunction, batchTime, timeoutThresholdTime, removeTimedoutData = doFullScan // remove timedout data only when full scan is enabled @@ -163,7 +164,7 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: override protected def getPartitions: Array[Partition] = { Array.tabulate(prevStateRDD.partitions.length) { i => - new TrackStateRDDPartition(i, prevStateRDD, partitionedDataRDD)} + new MapWithStateRDDPartition(i, prevStateRDD, partitionedDataRDD)} } override def clearDependencies(): Unit = { @@ -177,52 +178,46 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: } } -private[streaming] object TrackStateRDD { +private[streaming] object MapWithStateRDD { def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag]( pairRDD: RDD[(K, S)], partitioner: Partitioner, - updateTime: Time): TrackStateRDD[K, V, S, E] = { + updateTime: Time): MapWithStateRDD[K, V, S, E] = { - val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions ({ iterator => + val stateRDD = pairRDD.partitionBy(partitioner).mapPartitions ({ iterator => val stateMap = StateMap.create[K, S](SparkEnv.get.conf) iterator.foreach { case (key, state) => stateMap.put(key, state, updateTime.milliseconds) } - Iterator(TrackStateRDDRecord(stateMap, Seq.empty[E])) + Iterator(MapWithStateRDDRecord(stateMap, Seq.empty[E])) }, preservesPartitioning = true) val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner) val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None - new TrackStateRDD[K, V, S, E](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None) + new MapWithStateRDD[K, V, S, E]( + stateRDD, emptyDataRDD, noOpFunc, updateTime, None) } def createFromRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag]( rdd: RDD[(K, S, Long)], partitioner: Partitioner, - updateTime: Time): TrackStateRDD[K, V, S, E] = { + updateTime: Time): MapWithStateRDD[K, V, S, E] = { val pairRDD = rdd.map { x => (x._1, (x._2, x._3)) } - val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions({ iterator => + val stateRDD = pairRDD.partitionBy(partitioner).mapPartitions({ iterator => val stateMap = StateMap.create[K, S](SparkEnv.get.conf) iterator.foreach { case (key, (state, updateTime)) => stateMap.put(key, state, updateTime) } - Iterator(TrackStateRDDRecord(stateMap, Seq.empty[E])) + Iterator(MapWithStateRDDRecord(stateMap, Seq.empty[E])) }, preservesPartitioning = true) val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner) val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None - new TrackStateRDD[K, V, S, E](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None) - } -} - -private[streaming] class EmittedRecordsRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( - parent: TrackStateRDD[K, V, S, T]) extends RDD[T](parent) { - override protected def getPartitions: Array[Partition] = parent.partitions - override def compute(partition: Partition, context: TaskContext): Iterator[T] = { - parent.compute(partition, context).flatMap { _.emittedRecords } + new MapWithStateRDD[K, V, S, E]( + stateRDD, emptyDataRDD, noOpFunc, updateTime, None) } } |