aboutsummaryrefslogtreecommitdiff
path: root/streaming/src
diff options
context:
space:
mode:
authorAaditya Ramesh <aramesh@conviva.com>2016-11-15 13:01:01 -0800
committerShixiong Zhu <shixiong@databricks.com>2016-11-15 13:01:01 -0800
commit6f9e598ccf92f6272bbfb56ac56d3101387131b9 (patch)
treec8af7879458ea0cb0c022e6b11b437ff874dbfcf /streaming/src
parent745ab8bc50da89c42b297de9dcb833e5f2074481 (diff)
downloadspark-6f9e598ccf92f6272bbfb56ac56d3101387131b9.tar.gz
spark-6f9e598ccf92f6272bbfb56ac56d3101387131b9.tar.bz2
spark-6f9e598ccf92f6272bbfb56ac56d3101387131b9.zip
[SPARK-13027][STREAMING] Added batch time as a parameter to updateStateByKey
Added RDD batch time as an input parameter to the update function in updateStateByKey. Author: Aaditya Ramesh <aramesh@conviva.com> Closes #11122 from aramesh117/SPARK-13027.
Diffstat (limited to 'streaming/src')
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala40
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala28
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala66
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala12
4 files changed, 126 insertions, 20 deletions
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
index 2f2a6d13dd..ac739411fd 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
@@ -453,9 +453,12 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)])
def updateStateByKey[S: ClassTag](
updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)],
partitioner: Partitioner,
- rememberPartitioner: Boolean
- ): DStream[(K, S)] = ssc.withScope {
- new StateDStream(self, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner, None)
+ rememberPartitioner: Boolean): DStream[(K, S)] = ssc.withScope {
+ val cleanedFunc = ssc.sc.clean(updateFunc)
+ val newUpdateFunc = (_: Time, it: Iterator[(K, Seq[V], Option[S])]) => {
+ cleanedFunc(it)
+ }
+ new StateDStream(self, newUpdateFunc, partitioner, rememberPartitioner, None)
}
/**
@@ -499,10 +502,33 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)])
updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)],
partitioner: Partitioner,
rememberPartitioner: Boolean,
- initialRDD: RDD[(K, S)]
- ): DStream[(K, S)] = ssc.withScope {
- new StateDStream(self, ssc.sc.clean(updateFunc), partitioner,
- rememberPartitioner, Some(initialRDD))
+ initialRDD: RDD[(K, S)]): DStream[(K, S)] = ssc.withScope {
+ val cleanedFunc = ssc.sc.clean(updateFunc)
+ val newUpdateFunc = (_: Time, it: Iterator[(K, Seq[V], Option[S])]) => {
+ cleanedFunc(it)
+ }
+ new StateDStream(self, newUpdateFunc, partitioner, rememberPartitioner, Some(initialRDD))
+ }
+
+ /**
+ * Return a new "state" DStream where the state for each key is updated by applying
+ * the given function on the previous state of the key and the new values of the key.
+ * org.apache.spark.Partitioner is used to control the partitioning of each RDD.
+ * @param updateFunc State update function. If `this` function returns None, then
+ * corresponding state key-value pair will be eliminated.
+ * @param partitioner Partitioner for controlling the partitioning of each RDD in the new
+ * DStream.
+ * @tparam S State type
+ */
+ def updateStateByKey[S: ClassTag](updateFunc: (Time, K, Seq[V], Option[S]) => Option[S],
+ partitioner: Partitioner,
+ rememberPartitioner: Boolean,
+ initialRDD: Option[RDD[(K, S)]] = None): DStream[(K, S)] = ssc.withScope {
+ val cleanedFunc = ssc.sc.clean(updateFunc)
+ val newUpdateFunc = (time: Time, iterator: Iterator[(K, Seq[V], Option[S])]) => {
+ iterator.flatMap(t => cleanedFunc(time, t._1, t._2, t._3).map(s => (t._1, s)))
+ }
+ new StateDStream(self, newUpdateFunc, partitioner, rememberPartitioner, initialRDD)
}
/**
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala
index 8efb09a8ce..5bf1dabf08 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala
@@ -27,7 +27,7 @@ import org.apache.spark.streaming.{Duration, Time}
private[streaming]
class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
parent: DStream[(K, V)],
- updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)],
+ updateFunc: (Time, Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)],
partitioner: Partitioner,
preservePartitioning: Boolean,
initialRDD: Option[RDD[(K, S)]]
@@ -41,8 +41,10 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
override val mustCheckpoint = true
- private [this] def computeUsingPreviousRDD (
- parentRDD: RDD[(K, V)], prevStateRDD: RDD[(K, S)]) = {
+ private [this] def computeUsingPreviousRDD(
+ batchTime: Time,
+ parentRDD: RDD[(K, V)],
+ prevStateRDD: RDD[(K, S)]) = {
// Define the function for the mapPartition operation on cogrouped RDD;
// first map the cogrouped tuple to tuples of required type,
// and then apply the update function
@@ -53,7 +55,7 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
val headOption = if (itr.hasNext) Some(itr.next()) else None
(t._1, t._2._1.toSeq, headOption)
}
- updateFuncLocal(i)
+ updateFuncLocal(batchTime, i)
}
val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner)
val stateRDD = cogroupedRDD.mapPartitions(finalFunc, preservePartitioning)
@@ -68,15 +70,14 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
case Some(prevStateRDD) => // If previous state RDD exists
// Try to get the parent RDD
parent.getOrCompute(validTime) match {
- case Some(parentRDD) => // If parent RDD exists, then compute as usual
- computeUsingPreviousRDD(parentRDD, prevStateRDD)
- case None => // If parent RDD does not exist
-
+ case Some(parentRDD) => // If parent RDD exists, then compute as usual
+ computeUsingPreviousRDD (validTime, parentRDD, prevStateRDD)
+ case None => // If parent RDD does not exist
// Re-apply the update function to the old state RDD
val updateFuncLocal = updateFunc
val finalFunc = (iterator: Iterator[(K, S)]) => {
val i = iterator.map(t => (t._1, Seq[V](), Option(t._2)))
- updateFuncLocal(i)
+ updateFuncLocal(validTime, i)
}
val stateRDD = prevStateRDD.mapPartitions(finalFunc, preservePartitioning)
Some(stateRDD)
@@ -93,15 +94,16 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
// and then apply the update function
val updateFuncLocal = updateFunc
val finalFunc = (iterator: Iterator[(K, Iterable[V])]) => {
- updateFuncLocal(iterator.map(tuple => (tuple._1, tuple._2.toSeq, None)))
+ updateFuncLocal (validTime,
+ iterator.map (tuple => (tuple._1, tuple._2.toSeq, None)))
}
val groupedRDD = parentRDD.groupByKey(partitioner)
val sessionRDD = groupedRDD.mapPartitions(finalFunc, preservePartitioning)
// logDebug("Generating state RDD for time " + validTime + " (first)")
- Some(sessionRDD)
- case Some(initialStateRDD) =>
- computeUsingPreviousRDD(parentRDD, initialStateRDD)
+ Some (sessionRDD)
+ case Some (initialStateRDD) =>
+ computeUsingPreviousRDD(validTime, parentRDD, initialStateRDD)
}
case None => // If parent RDD does not exist, then nothing to do!
// logDebug("Not generating state RDD (no previous state, no parent)")
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
index cfcbdc7c38..4e702bbb92 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
@@ -471,6 +471,72 @@ class BasicOperationsSuite extends TestSuiteBase {
testOperation(inputData, updateStateOperation, outputData, true)
}
+ test("updateStateByKey - testing time stamps as input") {
+ type StreamingState = Long
+ val initial: Seq[(String, StreamingState)] = Seq(("a", 0L), ("c", 0L))
+
+ val inputData =
+ Seq(
+ Seq("a"),
+ Seq("a", "b"),
+ Seq("a", "b", "c"),
+ Seq("a", "b"),
+ Seq("a"),
+ Seq()
+ )
+
+ // a -> 1000, 3000, 6000, 10000, 15000, 15000
+ // b -> 0, 2000, 5000, 9000, 9000, 9000
+ // c -> 1000, 1000, 3000, 3000, 3000, 3000
+
+ val outputData: Seq[Seq[(String, StreamingState)]] = Seq(
+ Seq(
+ ("a", 1000L),
+ ("c", 0L)), // t = 1000
+ Seq(
+ ("a", 3000L),
+ ("b", 2000L),
+ ("c", 0L)), // t = 2000
+ Seq(
+ ("a", 6000L),
+ ("b", 5000L),
+ ("c", 3000L)), // t = 3000
+ Seq(
+ ("a", 10000L),
+ ("b", 9000L),
+ ("c", 3000L)), // t = 4000
+ Seq(
+ ("a", 15000L),
+ ("b", 9000L),
+ ("c", 3000L)), // t = 5000
+ Seq(
+ ("a", 15000L),
+ ("b", 9000L),
+ ("c", 3000L)) // t = 6000
+ )
+
+ val updateStateOperation = (s: DStream[String]) => {
+ val initialRDD = s.context.sparkContext.makeRDD(initial)
+ val updateFunc = (time: Time,
+ key: String,
+ values: Seq[Int],
+ state: Option[StreamingState]) => {
+ // Update only if we receive values for this key during the batch.
+ if (values.nonEmpty) {
+ Option(time.milliseconds + state.getOrElse(0L))
+ } else {
+ Option(state.getOrElse(0L))
+ }
+ }
+ s.map(x => (x, 1)).updateStateByKey[StreamingState](updateFunc = updateFunc,
+ partitioner = new HashPartitioner (numInputPartitions), rememberPartitioner = false,
+ initialRDD = Option(initialRDD))
+ }
+
+ testOperation(input = inputData, operation = updateStateOperation,
+ expectedOutput = outputData, useSet = true)
+ }
+
test("updateStateByKey - with initial value RDD") {
val initial = Seq(("a", 1), ("c", 2))
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala
index 1fc34f569f..2ab600ab81 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala
@@ -164,6 +164,10 @@ class DStreamClosureSuite extends SparkFunSuite with BeforeAndAfterAll {
private def testUpdateStateByKey(ds: DStream[(Int, Int)]): Unit = {
val updateF1 = (_: Seq[Int], _: Option[Int]) => { return; Some(1) }
val updateF2 = (_: Iterator[(Int, Seq[Int], Option[Int])]) => { return; Seq((1, 1)).toIterator }
+ val updateF3 = (_: Time, _: Int, _: Seq[Int], _: Option[Int]) => {
+ return
+ Option(1)
+ }
val initialRDD = ds.ssc.sparkContext.emptyRDD[Int].map { i => (i, i) }
expectCorrectException { ds.updateStateByKey(updateF1) }
expectCorrectException { ds.updateStateByKey(updateF1, 5) }
@@ -177,6 +181,14 @@ class DStreamClosureSuite extends SparkFunSuite with BeforeAndAfterAll {
expectCorrectException {
ds.updateStateByKey(updateF2, new HashPartitioner(5), true, initialRDD)
}
+ expectCorrectException {
+ ds.updateStateByKey(
+ updateFunc = updateF3,
+ partitioner = new HashPartitioner(5),
+ rememberPartitioner = true,
+ initialRDD = Option(initialRDD)
+ )
+ }
}
private def testMapValues(ds: DStream[(Int, Int)]): Unit = expectCorrectException {
ds.mapValues { _ => return; 1 }