aboutsummaryrefslogtreecommitdiff
path: root/streaming
diff options
context:
space:
mode:
Diffstat (limited to 'streaming')
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/State.scala193
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala212
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala46
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala142
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala188
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala337
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala314
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala494
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala193
9 files changed, 2115 insertions, 4 deletions
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/State.scala b/streaming/src/main/scala/org/apache/spark/streaming/State.scala
new file mode 100644
index 0000000000..7dd1b72f80
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/State.scala
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming
+
+import scala.language.implicitConversions
+
+import org.apache.spark.annotation.Experimental
+
+/**
+ * :: Experimental ::
+ * Abstract class for getting and updating the tracked state in the `trackStateByKey` operation of
+ * a [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a
+ * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
+ *
+ * Scala example of using `State`:
+ * {{{
+ * // A tracking function that maintains an integer state and return a String
+ * def trackStateFunc(data: Option[Int], state: State[Int]): Option[String] = {
+ * // Check if state exists
+ * if (state.exists) {
+ * val existingState = state.get // Get the existing state
+ * val shouldRemove = ... // Decide whether to remove the state
+ * if (shouldRemove) {
+ * state.remove() // Remove the state
+ * } else {
+ * val newState = ...
+ * state.update(newState) // Set the new state
+ * }
+ * } else {
+ * val initialState = ...
+ * state.update(initialState) // Set the initial state
+ * }
+ * ... // return something
+ * }
+ *
+ * }}}
+ *
+ * Java example:
+ * {{{
+ * TODO(@zsxwing)
+ * }}}
+ */
+@Experimental
+sealed abstract class State[S] {
+
+ /** Whether the state already exists */
+ def exists(): Boolean
+
+ /**
+ * Get the state if it exists, otherwise it will throw `java.util.NoSuchElementException`.
+ * Check with `exists()` whether the state exists or not before calling `get()`.
+ *
+ * @throws java.util.NoSuchElementException If the state does not exist.
+ */
+ def get(): S
+
+ /**
+ * Update the state with a new value.
+ *
+ * State cannot be updated if it has been already removed (that is, `remove()` has already been
+ * called) or it is going to be removed due to timeout (that is, `isTimingOut()` is `true`).
+ *
+ * @throws java.lang.IllegalArgumentException If the state has already been removed, or is
+ * going to be removed
+ */
+ def update(newState: S): Unit
+
+ /**
+ * Remove the state if it exists.
+ *
+ * State cannot be updated if it has been already removed (that is, `remove()` has already been
+ * called) or it is going to be removed due to timeout (that is, `isTimingOut()` is `true`).
+ */
+ def remove(): Unit
+
+ /**
+ * Whether the state is timing out and going to be removed by the system after the current batch.
+ * This timeout can occur if timeout duration has been specified in the
+ * [[org.apache.spark.streaming.StateSpec StatSpec]] and the key has not received any new data
+ * for that timeout duration.
+ */
+ def isTimingOut(): Boolean
+
+ /**
+ * Get the state as an [[scala.Option]]. It will be `Some(state)` if it exists, otherwise `None`.
+ */
+ @inline final def getOption(): Option[S] = if (exists) Some(get()) else None
+
+ @inline final override def toString(): String = {
+ getOption.map { _.toString }.getOrElse("<state not set>")
+ }
+}
+
+/** Internal implementation of the [[State]] interface */
+private[streaming] class StateImpl[S] extends State[S] {
+
+ private var state: S = null.asInstanceOf[S]
+ private var defined: Boolean = false
+ private var timingOut: Boolean = false
+ private var updated: Boolean = false
+ private var removed: Boolean = false
+
+ // ========= Public API =========
+ override def exists(): Boolean = {
+ defined
+ }
+
+ override def get(): S = {
+ if (defined) {
+ state
+ } else {
+ throw new NoSuchElementException("State is not set")
+ }
+ }
+
+ override def update(newState: S): Unit = {
+ require(!removed, "Cannot update the state after it has been removed")
+ require(!timingOut, "Cannot update the state that is timing out")
+ state = newState
+ defined = true
+ updated = true
+ }
+
+ override def isTimingOut(): Boolean = {
+ timingOut
+ }
+
+ override def remove(): Unit = {
+ require(!timingOut, "Cannot remove the state that is timing out")
+ require(!removed, "Cannot remove the state that has already been removed")
+ defined = false
+ updated = false
+ removed = true
+ }
+
+ // ========= Internal API =========
+
+ /** Whether the state has been marked for removing */
+ def isRemoved(): Boolean = {
+ removed
+ }
+
+ /** Whether the state has been been updated */
+ def isUpdated(): Boolean = {
+ updated
+ }
+
+ /**
+ * Update the internal data and flags in `this` to the given state option.
+ * This method allows `this` object to be reused across many state records.
+ */
+ def wrap(optionalState: Option[S]): Unit = {
+ optionalState match {
+ case Some(newState) =>
+ this.state = newState
+ defined = true
+
+ case None =>
+ this.state = null.asInstanceOf[S]
+ defined = false
+ }
+ timingOut = false
+ removed = false
+ updated = false
+ }
+
+ /**
+ * Update the internal data and flags in `this` to the given state that is going to be timed out.
+ * This method allows `this` object to be reused across many state records.
+ */
+ def wrapTiminoutState(newState: S): Unit = {
+ this.state = newState
+ defined = true
+ timingOut = true
+ removed = false
+ updated = false
+ }
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
new file mode 100644
index 0000000000..c9fe35e74c
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
@@ -0,0 +1,212 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaPairRDD
+import org.apache.spark.rdd.RDD
+import org.apache.spark.util.ClosureCleaner
+import org.apache.spark.{HashPartitioner, Partitioner}
+
+
+/**
+ * :: Experimental ::
+ * Abstract class representing all the specifications of the DStream transformation
+ * `trackStateByKey` operation of a
+ * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a
+ * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
+ * Use the [[org.apache.spark.streaming.StateSpec StateSpec.apply()]] or
+ * [[org.apache.spark.streaming.StateSpec StateSpec.create()]] to create instances of
+ * this class.
+ *
+ * Example in Scala:
+ * {{{
+ * def trackingFunction(data: Option[ValueType], wrappedState: State[StateType]): EmittedType = {
+ * ...
+ * }
+ *
+ * val spec = StateSpec.function(trackingFunction).numPartitions(10)
+ *
+ * val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType](spec)
+ * }}}
+ *
+ * Example in Java:
+ * {{{
+ * StateStateSpec[KeyType, ValueType, StateType, EmittedDataType] spec =
+ * StateStateSpec.function[KeyType, ValueType, StateType, EmittedDataType](trackingFunction)
+ * .numPartition(10);
+ *
+ * JavaDStream[EmittedDataType] emittedRecordDStream =
+ * javaPairDStream.trackStateByKey[StateType, EmittedDataType](spec);
+ * }}}
+ */
+@Experimental
+sealed abstract class StateSpec[KeyType, ValueType, StateType, EmittedType] extends Serializable {
+
+ /** Set the RDD containing the initial states that will be used by `trackStateByKey` */
+ def initialState(rdd: RDD[(KeyType, StateType)]): this.type
+
+ /** Set the RDD containing the initial states that will be used by `trackStateByKey` */
+ def initialState(javaPairRDD: JavaPairRDD[KeyType, StateType]): this.type
+
+ /**
+ * Set the number of partitions by which the state RDDs generated by `trackStateByKey`
+ * will be partitioned. Hash partitioning will be used.
+ */
+ def numPartitions(numPartitions: Int): this.type
+
+ /**
+ * Set the partitioner by which the state RDDs generated by `trackStateByKey` will be
+ * be partitioned.
+ */
+ def partitioner(partitioner: Partitioner): this.type
+
+ /**
+ * Set the duration after which the state of an idle key will be removed. A key and its state is
+ * considered idle if it has not received any data for at least the given duration. The state
+ * tracking function will be called one final time on the idle states that are going to be
+ * removed; [[org.apache.spark.streaming.State State.isTimingOut()]] set
+ * to `true` in that call.
+ */
+ def timeout(idleDuration: Duration): this.type
+}
+
+
+/**
+ * :: Experimental ::
+ * Builder object for creating instances of [[org.apache.spark.streaming.StateSpec StateSpec]]
+ * that is used for specifying the parameters of the DStream transformation
+ * `trackStateByKey` operation of a
+ * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a
+ * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
+ *
+ * Example in Scala:
+ * {{{
+ * def trackingFunction(data: Option[ValueType], wrappedState: State[StateType]): EmittedType = {
+ * ...
+ * }
+ *
+ * val spec = StateSpec.function(trackingFunction).numPartitions(10)
+ *
+ * val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType](spec)
+ * }}}
+ *
+ * Example in Java:
+ * {{{
+ * StateStateSpec[KeyType, ValueType, StateType, EmittedDataType] spec =
+ * StateStateSpec.function[KeyType, ValueType, StateType, EmittedDataType](trackingFunction)
+ * .numPartition(10);
+ *
+ * JavaDStream[EmittedDataType] emittedRecordDStream =
+ * javaPairDStream.trackStateByKey[StateType, EmittedDataType](spec);
+ * }}}
+ */
+@Experimental
+object StateSpec {
+ /**
+ * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications
+ * `trackStateByKey` operation on a
+ * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a
+ * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
+ * @param trackingFunction The function applied on every data item to manage the associated state
+ * and generate the emitted data
+ * @tparam KeyType Class of the keys
+ * @tparam ValueType Class of the values
+ * @tparam StateType Class of the states data
+ * @tparam EmittedType Class of the emitted data
+ */
+ def function[KeyType, ValueType, StateType, EmittedType](
+ trackingFunction: (Time, KeyType, Option[ValueType], State[StateType]) => Option[EmittedType]
+ ): StateSpec[KeyType, ValueType, StateType, EmittedType] = {
+ ClosureCleaner.clean(trackingFunction, checkSerializable = true)
+ new StateSpecImpl(trackingFunction)
+ }
+
+ /**
+ * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications
+ * `trackStateByKey` operation on a
+ * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a
+ * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
+ * @param trackingFunction The function applied on every data item to manage the associated state
+ * and generate the emitted data
+ * @tparam ValueType Class of the values
+ * @tparam StateType Class of the states data
+ * @tparam EmittedType Class of the emitted data
+ */
+ def function[KeyType, ValueType, StateType, EmittedType](
+ trackingFunction: (Option[ValueType], State[StateType]) => EmittedType
+ ): StateSpec[KeyType, ValueType, StateType, EmittedType] = {
+ ClosureCleaner.clean(trackingFunction, checkSerializable = true)
+ val wrappedFunction =
+ (time: Time, key: Any, value: Option[ValueType], state: State[StateType]) => {
+ Some(trackingFunction(value, state))
+ }
+ new StateSpecImpl(wrappedFunction)
+ }
+}
+
+
+/** Internal implementation of [[org.apache.spark.streaming.StateSpec]] interface. */
+private[streaming]
+case class StateSpecImpl[K, V, S, T](
+ function: (Time, K, Option[V], State[S]) => Option[T]) extends StateSpec[K, V, S, T] {
+
+ require(function != null)
+
+ @volatile private var partitioner: Partitioner = null
+ @volatile private var initialStateRDD: RDD[(K, S)] = null
+ @volatile private var timeoutInterval: Duration = null
+
+ override def initialState(rdd: RDD[(K, S)]): this.type = {
+ this.initialStateRDD = rdd
+ this
+ }
+
+ override def initialState(javaPairRDD: JavaPairRDD[K, S]): this.type = {
+ this.initialStateRDD = javaPairRDD.rdd
+ this
+ }
+
+
+ override def numPartitions(numPartitions: Int): this.type = {
+ this.partitioner(new HashPartitioner(numPartitions))
+ this
+ }
+
+ override def partitioner(partitioner: Partitioner): this.type = {
+ this.partitioner = partitioner
+ this
+ }
+
+ override def timeout(interval: Duration): this.type = {
+ this.timeoutInterval = interval
+ this
+ }
+
+ // ================= Private Methods =================
+
+ private[streaming] def getFunction(): (Time, K, Option[V], State[S]) => Option[T] = function
+
+ private[streaming] def getInitialStateRDD(): Option[RDD[(K, S)]] = Option(initialStateRDD)
+
+ private[streaming] def getPartitioner(): Option[Partitioner] = Option(partitioner)
+
+ private[streaming] def getTimeoutInterval(): Option[Duration] = Option(timeoutInterval)
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
index 71bec96d46..fb691eed27 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
@@ -24,19 +24,19 @@ import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.{JobConf, OutputFormat}
import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat}
-import org.apache.spark.{HashPartitioner, Partitioner}
+import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD
-import org.apache.spark.streaming.{Duration, Time}
import org.apache.spark.streaming.StreamingContext.rddToFileName
+import org.apache.spark.streaming._
import org.apache.spark.util.{SerializableConfiguration, SerializableJobConf}
+import org.apache.spark.{HashPartitioner, Partitioner}
/**
* Extra functions available on DStream of (key, value) pairs through an implicit conversion.
*/
class PairDStreamFunctions[K, V](self: DStream[(K, V)])
(implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K])
- extends Serializable
-{
+ extends Serializable {
private[streaming] def ssc = self.ssc
private[streaming] def sparkContext = self.context.sparkContext
@@ -351,6 +351,44 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)])
}
/**
+ * :: Experimental ::
+ * Return a new DStream of data generated by combining the key-value data in `this` stream
+ * with a continuously updated per-key state. The user-provided state tracking function is
+ * applied on each keyed data item along with its corresponding state. The function can choose to
+ * update/remove the state and return a transformed data, which forms the
+ * [[org.apache.spark.streaming.dstream.TrackStateDStream]].
+ *
+ * The specifications of this transformation is made through the
+ * [[org.apache.spark.streaming.StateSpec StateSpec]] class. Besides the tracking function, there
+ * are a number of optional parameters - initial state data, number of partitions, timeouts, etc.
+ * See the [[org.apache.spark.streaming.StateSpec StateSpec spec docs]] for more details.
+ *
+ * Example of using `trackStateByKey`:
+ * {{{
+ * def trackingFunction(data: Option[Int], wrappedState: State[Int]): String = {
+ * // Check if state exists, accordingly update/remove state and return transformed data
+ * }
+ *
+ * val spec = StateSpec.function(trackingFunction).numPartitions(10)
+ *
+ * val trackStateDStream = keyValueDStream.trackStateByKey[Int, String](spec)
+ * }}}
+ *
+ * @param spec Specification of this transformation
+ * @tparam StateType Class type of the state
+ * @tparam EmittedType Class type of the tranformed data return by the tracking function
+ */
+ @Experimental
+ def trackStateByKey[StateType: ClassTag, EmittedType: ClassTag](
+ spec: StateSpec[K, V, StateType, EmittedType]
+ ): TrackStateDStream[K, V, StateType, EmittedType] = {
+ new TrackStateDStreamImpl[K, V, StateType, EmittedType](
+ self,
+ spec.asInstanceOf[StateSpecImpl[K, V, StateType, EmittedType]]
+ )
+ }
+
+ /**
* Return a new "state" DStream where the state for each key is updated by applying
* the given function on the previous state of the key and the new values of each key.
* Hash partitioning is used to generate the RDDs with Spark's default number of partitions.
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala
new file mode 100644
index 0000000000..58d89c93bc
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.dstream
+
+import scala.reflect.ClassTag
+
+import org.apache.spark._
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.rdd.{EmptyRDD, RDD}
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.streaming._
+import org.apache.spark.streaming.rdd.{TrackStateRDD, TrackStateRDDRecord}
+
+/**
+ * :: Experimental ::
+ * DStream representing the stream of records emitted by the tracking function in the
+ * `trackStateByKey` operation on a
+ * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]].
+ * Additionally, it also gives access to the stream of state snapshots, that is, the state data of
+ * all keys after a batch has updated them.
+ *
+ * @tparam KeyType Class of the state key
+ * @tparam StateType Class of the state data
+ * @tparam EmittedType Class of the emitted records
+ */
+@Experimental
+sealed abstract class TrackStateDStream[KeyType, ValueType, StateType, EmittedType: ClassTag](
+ ssc: StreamingContext) extends DStream[EmittedType](ssc) {
+
+ /** Return a pair DStream where each RDD is the snapshot of the state of all the keys. */
+ def stateSnapshots(): DStream[(KeyType, StateType)]
+}
+
+/** Internal implementation of the [[TrackStateDStream]] */
+private[streaming] class TrackStateDStreamImpl[
+ KeyType: ClassTag, ValueType: ClassTag, StateType: ClassTag, EmittedType: ClassTag](
+ dataStream: DStream[(KeyType, ValueType)],
+ spec: StateSpecImpl[KeyType, ValueType, StateType, EmittedType])
+ extends TrackStateDStream[KeyType, ValueType, StateType, EmittedType](dataStream.context) {
+
+ private val internalStream =
+ new InternalTrackStateDStream[KeyType, ValueType, StateType, EmittedType](dataStream, spec)
+
+ override def slideDuration: Duration = internalStream.slideDuration
+
+ override def dependencies: List[DStream[_]] = List(internalStream)
+
+ override def compute(validTime: Time): Option[RDD[EmittedType]] = {
+ internalStream.getOrCompute(validTime).map { _.flatMap[EmittedType] { _.emittedRecords } }
+ }
+
+ /**
+ * Forward the checkpoint interval to the internal DStream that computes the state maps. This
+ * to make sure that this DStream does not get checkpointed, only the internal stream.
+ */
+ override def checkpoint(checkpointInterval: Duration): DStream[EmittedType] = {
+ internalStream.checkpoint(checkpointInterval)
+ this
+ }
+
+ /** Return a pair DStream where each RDD is the snapshot of the state of all the keys. */
+ def stateSnapshots(): DStream[(KeyType, StateType)] = {
+ internalStream.flatMap {
+ _.stateMap.getAll().map { case (k, s, _) => (k, s) }.toTraversable }
+ }
+
+ def keyClass: Class[_] = implicitly[ClassTag[KeyType]].runtimeClass
+
+ def valueClass: Class[_] = implicitly[ClassTag[ValueType]].runtimeClass
+
+ def stateClass: Class[_] = implicitly[ClassTag[StateType]].runtimeClass
+
+ def emittedClass: Class[_] = implicitly[ClassTag[EmittedType]].runtimeClass
+}
+
+/**
+ * A DStream that allows per-key state to be maintains, and arbitrary records to be generated
+ * based on updates to the state. This is the main DStream that implements the `trackStateByKey`
+ * operation on DStreams.
+ *
+ * @param parent Parent (key, value) stream that is the source
+ * @param spec Specifications of the trackStateByKey operation
+ * @tparam K Key type
+ * @tparam V Value type
+ * @tparam S Type of the state maintained
+ * @tparam E Type of the emitted data
+ */
+private[streaming]
+class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
+ parent: DStream[(K, V)], spec: StateSpecImpl[K, V, S, E])
+ extends DStream[TrackStateRDDRecord[K, S, E]](parent.context) {
+
+ persist(StorageLevel.MEMORY_ONLY)
+
+ private val partitioner = spec.getPartitioner().getOrElse(
+ new HashPartitioner(ssc.sc.defaultParallelism))
+
+ private val trackingFunction = spec.getFunction()
+
+ override def slideDuration: Duration = parent.slideDuration
+
+ override def dependencies: List[DStream[_]] = List(parent)
+
+ /** Enable automatic checkpointing */
+ override val mustCheckpoint = true
+
+ /** Method that generates a RDD for the given time */
+ override def compute(validTime: Time): Option[RDD[TrackStateRDDRecord[K, S, E]]] = {
+ // Get the previous state or create a new empty state RDD
+ val prevStateRDD = getOrCompute(validTime - slideDuration).getOrElse {
+ TrackStateRDD.createFromPairRDD[K, V, S, E](
+ spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)),
+ partitioner, validTime
+ )
+ }
+
+ // Compute the new state RDD with previous state RDD and partitioned data RDD
+ parent.getOrCompute(validTime).map { dataRDD =>
+ val partitionedDataRDD = dataRDD.partitionBy(partitioner)
+ val timeoutThresholdTime = spec.getTimeoutInterval().map { interval =>
+ (validTime - interval).milliseconds
+ }
+ new TrackStateRDD(
+ prevStateRDD, partitionedDataRDD, trackingFunction, validTime, timeoutThresholdTime)
+ }
+ }
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala
new file mode 100644
index 0000000000..ed7cea26d0
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.rdd
+
+import java.io.{IOException, ObjectInputStream, ObjectOutputStream}
+
+import scala.collection.mutable.ArrayBuffer
+import scala.reflect.ClassTag
+
+import org.apache.spark.rdd.{MapPartitionsRDD, RDD}
+import org.apache.spark.streaming.{Time, StateImpl, State}
+import org.apache.spark.streaming.util.{EmptyStateMap, StateMap}
+import org.apache.spark.util.Utils
+import org.apache.spark._
+
+/**
+ * Record storing the keyed-state [[TrackStateRDD]]. Each record contains a [[StateMap]] and a
+ * sequence of records returned by the tracking function of `trackStateByKey`.
+ */
+private[streaming] case class TrackStateRDDRecord[K, S, T](
+ var stateMap: StateMap[K, S], var emittedRecords: Seq[T])
+
+/**
+ * Partition of the [[TrackStateRDD]], which depends on corresponding partitions of prev state
+ * RDD, and a partitioned keyed-data RDD
+ */
+private[streaming] class TrackStateRDDPartition(
+ idx: Int,
+ @transient private var prevStateRDD: RDD[_],
+ @transient private var partitionedDataRDD: RDD[_]) extends Partition {
+
+ private[rdd] var previousSessionRDDPartition: Partition = null
+ private[rdd] var partitionedDataRDDPartition: Partition = null
+
+ override def index: Int = idx
+ override def hashCode(): Int = idx
+
+ @throws(classOf[IOException])
+ private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
+ // Update the reference to parent split at the time of task serialization
+ previousSessionRDDPartition = prevStateRDD.partitions(index)
+ partitionedDataRDDPartition = partitionedDataRDD.partitions(index)
+ oos.defaultWriteObject()
+ }
+}
+
+
+/**
+ * RDD storing the keyed-state of `trackStateByKey` and corresponding emitted records.
+ * Each partition of this RDD has a single record of type [[TrackStateRDDRecord]]. This contains a
+ * [[StateMap]] (containing the keyed-states) and the sequence of records returned by the tracking
+ * function of `trackStateByKey`.
+ * @param prevStateRDD The previous TrackStateRDD on whose StateMap data `this` RDD will be created
+ * @param partitionedDataRDD The partitioned data RDD which is used update the previous StateMaps
+ * in the `prevStateRDD` to create `this` RDD
+ * @param trackingFunction The function that will be used to update state and return new data
+ * @param batchTime The time of the batch to which this RDD belongs to. Use to update
+ */
+private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
+ private var prevStateRDD: RDD[TrackStateRDDRecord[K, S, T]],
+ private var partitionedDataRDD: RDD[(K, V)],
+ trackingFunction: (Time, K, Option[V], State[S]) => Option[T],
+ batchTime: Time, timeoutThresholdTime: Option[Long]
+ ) extends RDD[TrackStateRDDRecord[K, S, T]](
+ partitionedDataRDD.sparkContext,
+ List(
+ new OneToOneDependency[TrackStateRDDRecord[K, S, T]](prevStateRDD),
+ new OneToOneDependency(partitionedDataRDD))
+ ) {
+
+ @volatile private var doFullScan = false
+
+ require(prevStateRDD.partitioner.nonEmpty)
+ require(partitionedDataRDD.partitioner == prevStateRDD.partitioner)
+
+ override val partitioner = prevStateRDD.partitioner
+
+ override def checkpoint(): Unit = {
+ super.checkpoint()
+ doFullScan = true
+ }
+
+ override def compute(
+ partition: Partition, context: TaskContext): Iterator[TrackStateRDDRecord[K, S, T]] = {
+
+ val stateRDDPartition = partition.asInstanceOf[TrackStateRDDPartition]
+ val prevStateRDDIterator = prevStateRDD.iterator(
+ stateRDDPartition.previousSessionRDDPartition, context)
+ val dataIterator = partitionedDataRDD.iterator(
+ stateRDDPartition.partitionedDataRDDPartition, context)
+
+ // Create a new state map by cloning the previous one (if it exists) or by creating an empty one
+ val newStateMap = if (prevStateRDDIterator.hasNext) {
+ prevStateRDDIterator.next().stateMap.copy()
+ } else {
+ new EmptyStateMap[K, S]()
+ }
+
+ val emittedRecords = new ArrayBuffer[T]
+ val wrappedState = new StateImpl[S]()
+
+ // Call the tracking function on each record in the data RDD partition, and accordingly
+ // update the states touched, and the data returned by the tracking function.
+ dataIterator.foreach { case (key, value) =>
+ wrappedState.wrap(newStateMap.get(key))
+ val emittedRecord = trackingFunction(batchTime, key, Some(value), wrappedState)
+ if (wrappedState.isRemoved) {
+ newStateMap.remove(key)
+ } else if (wrappedState.isUpdated) {
+ newStateMap.put(key, wrappedState.get(), batchTime.milliseconds)
+ }
+ emittedRecords ++= emittedRecord
+ }
+
+ // If the RDD is expected to be doing a full scan of all the data in the StateMap,
+ // then use this opportunity to filter out those keys that have timed out.
+ // For each of them call the tracking function.
+ if (doFullScan && timeoutThresholdTime.isDefined) {
+ newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) =>
+ wrappedState.wrapTiminoutState(state)
+ val emittedRecord = trackingFunction(batchTime, key, None, wrappedState)
+ emittedRecords ++= emittedRecord
+ newStateMap.remove(key)
+ }
+ }
+
+ Iterator(TrackStateRDDRecord(newStateMap, emittedRecords))
+ }
+
+ override protected def getPartitions: Array[Partition] = {
+ Array.tabulate(prevStateRDD.partitions.length) { i =>
+ new TrackStateRDDPartition(i, prevStateRDD, partitionedDataRDD)}
+ }
+
+ override def clearDependencies(): Unit = {
+ super.clearDependencies()
+ prevStateRDD = null
+ partitionedDataRDD = null
+ }
+
+ def setFullScan(): Unit = {
+ doFullScan = true
+ }
+}
+
+private[streaming] object TrackStateRDD {
+
+ def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
+ pairRDD: RDD[(K, S)],
+ partitioner: Partitioner,
+ updateTime: Time): TrackStateRDD[K, V, S, T] = {
+
+ val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions ({ iterator =>
+ val stateMap = StateMap.create[K, S](SparkEnv.get.conf)
+ iterator.foreach { case (key, state) => stateMap.put(key, state, updateTime.milliseconds) }
+ Iterator(TrackStateRDDRecord(stateMap, Seq.empty[T]))
+ }, preservesPartitioning = true)
+
+ val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner)
+
+ val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None
+
+ new TrackStateRDD[K, V, S, T](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None)
+ }
+}
+
+private[streaming] class EmittedRecordsRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
+ parent: TrackStateRDD[K, V, S, T]) extends RDD[T](parent) {
+ override protected def getPartitions: Array[Partition] = parent.partitions
+ override def compute(partition: Partition, context: TaskContext): Iterator[T] = {
+ parent.compute(partition, context).flatMap { _.emittedRecords }
+ }
+}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
new file mode 100644
index 0000000000..ed622ef7bf
--- /dev/null
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
@@ -0,0 +1,337 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.util
+
+import java.io.{ObjectInputStream, ObjectOutputStream}
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.SparkConf
+import org.apache.spark.streaming.util.OpenHashMapBasedStateMap._
+import org.apache.spark.util.collection.OpenHashMap
+
+/** Internal interface for defining the map that keeps track of sessions. */
+private[streaming] abstract class StateMap[K: ClassTag, S: ClassTag] extends Serializable {
+
+ /** Get the state for a key if it exists */
+ def get(key: K): Option[S]
+
+ /** Get all the keys and states whose updated time is older than the given threshold time */
+ def getByTime(threshUpdatedTime: Long): Iterator[(K, S, Long)]
+
+ /** Get all the keys and states in this map. */
+ def getAll(): Iterator[(K, S, Long)]
+
+ /** Add or update state */
+ def put(key: K, state: S, updatedTime: Long): Unit
+
+ /** Remove a key */
+ def remove(key: K): Unit
+
+ /**
+ * Shallow copy `this` map to create a new state map.
+ * Updates to the new map should not mutate `this` map.
+ */
+ def copy(): StateMap[K, S]
+
+ def toDebugString(): String = toString()
+}
+
+/** Companion object for [[StateMap]], with utility methods */
+private[streaming] object StateMap {
+ def empty[K: ClassTag, S: ClassTag]: StateMap[K, S] = new EmptyStateMap[K, S]
+
+ def create[K: ClassTag, S: ClassTag](conf: SparkConf): StateMap[K, S] = {
+ val deltaChainThreshold = conf.getInt("spark.streaming.sessionByKey.deltaChainThreshold",
+ DELTA_CHAIN_LENGTH_THRESHOLD)
+ new OpenHashMapBasedStateMap[K, S](64, deltaChainThreshold)
+ }
+}
+
+/** Implementation of StateMap interface representing an empty map */
+private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends StateMap[K, S] {
+ override def put(key: K, session: S, updateTime: Long): Unit = {
+ throw new NotImplementedError("put() should not be called on an EmptyStateMap")
+ }
+ override def get(key: K): Option[S] = None
+ override def getByTime(threshUpdatedTime: Long): Iterator[(K, S, Long)] = Iterator.empty
+ override def getAll(): Iterator[(K, S, Long)] = Iterator.empty
+ override def copy(): StateMap[K, S] = this
+ override def remove(key: K): Unit = { }
+ override def toDebugString(): String = ""
+}
+
+/** Implementation of StateMap based on Spark's [[org.apache.spark.util.collection.OpenHashMap]] */
+private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag](
+ @transient @volatile var parentStateMap: StateMap[K, S],
+ initialCapacity: Int = 64,
+ deltaChainThreshold: Int = DELTA_CHAIN_LENGTH_THRESHOLD
+ ) extends StateMap[K, S] { self =>
+
+ def this(initialCapacity: Int, deltaChainThreshold: Int) = this(
+ new EmptyStateMap[K, S],
+ initialCapacity = initialCapacity,
+ deltaChainThreshold = deltaChainThreshold)
+
+ def this(deltaChainThreshold: Int) = this(
+ initialCapacity = 64, deltaChainThreshold = deltaChainThreshold)
+
+ def this() = this(DELTA_CHAIN_LENGTH_THRESHOLD)
+
+ @transient @volatile private var deltaMap =
+ new OpenHashMap[K, StateInfo[S]](initialCapacity)
+
+ /** Get the session data if it exists */
+ override def get(key: K): Option[S] = {
+ val stateInfo = deltaMap(key)
+ if (stateInfo != null) {
+ if (!stateInfo.deleted) {
+ Some(stateInfo.data)
+ } else {
+ None
+ }
+ } else {
+ parentStateMap.get(key)
+ }
+ }
+
+ /** Get all the keys and states whose updated time is older than the give threshold time */
+ override def getByTime(threshUpdatedTime: Long): Iterator[(K, S, Long)] = {
+ val oldStates = parentStateMap.getByTime(threshUpdatedTime).filter { case (key, value, _) =>
+ !deltaMap.contains(key)
+ }
+
+ val updatedStates = deltaMap.iterator.filter { case (_, stateInfo) =>
+ !stateInfo.deleted && stateInfo.updateTime < threshUpdatedTime
+ }.map { case (key, stateInfo) =>
+ (key, stateInfo.data, stateInfo.updateTime)
+ }
+ oldStates ++ updatedStates
+ }
+
+ /** Get all the keys and states in this map. */
+ override def getAll(): Iterator[(K, S, Long)] = {
+
+ val oldStates = parentStateMap.getAll().filter { case (key, _, _) =>
+ !deltaMap.contains(key)
+ }
+
+ val updatedStates = deltaMap.iterator.filter { ! _._2.deleted }.map { case (key, stateInfo) =>
+ (key, stateInfo.data, stateInfo.updateTime)
+ }
+ oldStates ++ updatedStates
+ }
+
+ /** Add or update state */
+ override def put(key: K, state: S, updateTime: Long): Unit = {
+ val stateInfo = deltaMap(key)
+ if (stateInfo != null) {
+ stateInfo.update(state, updateTime)
+ } else {
+ deltaMap.update(key, new StateInfo(state, updateTime))
+ }
+ }
+
+ /** Remove a state */
+ override def remove(key: K): Unit = {
+ val stateInfo = deltaMap(key)
+ if (stateInfo != null) {
+ stateInfo.markDeleted()
+ } else {
+ val newInfo = new StateInfo[S](deleted = true)
+ deltaMap.update(key, newInfo)
+ }
+ }
+
+ /**
+ * Shallow copy the map to create a new session store. Updates to the new map
+ * should not mutate `this` map.
+ */
+ override def copy(): StateMap[K, S] = {
+ new OpenHashMapBasedStateMap[K, S](this, deltaChainThreshold = deltaChainThreshold)
+ }
+
+ /** Whether the delta chain lenght is long enough that it should be compacted */
+ def shouldCompact: Boolean = {
+ deltaChainLength >= deltaChainThreshold
+ }
+
+ /** Length of the delta chains of this map */
+ def deltaChainLength: Int = parentStateMap match {
+ case map: OpenHashMapBasedStateMap[_, _] => map.deltaChainLength + 1
+ case _ => 0
+ }
+
+ /**
+ * Approximate number of keys in the map. This is an overestimation that is mainly used to
+ * reserve capacity in a new map at delta compaction time.
+ */
+ def approxSize: Int = deltaMap.size + {
+ parentStateMap match {
+ case s: OpenHashMapBasedStateMap[_, _] => s.approxSize
+ case _ => 0
+ }
+ }
+
+ /** Get all the data of this map as string formatted as a tree based on the delta depth */
+ override def toDebugString(): String = {
+ val tabs = if (deltaChainLength > 0) {
+ (" " * (deltaChainLength - 1)) + "+--- "
+ } else ""
+ parentStateMap.toDebugString() + "\n" + deltaMap.iterator.mkString(tabs, "\n" + tabs, "")
+ }
+
+ override def toString(): String = {
+ s"[${System.identityHashCode(this)}, ${System.identityHashCode(parentStateMap)}]"
+ }
+
+ /**
+ * Serialize the map data. Besides serialization, this method actually compact the deltas
+ * (if needed) in a single pass over all the data in the map.
+ */
+
+ private def writeObject(outputStream: ObjectOutputStream): Unit = {
+ // Write all the non-transient fields, especially class tags, etc.
+ outputStream.defaultWriteObject()
+
+ // Write the data in the delta of this state map
+ outputStream.writeInt(deltaMap.size)
+ val deltaMapIterator = deltaMap.iterator
+ var deltaMapCount = 0
+ while (deltaMapIterator.hasNext) {
+ deltaMapCount += 1
+ val (key, stateInfo) = deltaMapIterator.next()
+ outputStream.writeObject(key)
+ outputStream.writeObject(stateInfo)
+ }
+ assert(deltaMapCount == deltaMap.size)
+
+ // Write the data in the parent state map while copying the data into a new parent map for
+ // compaction (if needed)
+ val doCompaction = shouldCompact
+ val newParentSessionStore = if (doCompaction) {
+ val initCapacity = if (approxSize > 0) approxSize else 64
+ new OpenHashMapBasedStateMap[K, S](initialCapacity = initCapacity, deltaChainThreshold)
+ } else { null }
+
+ val iterOfActiveSessions = parentStateMap.getAll()
+
+ var parentSessionCount = 0
+
+ // First write the approximate size of the data to be written, so that readObject can
+ // allocate appropriately sized OpenHashMap.
+ outputStream.writeInt(approxSize)
+
+ while(iterOfActiveSessions.hasNext) {
+ parentSessionCount += 1
+
+ val (key, state, updateTime) = iterOfActiveSessions.next()
+ outputStream.writeObject(key)
+ outputStream.writeObject(state)
+ outputStream.writeLong(updateTime)
+
+ if (doCompaction) {
+ newParentSessionStore.deltaMap.update(
+ key, StateInfo(state, updateTime, deleted = false))
+ }
+ }
+
+ // Write the final limit marking object with the correct count of records written.
+ val limiterObj = new LimitMarker(parentSessionCount)
+ outputStream.writeObject(limiterObj)
+ if (doCompaction) {
+ parentStateMap = newParentSessionStore
+ }
+ }
+
+ /** Deserialize the map data. */
+ private def readObject(inputStream: ObjectInputStream): Unit = {
+
+ // Read the non-transient fields, especially class tags, etc.
+ inputStream.defaultReadObject()
+
+ // Read the data of the delta
+ val deltaMapSize = inputStream.readInt()
+ deltaMap = new OpenHashMap[K, StateInfo[S]]()
+ var deltaMapCount = 0
+ while (deltaMapCount < deltaMapSize) {
+ val key = inputStream.readObject().asInstanceOf[K]
+ val sessionInfo = inputStream.readObject().asInstanceOf[StateInfo[S]]
+ deltaMap.update(key, sessionInfo)
+ deltaMapCount += 1
+ }
+
+
+ // Read the data of the parent map. Keep reading records, until the limiter is reached
+ // First read the approximate number of records to expect and allocate properly size
+ // OpenHashMap
+ val parentSessionStoreSizeHint = inputStream.readInt()
+ val newParentSessionStore = new OpenHashMapBasedStateMap[K, S](
+ initialCapacity = parentSessionStoreSizeHint, deltaChainThreshold)
+
+ // Read the records until the limit marking object has been reached
+ var parentSessionLoopDone = false
+ while(!parentSessionLoopDone) {
+ val obj = inputStream.readObject()
+ if (obj.isInstanceOf[LimitMarker]) {
+ parentSessionLoopDone = true
+ val expectedCount = obj.asInstanceOf[LimitMarker].num
+ assert(expectedCount == newParentSessionStore.deltaMap.size)
+ } else {
+ val key = obj.asInstanceOf[K]
+ val state = inputStream.readObject().asInstanceOf[S]
+ val updateTime = inputStream.readLong()
+ newParentSessionStore.deltaMap.update(
+ key, StateInfo(state, updateTime, deleted = false))
+ }
+ }
+ parentStateMap = newParentSessionStore
+ }
+}
+
+/**
+ * Companion object of [[OpenHashMapBasedStateMap]] having associated helper
+ * classes and methods
+ */
+private[streaming] object OpenHashMapBasedStateMap {
+
+ /** Internal class to represent the state information */
+ case class StateInfo[S](
+ var data: S = null.asInstanceOf[S],
+ var updateTime: Long = -1,
+ var deleted: Boolean = false) {
+
+ def markDeleted(): Unit = {
+ deleted = true
+ }
+
+ def update(newData: S, newUpdateTime: Long): Unit = {
+ data = newData
+ updateTime = newUpdateTime
+ deleted = false
+ }
+ }
+
+ /**
+ * Internal class to represent a marker the demarkate the the end of all state data in the
+ * serialized bytes.
+ */
+ class LimitMarker(val num: Int) extends Serializable
+
+ val DELTA_CHAIN_LENGTH_THRESHOLD = 20
+}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala
new file mode 100644
index 0000000000..48d3b41b66
--- /dev/null
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala
@@ -0,0 +1,314 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming
+
+import scala.collection.{immutable, mutable, Map}
+import scala.util.Random
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.streaming.util.{EmptyStateMap, OpenHashMapBasedStateMap, StateMap}
+import org.apache.spark.util.Utils
+
+class StateMapSuite extends SparkFunSuite {
+
+ test("EmptyStateMap") {
+ val map = new EmptyStateMap[Int, Int]
+ intercept[scala.NotImplementedError] {
+ map.put(1, 1, 1)
+ }
+ assert(map.get(1) === None)
+ assert(map.getByTime(10000).isEmpty)
+ assert(map.getAll().isEmpty)
+ map.remove(1) // no exception
+ assert(map.copy().eq(map))
+ }
+
+ test("OpenHashMapBasedStateMap - put, get, getByTime, getAll, remove") {
+ val map = new OpenHashMapBasedStateMap[Int, Int]()
+
+ map.put(1, 100, 10)
+ assert(map.get(1) === Some(100))
+ assert(map.get(2) === None)
+ assert(map.getByTime(11).toSet === Set((1, 100, 10)))
+ assert(map.getByTime(10).toSet === Set.empty)
+ assert(map.getByTime(9).toSet === Set.empty)
+ assert(map.getAll().toSet === Set((1, 100, 10)))
+
+ map.put(2, 200, 20)
+ assert(map.getByTime(21).toSet === Set((1, 100, 10), (2, 200, 20)))
+ assert(map.getByTime(11).toSet === Set((1, 100, 10)))
+ assert(map.getByTime(10).toSet === Set.empty)
+ assert(map.getByTime(9).toSet === Set.empty)
+ assert(map.getAll().toSet === Set((1, 100, 10), (2, 200, 20)))
+
+ map.remove(1)
+ assert(map.get(1) === None)
+ assert(map.getAll().toSet === Set((2, 200, 20)))
+ }
+
+ test("OpenHashMapBasedStateMap - put, get, getByTime, getAll, remove with copy") {
+ val parentMap = new OpenHashMapBasedStateMap[Int, Int]()
+ parentMap.put(1, 100, 1)
+ parentMap.put(2, 200, 2)
+ parentMap.remove(1)
+
+ // Create child map and make changes
+ val map = parentMap.copy()
+ assert(map.get(1) === None)
+ assert(map.get(2) === Some(200))
+ assert(map.getByTime(10).toSet === Set((2, 200, 2)))
+ assert(map.getByTime(2).toSet === Set.empty)
+ assert(map.getAll().toSet === Set((2, 200, 2)))
+
+ // Add new items
+ map.put(3, 300, 3)
+ assert(map.get(3) === Some(300))
+ map.put(4, 400, 4)
+ assert(map.get(4) === Some(400))
+ assert(map.getByTime(10).toSet === Set((2, 200, 2), (3, 300, 3), (4, 400, 4)))
+ assert(map.getByTime(4).toSet === Set((2, 200, 2), (3, 300, 3)))
+ assert(map.getAll().toSet === Set((2, 200, 2), (3, 300, 3), (4, 400, 4)))
+ assert(parentMap.getAll().toSet === Set((2, 200, 2)))
+
+ // Remove items
+ map.remove(4)
+ assert(map.get(4) === None) // item added in this map, then removed in this map
+ map.remove(2)
+ assert(map.get(2) === None) // item removed in parent map, then added in this map
+ assert(map.getAll().toSet === Set((3, 300, 3)))
+ assert(parentMap.getAll().toSet === Set((2, 200, 2)))
+
+ // Update items
+ map.put(1, 1000, 100)
+ assert(map.get(1) === Some(1000)) // item removed in parent map, then added in this map
+ map.put(2, 2000, 200)
+ assert(map.get(2) === Some(2000)) // item added in parent map, then removed + added in this map
+ map.put(3, 3000, 300)
+ assert(map.get(3) === Some(3000)) // item added + updated in this map
+ map.put(4, 4000, 400)
+ assert(map.get(4) === Some(4000)) // item removed + updated in this map
+
+ assert(map.getAll().toSet ===
+ Set((1, 1000, 100), (2, 2000, 200), (3, 3000, 300), (4, 4000, 400)))
+ assert(parentMap.getAll().toSet === Set((2, 200, 2)))
+
+ map.remove(2) // remove item present in parent map, so that its not visible in child map
+
+ // Create child map and see availability of items
+ val childMap = map.copy()
+ assert(childMap.getAll().toSet === map.getAll().toSet)
+ assert(childMap.get(1) === Some(1000)) // item removed in grandparent, but added in parent map
+ assert(childMap.get(2) === None) // item added in grandparent, but removed in parent map
+ assert(childMap.get(3) === Some(3000)) // item added and updated in parent map
+
+ childMap.put(2, 20000, 200)
+ assert(childMap.get(2) === Some(20000)) // item map
+ }
+
+ test("OpenHashMapBasedStateMap - serializing and deserializing") {
+ val map1 = new OpenHashMapBasedStateMap[Int, Int]()
+ map1.put(1, 100, 1)
+ map1.put(2, 200, 2)
+
+ val map2 = map1.copy()
+ map2.put(3, 300, 3)
+ map2.put(4, 400, 4)
+
+ val map3 = map2.copy()
+ map3.put(3, 600, 3)
+ map3.remove(2)
+
+ // Do not test compaction
+ assert(map3.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false)
+
+ val deser_map3 = Utils.deserialize[StateMap[Int, Int]](
+ Utils.serialize(map3), Thread.currentThread().getContextClassLoader)
+ assertMap(deser_map3, map3, 1, "Deserialized map not same as original map")
+ }
+
+ test("OpenHashMapBasedStateMap - serializing and deserializing with compaction") {
+ val targetDeltaLength = 10
+ val deltaChainThreshold = 5
+
+ var map = new OpenHashMapBasedStateMap[Int, Int](
+ deltaChainThreshold = deltaChainThreshold)
+
+ // Make large delta chain with length more than deltaChainThreshold
+ for(i <- 1 to targetDeltaLength) {
+ map.put(Random.nextInt(), Random.nextInt(), 1)
+ map = map.copy().asInstanceOf[OpenHashMapBasedStateMap[Int, Int]]
+ }
+ assert(map.deltaChainLength > deltaChainThreshold)
+ assert(map.shouldCompact === true)
+
+ val deser_map = Utils.deserialize[OpenHashMapBasedStateMap[Int, Int]](
+ Utils.serialize(map), Thread.currentThread().getContextClassLoader)
+ assert(deser_map.deltaChainLength < deltaChainThreshold)
+ assert(deser_map.shouldCompact === false)
+ assertMap(deser_map, map, 1, "Deserialized + compacted map not same as original map")
+ }
+
+ test("OpenHashMapBasedStateMap - all possible sequences of operations with copies ") {
+ /*
+ * This tests the map using all permutations of sequences operations, across multiple map
+ * copies as well as between copies. It is to ensure complete coverage, though it is
+ * kind of hard to debug this. It is set up as follows.
+ *
+ * - For any key, there can be 2 types of update ops on a state map - put or remove
+ *
+ * - These operations are done on a test map in "sets". After each set, the map is "copied"
+ * to create a new map, and the next set of operations are done on the new one. This tests
+ * whether the map data persistes correctly across copies.
+ *
+ * - Within each set, there are a number of operations to test whether the map correctly
+ * updates and removes data without affecting the parent state map.
+ *
+ * - Overall this creates (numSets * numOpsPerSet) operations, each of which that can 2 types
+ * of operations. This leads to a total of [2 ^ (numSets * numOpsPerSet)] different sequence
+ * of operations, which we will test with different keys.
+ *
+ * Example: With numSets = 2, and numOpsPerSet = 2 give numTotalOps = 4. This means that
+ * 2 ^ 4 = 16 possible permutations needs to be tested using 16 keys.
+ * _______________________________________________
+ * | | Set1 | Set2 |
+ * | |-----------------|-----------------|
+ * | | Op1 Op2 |c| Op3 Op4 |
+ * |---------|----------------|o|----------------|
+ * | key 0 | put put |p| put put |
+ * | key 1 | put put |y| put rem |
+ * | key 2 | put put | | rem put |
+ * | key 3 | put put |t| rem rem |
+ * | key 4 | put rem |h| put put |
+ * | key 5 | put rem |e| put rem |
+ * | key 6 | put rem | | rem put |
+ * | key 7 | put rem |s| rem rem |
+ * | key 8 | rem put |t| put put |
+ * | key 9 | rem put |a| put rem |
+ * | key 10 | rem put |t| rem put |
+ * | key 11 | rem put |e| rem rem |
+ * | key 12 | rem rem | | put put |
+ * | key 13 | rem rem |m| put rem |
+ * | key 14 | rem rem |a| rem put |
+ * | key 15 | rem rem |p| rem rem |
+ * |_________|________________|_|________________|
+ */
+
+ val numTypeMapOps = 2 // 0 = put a new value, 1 = remove value
+ val numSets = 3
+ val numOpsPerSet = 3 // to test seq of ops like update -> remove -> update in same set
+ val numTotalOps = numOpsPerSet * numSets
+ val numKeys = math.pow(numTypeMapOps, numTotalOps).toInt // to get all combinations of ops
+
+ val refMap = new mutable.HashMap[Int, (Int, Long)]()
+ var prevSetRefMap: immutable.Map[Int, (Int, Long)] = null
+
+ var stateMap: StateMap[Int, Int] = new OpenHashMapBasedStateMap[Int, Int]()
+ var prevSetStateMap: StateMap[Int, Int] = null
+
+ var time = 1L
+
+ for (setId <- 0 until numSets) {
+ for (opInSetId <- 0 until numOpsPerSet) {
+ val opId = setId * numOpsPerSet + opInSetId
+ for (keyId <- 0 until numKeys) {
+ time += 1
+ // Find the operation type that needs to be done
+ // This is similar to finding the nth bit value of a binary number
+ // E.g. nth bit from the right of any binary number B is [ B / (2 ^ (n - 1)) ] % 2
+ val opCode =
+ (keyId / math.pow(numTypeMapOps, numTotalOps - opId - 1).toInt) % numTypeMapOps
+ opCode match {
+ case 0 =>
+ val value = Random.nextInt()
+ stateMap.put(keyId, value, time)
+ refMap.put(keyId, (value, time))
+ case 1 =>
+ stateMap.remove(keyId)
+ refMap.remove(keyId)
+ }
+ }
+
+ // Test whether the current state map after all key updates is correct
+ assertMap(stateMap, refMap, time, "State map does not match reference map")
+
+ // Test whether the previous map before copy has not changed
+ if (prevSetStateMap != null && prevSetRefMap != null) {
+ assertMap(prevSetStateMap, prevSetRefMap, time,
+ "Parent state map somehow got modified, does not match corresponding reference map")
+ }
+ }
+
+ // Copy the map and remember the previous maps for future tests
+ prevSetStateMap = stateMap
+ prevSetRefMap = refMap.toMap
+ stateMap = stateMap.copy()
+
+ // Assert that the copied map has the same data
+ assertMap(stateMap, prevSetRefMap, time,
+ "State map does not match reference map after copying")
+ }
+ assertMap(stateMap, refMap.toMap, time, "Final state map does not match reference map")
+ }
+
+ // Assert whether all the data and operations on a state map matches that of a reference state map
+ private def assertMap(
+ mapToTest: StateMap[Int, Int],
+ refMapToTestWith: StateMap[Int, Int],
+ time: Long,
+ msg: String): Unit = {
+ withClue(msg) {
+ // Assert all the data is same as the reference map
+ assert(mapToTest.getAll().toSet === refMapToTestWith.getAll().toSet)
+
+ // Assert that get on every key returns the right value
+ for (keyId <- refMapToTestWith.getAll().map { _._1 }) {
+ assert(mapToTest.get(keyId) === refMapToTestWith.get(keyId))
+ }
+
+ // Assert that every time threshold returns the correct data
+ for (t <- 0L to (time + 1)) {
+ assert(mapToTest.getByTime(t).toSet === refMapToTestWith.getByTime(t).toSet)
+ }
+ }
+ }
+
+ // Assert whether all the data and operations on a state map matches that of a reference map
+ private def assertMap(
+ mapToTest: StateMap[Int, Int],
+ refMapToTestWith: Map[Int, (Int, Long)],
+ time: Long,
+ msg: String): Unit = {
+ withClue(msg) {
+ // Assert all the data is same as the reference map
+ assert(mapToTest.getAll().toSet ===
+ refMapToTestWith.iterator.map { x => (x._1, x._2._1, x._2._2) }.toSet)
+
+ // Assert that get on every key returns the right value
+ for (keyId <- refMapToTestWith.keys) {
+ assert(mapToTest.get(keyId) === refMapToTestWith.get(keyId).map { _._1 })
+ }
+
+ // Assert that every time threshold returns the correct data
+ for (t <- 0L to (time + 1)) {
+ val expectedRecords =
+ refMapToTestWith.iterator.filter { _._2._2 < t }.map { x => (x._1, x._2._1, x._2._2) }
+ assert(mapToTest.getByTime(t).toSet === expectedRecords.toSet)
+ }
+ }
+ }
+}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala
new file mode 100644
index 0000000000..e3072b4442
--- /dev/null
+++ b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala
@@ -0,0 +1,494 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming
+
+import java.io.File
+
+import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
+import scala.reflect.ClassTag
+
+import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
+
+import org.apache.spark.streaming.dstream.{TrackStateDStream, TrackStateDStreamImpl}
+import org.apache.spark.util.{ManualClock, Utils}
+import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
+
+class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter {
+
+ private var sc: SparkContext = null
+ private var ssc: StreamingContext = null
+ private var checkpointDir: File = null
+ private val batchDuration = Seconds(1)
+
+ before {
+ StreamingContext.getActive().foreach {
+ _.stop(stopSparkContext = false)
+ }
+ checkpointDir = Utils.createTempDir("checkpoint")
+
+ ssc = new StreamingContext(sc, batchDuration)
+ ssc.checkpoint(checkpointDir.toString)
+ }
+
+ after {
+ StreamingContext.getActive().foreach {
+ _.stop(stopSparkContext = false)
+ }
+ }
+
+ override def beforeAll(): Unit = {
+ val conf = new SparkConf().setMaster("local").setAppName("TrackStateByKeySuite")
+ conf.set("spark.streaming.clock", classOf[ManualClock].getName())
+ sc = new SparkContext(conf)
+ }
+
+ test("state - get, exists, update, remove, ") {
+ var state: StateImpl[Int] = null
+
+ def testState(
+ expectedData: Option[Int],
+ shouldBeUpdated: Boolean = false,
+ shouldBeRemoved: Boolean = false,
+ shouldBeTimingOut: Boolean = false
+ ): Unit = {
+ if (expectedData.isDefined) {
+ assert(state.exists)
+ assert(state.get() === expectedData.get)
+ assert(state.getOption() === expectedData)
+ assert(state.getOption.getOrElse(-1) === expectedData.get)
+ } else {
+ assert(!state.exists)
+ intercept[NoSuchElementException] {
+ state.get()
+ }
+ assert(state.getOption() === None)
+ assert(state.getOption.getOrElse(-1) === -1)
+ }
+
+ assert(state.isTimingOut() === shouldBeTimingOut)
+ if (shouldBeTimingOut) {
+ intercept[IllegalArgumentException] {
+ state.remove()
+ }
+ intercept[IllegalArgumentException] {
+ state.update(-1)
+ }
+ }
+
+ assert(state.isUpdated() === shouldBeUpdated)
+
+ assert(state.isRemoved() === shouldBeRemoved)
+ if (shouldBeRemoved) {
+ intercept[IllegalArgumentException] {
+ state.remove()
+ }
+ intercept[IllegalArgumentException] {
+ state.update(-1)
+ }
+ }
+ }
+
+ state = new StateImpl[Int]()
+ testState(None)
+
+ state.wrap(None)
+ testState(None)
+
+ state.wrap(Some(1))
+ testState(Some(1))
+
+ state.update(2)
+ testState(Some(2), shouldBeUpdated = true)
+
+ state = new StateImpl[Int]()
+ state.update(2)
+ testState(Some(2), shouldBeUpdated = true)
+
+ state.remove()
+ testState(None, shouldBeRemoved = true)
+
+ state.wrapTiminoutState(3)
+ testState(Some(3), shouldBeTimingOut = true)
+ }
+
+ test("trackStateByKey - basic operations with simple API") {
+ val inputData =
+ Seq(
+ Seq(),
+ Seq("a"),
+ Seq("a", "b"),
+ Seq("a", "b", "c"),
+ Seq("a", "b"),
+ Seq("a"),
+ Seq()
+ )
+
+ val outputData =
+ Seq(
+ Seq(),
+ Seq(1),
+ Seq(2, 1),
+ Seq(3, 2, 1),
+ Seq(4, 3),
+ Seq(5),
+ Seq()
+ )
+
+ val stateData =
+ Seq(
+ Seq(),
+ Seq(("a", 1)),
+ Seq(("a", 2), ("b", 1)),
+ Seq(("a", 3), ("b", 2), ("c", 1)),
+ Seq(("a", 4), ("b", 3), ("c", 1)),
+ Seq(("a", 5), ("b", 3), ("c", 1)),
+ Seq(("a", 5), ("b", 3), ("c", 1))
+ )
+
+ // state maintains running count, and updated count is returned
+ val trackStateFunc = (value: Option[Int], state: State[Int]) => {
+ val sum = value.getOrElse(0) + state.getOption.getOrElse(0)
+ state.update(sum)
+ sum
+ }
+
+ testOperation[String, Int, Int](
+ inputData, StateSpec.function(trackStateFunc), outputData, stateData)
+ }
+
+ test("trackStateByKey - basic operations with advanced API") {
+ val inputData =
+ Seq(
+ Seq(),
+ Seq("a"),
+ Seq("a", "b"),
+ Seq("a", "b", "c"),
+ Seq("a", "b"),
+ Seq("a"),
+ Seq()
+ )
+
+ val outputData =
+ Seq(
+ Seq(),
+ Seq("aa"),
+ Seq("aa", "bb"),
+ Seq("aa", "bb", "cc"),
+ Seq("aa", "bb"),
+ Seq("aa"),
+ Seq()
+ )
+
+ val stateData =
+ Seq(
+ Seq(),
+ Seq(("a", 1)),
+ Seq(("a", 2), ("b", 1)),
+ Seq(("a", 3), ("b", 2), ("c", 1)),
+ Seq(("a", 4), ("b", 3), ("c", 1)),
+ Seq(("a", 5), ("b", 3), ("c", 1)),
+ Seq(("a", 5), ("b", 3), ("c", 1))
+ )
+
+ // state maintains running count, key string doubled and returned
+ val trackStateFunc = (batchTime: Time, key: String, value: Option[Int], state: State[Int]) => {
+ val sum = value.getOrElse(0) + state.getOption.getOrElse(0)
+ state.update(sum)
+ Some(key * 2)
+ }
+
+ testOperation(inputData, StateSpec.function(trackStateFunc), outputData, stateData)
+ }
+
+ test("trackStateByKey - type inferencing and class tags") {
+
+ // Simple track state function with value as Int, state as Double and emitted type as Double
+ val simpleFunc = (value: Option[Int], state: State[Double]) => {
+ 0L
+ }
+
+ // Advanced track state function with key as String, value as Int, state as Double and
+ // emitted type as Double
+ val advancedFunc = (time: Time, key: String, value: Option[Int], state: State[Double]) => {
+ Some(0L)
+ }
+
+ def testTypes(dstream: TrackStateDStream[_, _, _, _]): Unit = {
+ val dstreamImpl = dstream.asInstanceOf[TrackStateDStreamImpl[_, _, _, _]]
+ assert(dstreamImpl.keyClass === classOf[String])
+ assert(dstreamImpl.valueClass === classOf[Int])
+ assert(dstreamImpl.stateClass === classOf[Double])
+ assert(dstreamImpl.emittedClass === classOf[Long])
+ }
+
+ val inputStream = new TestInputStream[(String, Int)](ssc, Seq.empty, numPartitions = 2)
+
+ // Defining StateSpec inline with trackStateByKey and simple function implicitly gets the types
+ val simpleFunctionStateStream1 = inputStream.trackStateByKey(
+ StateSpec.function(simpleFunc).numPartitions(1))
+ testTypes(simpleFunctionStateStream1)
+
+ // Separately defining StateSpec with simple function requires explicitly specifying types
+ val simpleFuncSpec = StateSpec.function[String, Int, Double, Long](simpleFunc)
+ val simpleFunctionStateStream2 = inputStream.trackStateByKey(simpleFuncSpec)
+ testTypes(simpleFunctionStateStream2)
+
+ // Separately defining StateSpec with advanced function implicitly gets the types
+ val advFuncSpec1 = StateSpec.function(advancedFunc)
+ val advFunctionStateStream1 = inputStream.trackStateByKey(advFuncSpec1)
+ testTypes(advFunctionStateStream1)
+
+ // Defining StateSpec inline with trackStateByKey and advanced func implicitly gets the types
+ val advFunctionStateStream2 = inputStream.trackStateByKey(
+ StateSpec.function(simpleFunc).numPartitions(1))
+ testTypes(advFunctionStateStream2)
+
+ // Defining StateSpec inline with trackStateByKey and advanced func implicitly gets the types
+ val advFuncSpec2 = StateSpec.function[String, Int, Double, Long](advancedFunc)
+ val advFunctionStateStream3 = inputStream.trackStateByKey[Double, Long](advFuncSpec2)
+ testTypes(advFunctionStateStream3)
+ }
+
+ test("trackStateByKey - states as emitted records") {
+ val inputData =
+ Seq(
+ Seq(),
+ Seq("a"),
+ Seq("a", "b"),
+ Seq("a", "b", "c"),
+ Seq("a", "b"),
+ Seq("a"),
+ Seq()
+ )
+
+ val outputData =
+ Seq(
+ Seq(),
+ Seq(("a", 1)),
+ Seq(("a", 2), ("b", 1)),
+ Seq(("a", 3), ("b", 2), ("c", 1)),
+ Seq(("a", 4), ("b", 3)),
+ Seq(("a", 5)),
+ Seq()
+ )
+
+ val stateData =
+ Seq(
+ Seq(),
+ Seq(("a", 1)),
+ Seq(("a", 2), ("b", 1)),
+ Seq(("a", 3), ("b", 2), ("c", 1)),
+ Seq(("a", 4), ("b", 3), ("c", 1)),
+ Seq(("a", 5), ("b", 3), ("c", 1)),
+ Seq(("a", 5), ("b", 3), ("c", 1))
+ )
+
+ val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => {
+ val sum = value.getOrElse(0) + state.getOption.getOrElse(0)
+ val output = (key, sum)
+ state.update(sum)
+ Some(output)
+ }
+
+ testOperation(inputData, StateSpec.function(trackStateFunc), outputData, stateData)
+ }
+
+ test("trackStateByKey - initial states, with nothing emitted") {
+
+ val initialState = Seq(("a", 5), ("b", 10), ("c", -20), ("d", 0))
+
+ val inputData =
+ Seq(
+ Seq(),
+ Seq("a"),
+ Seq("a", "b"),
+ Seq("a", "b", "c"),
+ Seq("a", "b"),
+ Seq("a"),
+ Seq()
+ )
+
+ val outputData = Seq.fill(inputData.size)(Seq.empty[Int])
+
+ val stateData =
+ Seq(
+ Seq(("a", 5), ("b", 10), ("c", -20), ("d", 0)),
+ Seq(("a", 6), ("b", 10), ("c", -20), ("d", 0)),
+ Seq(("a", 7), ("b", 11), ("c", -20), ("d", 0)),
+ Seq(("a", 8), ("b", 12), ("c", -19), ("d", 0)),
+ Seq(("a", 9), ("b", 13), ("c", -19), ("d", 0)),
+ Seq(("a", 10), ("b", 13), ("c", -19), ("d", 0)),
+ Seq(("a", 10), ("b", 13), ("c", -19), ("d", 0))
+ )
+
+ val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => {
+ val sum = value.getOrElse(0) + state.getOption.getOrElse(0)
+ val output = (key, sum)
+ state.update(sum)
+ None.asInstanceOf[Option[Int]]
+ }
+
+ val trackStateSpec = StateSpec.function(trackStateFunc).initialState(sc.makeRDD(initialState))
+ testOperation(inputData, trackStateSpec, outputData, stateData)
+ }
+
+ test("trackStateByKey - state removing") {
+ val inputData =
+ Seq(
+ Seq(),
+ Seq("a"),
+ Seq("a", "b"), // a will be removed
+ Seq("a", "b", "c"), // b will be removed
+ Seq("a", "b", "c"), // a and c will be removed
+ Seq("a", "b"), // b will be removed
+ Seq("a"), // a will be removed
+ Seq()
+ )
+
+ // States that were removed
+ val outputData =
+ Seq(
+ Seq(),
+ Seq(),
+ Seq("a"),
+ Seq("b"),
+ Seq("a", "c"),
+ Seq("b"),
+ Seq("a"),
+ Seq()
+ )
+
+ val stateData =
+ Seq(
+ Seq(),
+ Seq(("a", 1)),
+ Seq(("b", 1)),
+ Seq(("a", 1), ("c", 1)),
+ Seq(("b", 1)),
+ Seq(("a", 1)),
+ Seq(),
+ Seq()
+ )
+
+ val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => {
+ if (state.exists) {
+ state.remove()
+ Some(key)
+ } else {
+ state.update(value.get)
+ None
+ }
+ }
+
+ testOperation(
+ inputData, StateSpec.function(trackStateFunc).numPartitions(1), outputData, stateData)
+ }
+
+ test("trackStateByKey - state timing out") {
+ val inputData =
+ Seq(
+ Seq("a", "b", "c"),
+ Seq("a", "b"),
+ Seq("a"),
+ Seq(), // c will time out
+ Seq(), // b will time out
+ Seq("a") // a will not time out
+ ) ++ Seq.fill(20)(Seq("a")) // a will continue to stay active
+
+ val trackStateFunc = (time: Time, key: String, value: Option[Int], state: State[Int]) => {
+ if (value.isDefined) {
+ state.update(1)
+ }
+ if (state.isTimingOut) {
+ Some(key)
+ } else {
+ None
+ }
+ }
+
+ val (collectedOutputs, collectedStateSnapshots) = getOperationOutput(
+ inputData, StateSpec.function(trackStateFunc).timeout(Seconds(3)), 20)
+
+ // b and c should be emitted once each, when they were marked as expired
+ assert(collectedOutputs.flatten.sorted === Seq("b", "c"))
+
+ // States for a, b, c should be defined at one point of time
+ assert(collectedStateSnapshots.exists {
+ _.toSet == Set(("a", 1), ("b", 1), ("c", 1))
+ })
+
+ // Finally state should be defined only for a
+ assert(collectedStateSnapshots.last.toSet === Set(("a", 1)))
+ }
+
+
+ private def testOperation[K: ClassTag, S: ClassTag, T: ClassTag](
+ input: Seq[Seq[K]],
+ trackStateSpec: StateSpec[K, Int, S, T],
+ expectedOutputs: Seq[Seq[T]],
+ expectedStateSnapshots: Seq[Seq[(K, S)]]
+ ): Unit = {
+ require(expectedOutputs.size == expectedStateSnapshots.size)
+
+ val (collectedOutputs, collectedStateSnapshots) =
+ getOperationOutput(input, trackStateSpec, expectedOutputs.size)
+ assert(expectedOutputs, collectedOutputs, "outputs")
+ assert(expectedStateSnapshots, collectedStateSnapshots, "state snapshots")
+ }
+
+ private def getOperationOutput[K: ClassTag, S: ClassTag, T: ClassTag](
+ input: Seq[Seq[K]],
+ trackStateSpec: StateSpec[K, Int, S, T],
+ numBatches: Int
+ ): (Seq[Seq[T]], Seq[Seq[(K, S)]]) = {
+
+ // Setup the stream computation
+ val inputStream = new TestInputStream(ssc, input, numPartitions = 2)
+ val trackeStateStream = inputStream.map(x => (x, 1)).trackStateByKey(trackStateSpec)
+ val collectedOutputs = new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]]
+ val outputStream = new TestOutputStream(trackeStateStream, collectedOutputs)
+ val collectedStateSnapshots = new ArrayBuffer[Seq[(K, S)]] with SynchronizedBuffer[Seq[(K, S)]]
+ val stateSnapshotStream = new TestOutputStream(
+ trackeStateStream.stateSnapshots(), collectedStateSnapshots)
+ outputStream.register()
+ stateSnapshotStream.register()
+
+ val batchCounter = new BatchCounter(ssc)
+ ssc.start()
+
+ val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
+ clock.advance(batchDuration.milliseconds * numBatches)
+
+ batchCounter.waitUntilBatchesCompleted(numBatches, 10000)
+ (collectedOutputs, collectedStateSnapshots)
+ }
+
+ private def assert[U](expected: Seq[Seq[U]], collected: Seq[Seq[U]], typ: String) {
+ val debugString = "\nExpected:\n" + expected.mkString("\n") +
+ "\nCollected:\n" + collected.mkString("\n")
+ assert(expected.size === collected.size,
+ s"number of collected $typ (${collected.size}) different from expected (${expected.size})" +
+ debugString)
+ expected.zip(collected).foreach { case (c, e) =>
+ assert(c.toSet === e.toSet,
+ s"collected $typ is different from expected $debugString"
+ )
+ }
+ }
+}
+
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
new file mode 100644
index 0000000000..fc5f26607e
--- /dev/null
+++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.rdd
+
+import scala.collection.mutable.ArrayBuffer
+import scala.reflect.ClassTag
+
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.streaming.{Time, State}
+import org.apache.spark.{HashPartitioner, SparkConf, SparkContext, SparkFunSuite}
+
+class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
+
+ private var sc = new SparkContext(
+ new SparkConf().setMaster("local").setAppName("TrackStateRDDSuite"))
+
+ override def afterAll(): Unit = {
+ sc.stop()
+ }
+
+ test("creation from pair RDD") {
+ val data = Seq((1, "1"), (2, "2"), (3, "3"))
+ val partitioner = new HashPartitioner(10)
+ val rdd = TrackStateRDD.createFromPairRDD[Int, Int, String, Int](
+ sc.parallelize(data), partitioner, Time(123))
+ assertRDD[Int, Int, String, Int](rdd, data.map { x => (x._1, x._2, 123)}.toSet, Set.empty)
+ assert(rdd.partitions.size === partitioner.numPartitions)
+
+ assert(rdd.partitioner === Some(partitioner))
+ }
+
+ test("states generated by TrackStateRDD") {
+ val initStates = Seq(("k1", 0), ("k2", 0))
+ val initTime = 123
+ val initStateWthTime = initStates.map { x => (x._1, x._2, initTime) }.toSet
+ val partitioner = new HashPartitioner(2)
+ val initStateRDD = TrackStateRDD.createFromPairRDD[String, Int, Int, Int](
+ sc.parallelize(initStates), partitioner, Time(initTime)).persist()
+ assertRDD(initStateRDD, initStateWthTime, Set.empty)
+
+ val updateTime = 345
+
+ /**
+ * Test that the test state RDD, when operated with new data,
+ * creates a new state RDD with expected states
+ */
+ def testStateUpdates(
+ testStateRDD: TrackStateRDD[String, Int, Int, Int],
+ testData: Seq[(String, Int)],
+ expectedStates: Set[(String, Int, Int)]): TrackStateRDD[String, Int, Int, Int] = {
+
+ // Persist the test TrackStateRDD so that its not recomputed while doing the next operation.
+ // This is to make sure that we only track which state keys are being touched in the next op.
+ testStateRDD.persist().count()
+
+ // To track which keys are being touched
+ TrackStateRDDSuite.touchedStateKeys.clear()
+
+ val trackingFunc = (time: Time, key: String, data: Option[Int], state: State[Int]) => {
+
+ // Track the key that has been touched
+ TrackStateRDDSuite.touchedStateKeys += key
+
+ // If the data is 0, do not do anything with the state
+ // else if the data is 1, increment the state if it exists, or set new state to 0
+ // else if the data is 2, remove the state if it exists
+ data match {
+ case Some(1) =>
+ if (state.exists()) { state.update(state.get + 1) }
+ else state.update(0)
+ case Some(2) =>
+ state.remove()
+ case _ =>
+ }
+ None.asInstanceOf[Option[Int]] // Do not return anything, not being tested
+ }
+ val newDataRDD = sc.makeRDD(testData).partitionBy(testStateRDD.partitioner.get)
+
+ // Assert that the new state RDD has expected state data
+ val newStateRDD = assertOperation(
+ testStateRDD, newDataRDD, trackingFunc, updateTime, expectedStates, Set.empty)
+
+ // Assert that the function was called only for the keys present in the data
+ assert(TrackStateRDDSuite.touchedStateKeys.size === testData.size,
+ "More number of keys are being touched than that is expected")
+ assert(TrackStateRDDSuite.touchedStateKeys.toSet === testData.toMap.keys,
+ "Keys not in the data are being touched unexpectedly")
+
+ // Assert that the test RDD's data has not changed
+ assertRDD(initStateRDD, initStateWthTime, Set.empty)
+ newStateRDD
+ }
+
+ // Test no-op, no state should change
+ testStateUpdates(initStateRDD, Seq(), initStateWthTime) // should not scan any state
+ testStateUpdates(
+ initStateRDD, Seq(("k1", 0)), initStateWthTime) // should not update existing state
+ testStateUpdates(
+ initStateRDD, Seq(("k3", 0)), initStateWthTime) // should not create new state
+
+ // Test creation of new state
+ val rdd1 = testStateUpdates(initStateRDD, Seq(("k3", 1)), // should create k3's state as 0
+ Set(("k1", 0, initTime), ("k2", 0, initTime), ("k3", 0, updateTime)))
+
+ val rdd2 = testStateUpdates(rdd1, Seq(("k4", 1)), // should create k4's state as 0
+ Set(("k1", 0, initTime), ("k2", 0, initTime), ("k3", 0, updateTime), ("k4", 0, updateTime)))
+
+ // Test updating of state
+ val rdd3 = testStateUpdates(
+ initStateRDD, Seq(("k1", 1)), // should increment k1's state 0 -> 1
+ Set(("k1", 1, updateTime), ("k2", 0, initTime)))
+
+ val rdd4 = testStateUpdates(rdd3,
+ Seq(("x", 0), ("k2", 1), ("k2", 1), ("k3", 1)), // should update k2, 0 -> 2 and create k3, 0
+ Set(("k1", 1, updateTime), ("k2", 2, updateTime), ("k3", 0, updateTime)))
+
+ val rdd5 = testStateUpdates(
+ rdd4, Seq(("k3", 1)), // should update k3's state 0 -> 2
+ Set(("k1", 1, updateTime), ("k2", 2, updateTime), ("k3", 1, updateTime)))
+
+ // Test removing of state
+ val rdd6 = testStateUpdates( // should remove k1's state
+ initStateRDD, Seq(("k1", 2)), Set(("k2", 0, initTime)))
+
+ val rdd7 = testStateUpdates( // should remove k2's state
+ rdd6, Seq(("k2", 2), ("k0", 2), ("k3", 1)), Set(("k3", 0, updateTime)))
+
+ val rdd8 = testStateUpdates(
+ rdd7, Seq(("k3", 2)), Set() //
+ )
+ }
+
+ /** Assert whether the `trackStateByKey` operation generates expected results */
+ private def assertOperation[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
+ testStateRDD: TrackStateRDD[K, V, S, T],
+ newDataRDD: RDD[(K, V)],
+ trackStateFunc: (Time, K, Option[V], State[S]) => Option[T],
+ currentTime: Long,
+ expectedStates: Set[(K, S, Int)],
+ expectedEmittedRecords: Set[T],
+ doFullScan: Boolean = false
+ ): TrackStateRDD[K, V, S, T] = {
+
+ val partitionedNewDataRDD = if (newDataRDD.partitioner != testStateRDD.partitioner) {
+ newDataRDD.partitionBy(testStateRDD.partitioner.get)
+ } else {
+ newDataRDD
+ }
+
+ val newStateRDD = new TrackStateRDD[K, V, S, T](
+ testStateRDD, newDataRDD, trackStateFunc, Time(currentTime), None)
+ if (doFullScan) newStateRDD.setFullScan()
+
+ // Persist to make sure that it gets computed only once and we can track precisely how many
+ // state keys the computing touched
+ newStateRDD.persist()
+ assertRDD(newStateRDD, expectedStates, expectedEmittedRecords)
+ newStateRDD
+ }
+
+ /** Assert whether the [[TrackStateRDD]] has the expected state ad emitted records */
+ private def assertRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
+ trackStateRDD: TrackStateRDD[K, V, S, T],
+ expectedStates: Set[(K, S, Int)],
+ expectedEmittedRecords: Set[T]): Unit = {
+ val states = trackStateRDD.flatMap { _.stateMap.getAll() }.collect().toSet
+ val emittedRecords = trackStateRDD.flatMap { _.emittedRecords }.collect().toSet
+ assert(states === expectedStates, "states after track state operation were not as expected")
+ assert(emittedRecords === expectedEmittedRecords,
+ "emitted records after track state operation were not as expected")
+ }
+}
+
+object TrackStateRDDSuite {
+ private val touchedStateKeys = new ArrayBuffer[String]()
+}