aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala13
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala19
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala49
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala70
-rw-r--r--streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java53
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala74
6 files changed, 240 insertions, 38 deletions
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala
index a4d159bf38..514252b89e 100644
--- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala
@@ -18,12 +18,13 @@
package org.apache.spark.examples.streaming
import org.apache.spark.SparkConf
+import org.apache.spark.HashPartitioner
import org.apache.spark.streaming._
import org.apache.spark.streaming.StreamingContext._
/**
* Counts words cumulatively in UTF8 encoded, '\n' delimited text received from the network every
- * second.
+ * second starting with initial value of word count.
* Usage: StatefulNetworkWordCount <hostname> <port>
* <hostname> and <port> describe the TCP server that Spark Streaming would connect to receive
* data.
@@ -51,11 +52,18 @@ object StatefulNetworkWordCount {
Some(currentCount + previousCount)
}
+ val newUpdateFunc = (iterator: Iterator[(String, Seq[Int], Option[Int])]) => {
+ iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s)))
+ }
+
val sparkConf = new SparkConf().setAppName("StatefulNetworkWordCount")
// Create the context with a 1 second batch size
val ssc = new StreamingContext(sparkConf, Seconds(1))
ssc.checkpoint(".")
+ // Initial RDD input to updateStateByKey
+ val initialRDD = ssc.sparkContext.parallelize(List(("hello", 1), ("world", 1)))
+
// Create a NetworkInputDStream on target ip:port and count the
// words in input stream of \n delimited test (eg. generated by 'nc')
val lines = ssc.socketTextStream(args(0), args(1).toInt)
@@ -64,7 +72,8 @@ object StatefulNetworkWordCount {
// Update the cumulative count using updateStateByKey
// This will give a Dstream made of state (which is the cumulative count of the words)
- val stateDstream = wordDstream.updateStateByKey[Int](updateFunc)
+ val stateDstream = wordDstream.updateStateByKey[Int](newUpdateFunc,
+ new HashPartitioner (ssc.sparkContext.defaultParallelism), true, initialRDD)
stateDstream.print()
ssc.start()
ssc.awaitTermination()
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
index 59d4423086..bb44b906d7 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
@@ -492,6 +492,25 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
dstream.updateStateByKey(convertUpdateStateFunction(updateFunc), partitioner)
}
+ /**
+ * Return a new "state" DStream where the state for each key is updated by applying
+ * the given function on the previous state of the key and the new values of the key.
+ * org.apache.spark.Partitioner is used to control the partitioning of each RDD.
+ * @param updateFunc State update function. If `this` function returns None, then
+ * corresponding state key-value pair will be eliminated.
+ * @param partitioner Partitioner for controlling the partitioning of each RDD in the new
+ * DStream.
+ * @param initialRDD initial state value of each key.
+ * @tparam S State type
+ */
+ def updateStateByKey[S](
+ updateFunc: JFunction2[JList[V], Optional[S], Optional[S]],
+ partitioner: Partitioner,
+ initialRDD: JavaPairRDD[K, S]
+ ): JavaPairDStream[K, S] = {
+ implicit val cm: ClassTag[S] = fakeClassTag
+ dstream.updateStateByKey(convertUpdateStateFunction(updateFunc), partitioner, initialRDD)
+ }
/**
* Return a new DStream by applying a map function to the value of each key-value pairs in
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
index 9467595d30..b39f47f04a 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala
@@ -413,7 +413,54 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)])
partitioner: Partitioner,
rememberPartitioner: Boolean
): DStream[(K, S)] = {
- new StateDStream(self, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner)
+ new StateDStream(self, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner, None)
+ }
+
+ /**
+ * Return a new "state" DStream where the state for each key is updated by applying
+ * the given function on the previous state of the key and the new values of the key.
+ * org.apache.spark.Partitioner is used to control the partitioning of each RDD.
+ * @param updateFunc State update function. If `this` function returns None, then
+ * corresponding state key-value pair will be eliminated.
+ * @param partitioner Partitioner for controlling the partitioning of each RDD in the new
+ * DStream.
+ * @param initialRDD initial state value of each key.
+ * @tparam S State type
+ */
+ def updateStateByKey[S: ClassTag](
+ updateFunc: (Seq[V], Option[S]) => Option[S],
+ partitioner: Partitioner,
+ initialRDD: RDD[(K, S)]
+ ): DStream[(K, S)] = {
+ val newUpdateFunc = (iterator: Iterator[(K, Seq[V], Option[S])]) => {
+ iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s)))
+ }
+ updateStateByKey(newUpdateFunc, partitioner, true, initialRDD)
+ }
+
+ /**
+ * Return a new "state" DStream where the state for each key is updated by applying
+ * the given function on the previous state of the key and the new values of each key.
+ * org.apache.spark.Partitioner is used to control the partitioning of each RDD.
+ * @param updateFunc State update function. If `this` function returns None, then
+ * corresponding state key-value pair will be eliminated. Note, that
+ * this function may generate a different a tuple with a different key
+ * than the input key. It is up to the developer to decide whether to
+ * remember the partitioner despite the key being changed.
+ * @param partitioner Partitioner for controlling the partitioning of each RDD in the new
+ * DStream
+ * @param rememberPartitioner Whether to remember the paritioner object in the generated RDDs.
+ * @param initialRDD initial state value of each key.
+ * @tparam S State type
+ */
+ def updateStateByKey[S: ClassTag](
+ updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)],
+ partitioner: Partitioner,
+ rememberPartitioner: Boolean,
+ initialRDD: RDD[(K, S)]
+ ): DStream[(K, S)] = {
+ new StateDStream(self, ssc.sc.clean(updateFunc), partitioner,
+ rememberPartitioner, Some(initialRDD))
}
/**
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala
index 7e22268767..ebb04dd35b 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala
@@ -30,7 +30,8 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
parent: DStream[(K, V)],
updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)],
partitioner: Partitioner,
- preservePartitioning: Boolean
+ preservePartitioning: Boolean,
+ initialRDD : Option[RDD[(K, S)]]
) extends DStream[(K, S)](parent.ssc) {
super.persist(StorageLevel.MEMORY_ONLY_SER)
@@ -41,6 +42,25 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
override val mustCheckpoint = true
+ private [this] def computeUsingPreviousRDD (
+ parentRDD : RDD[(K, V)], prevStateRDD : RDD[(K, S)]) = {
+ // Define the function for the mapPartition operation on cogrouped RDD;
+ // first map the cogrouped tuple to tuples of required type,
+ // and then apply the update function
+ val updateFuncLocal = updateFunc
+ val finalFunc = (iterator: Iterator[(K, (Iterable[V], Iterable[S]))]) => {
+ val i = iterator.map(t => {
+ val itr = t._2._2.iterator
+ val headOption = if(itr.hasNext) Some(itr.next) else None
+ (t._1, t._2._1.toSeq, headOption)
+ })
+ updateFuncLocal(i)
+ }
+ val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner)
+ val stateRDD = cogroupedRDD.mapPartitions(finalFunc, preservePartitioning)
+ Some(stateRDD)
+ }
+
override def compute(validTime: Time): Option[RDD[(K, S)]] = {
// Try to get the previous state RDD
@@ -51,25 +71,7 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
// Try to get the parent RDD
parent.getOrCompute(validTime) match {
case Some(parentRDD) => { // If parent RDD exists, then compute as usual
-
- // Define the function for the mapPartition operation on cogrouped RDD;
- // first map the cogrouped tuple to tuples of required type,
- // and then apply the update function
- val updateFuncLocal = updateFunc
- val finalFunc = (iterator: Iterator[(K, (Iterable[V], Iterable[S]))]) => {
- val i = iterator.map(t => {
- val itr = t._2._2.iterator
- val headOption = itr.hasNext match {
- case true => Some(itr.next())
- case false => None
- }
- (t._1, t._2._1.toSeq, headOption)
- })
- updateFuncLocal(i)
- }
- val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner)
- val stateRDD = cogroupedRDD.mapPartitions(finalFunc, preservePartitioning)
- Some(stateRDD)
+ computeUsingPreviousRDD (parentRDD, prevStateRDD)
}
case None => { // If parent RDD does not exist
@@ -90,19 +92,25 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
// Try to get the parent RDD
parent.getOrCompute(validTime) match {
case Some(parentRDD) => { // If parent RDD exists, then compute as usual
+ initialRDD match {
+ case None => {
+ // Define the function for the mapPartition operation on grouped RDD;
+ // first map the grouped tuple to tuples of required type,
+ // and then apply the update function
+ val updateFuncLocal = updateFunc
+ val finalFunc = (iterator : Iterator[(K, Iterable[V])]) => {
+ updateFuncLocal (iterator.map (tuple => (tuple._1, tuple._2.toSeq, None)))
+ }
- // Define the function for the mapPartition operation on grouped RDD;
- // first map the grouped tuple to tuples of required type,
- // and then apply the update function
- val updateFuncLocal = updateFunc
- val finalFunc = (iterator: Iterator[(K, Iterable[V])]) => {
- updateFuncLocal(iterator.map(tuple => (tuple._1, tuple._2.toSeq, None)))
+ val groupedRDD = parentRDD.groupByKey (partitioner)
+ val sessionRDD = groupedRDD.mapPartitions (finalFunc, preservePartitioning)
+ // logDebug("Generating state RDD for time " + validTime + " (first)")
+ Some (sessionRDD)
+ }
+ case Some (initialStateRDD) => {
+ computeUsingPreviousRDD(parentRDD, initialStateRDD)
+ }
}
-
- val groupedRDD = parentRDD.groupByKey(partitioner)
- val sessionRDD = groupedRDD.mapPartitions(finalFunc, preservePartitioning)
- // logDebug("Generating state RDD for time " + validTime + " (first)")
- Some(sessionRDD)
}
case None => { // If parent RDD does not exist, then nothing to do!
// logDebug("Not generating state RDD (no previous state, no parent)")
diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
index 4efeb8dfbe..ce645fccba 100644
--- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
+++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
@@ -806,15 +806,17 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements Serializa
* Performs an order-invariant comparison of lists representing two RDD streams. This allows
* us to account for ordering variation within individual RDD's which occurs during windowing.
*/
- public static <T extends Comparable<T>> void assertOrderInvariantEquals(
+ public static <T> void assertOrderInvariantEquals(
List<List<T>> expected, List<List<T>> actual) {
+ List<Set<T>> expectedSets = new ArrayList<Set<T>>();
for (List<T> list: expected) {
- Collections.sort(list);
+ expectedSets.add(Collections.unmodifiableSet(new HashSet<T>(list)));
}
+ List<Set<T>> actualSets = new ArrayList<Set<T>>();
for (List<T> list: actual) {
- Collections.sort(list);
+ actualSets.add(Collections.unmodifiableSet(new HashSet<T>(list)));
}
- Assert.assertEquals(expected, actual);
+ Assert.assertEquals(expectedSets, actualSets);
}
@@ -1241,6 +1243,49 @@ public class JavaAPISuite extends LocalJavaStreamingContext implements Serializa
@SuppressWarnings("unchecked")
@Test
+ public void testUpdateStateByKeyWithInitial() {
+ List<List<Tuple2<String, Integer>>> inputData = stringIntKVStream;
+
+ List<Tuple2<String, Integer>> initial = Arrays.asList (
+ new Tuple2<String, Integer> ("california", 1),
+ new Tuple2<String, Integer> ("new york", 2));
+
+ JavaRDD<Tuple2<String, Integer>> tmpRDD = ssc.sparkContext().parallelize(initial);
+ JavaPairRDD<String, Integer> initialRDD = JavaPairRDD.fromJavaRDD (tmpRDD);
+
+ List<List<Tuple2<String, Integer>>> expected = Arrays.asList(
+ Arrays.asList(new Tuple2<String, Integer>("california", 5),
+ new Tuple2<String, Integer>("new york", 7)),
+ Arrays.asList(new Tuple2<String, Integer>("california", 15),
+ new Tuple2<String, Integer>("new york", 11)),
+ Arrays.asList(new Tuple2<String, Integer>("california", 15),
+ new Tuple2<String, Integer>("new york", 11)));
+
+ JavaDStream<Tuple2<String, Integer>> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
+ JavaPairDStream<String, Integer> pairStream = JavaPairDStream.fromJavaDStream(stream);
+
+ JavaPairDStream<String, Integer> updated = pairStream.updateStateByKey(
+ new Function2<List<Integer>, Optional<Integer>, Optional<Integer>>() {
+ @Override
+ public Optional<Integer> call(List<Integer> values, Optional<Integer> state) {
+ int out = 0;
+ if (state.isPresent()) {
+ out = out + state.get();
+ }
+ for (Integer v: values) {
+ out = out + v;
+ }
+ return Optional.of(out);
+ }
+ }, new HashPartitioner(1), initialRDD);
+ JavaTestUtils.attachTestOutputStream(updated);
+ List<List<Tuple2<String, Integer>>> result = JavaTestUtils.runStreams(ssc, 3, 3);
+
+ assertOrderInvariantEquals(expected, result);
+ }
+
+ @SuppressWarnings("unchecked")
+ @Test
public void testReduceByKeyAndWindowWithInverse() {
List<List<Tuple2<String, Integer>>> inputData = stringIntKVStream;
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
index dbab685dc3..30a359677c 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
@@ -30,6 +30,7 @@ import org.apache.spark.rdd.{BlockRDD, RDD}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.StreamingContext._
import org.apache.spark.streaming.dstream.{DStream, WindowedDStream}
+import org.apache.spark.HashPartitioner
class BasicOperationsSuite extends TestSuiteBase {
test("map") {
@@ -350,6 +351,79 @@ class BasicOperationsSuite extends TestSuiteBase {
testOperation(inputData, updateStateOperation, outputData, true)
}
+ test("updateStateByKey - simple with initial value RDD") {
+ val initial = Seq(("a", 1), ("c", 2))
+
+ val inputData =
+ Seq(
+ Seq("a"),
+ Seq("a", "b"),
+ Seq("a", "b", "c"),
+ Seq("a", "b"),
+ Seq("a"),
+ Seq()
+ )
+
+ val outputData =
+ Seq(
+ Seq(("a", 2), ("c", 2)),
+ Seq(("a", 3), ("b", 1), ("c", 2)),
+ Seq(("a", 4), ("b", 2), ("c", 3)),
+ Seq(("a", 5), ("b", 3), ("c", 3)),
+ Seq(("a", 6), ("b", 3), ("c", 3)),
+ Seq(("a", 6), ("b", 3), ("c", 3))
+ )
+
+ val updateStateOperation = (s: DStream[String]) => {
+ val initialRDD = s.context.sparkContext.makeRDD(initial)
+ val updateFunc = (values: Seq[Int], state: Option[Int]) => {
+ Some(values.sum + state.getOrElse(0))
+ }
+ s.map(x => (x, 1)).updateStateByKey[Int](updateFunc,
+ new HashPartitioner (numInputPartitions), initialRDD)
+ }
+
+ testOperation(inputData, updateStateOperation, outputData, true)
+ }
+
+ test("updateStateByKey - with initial value RDD") {
+ val initial = Seq(("a", 1), ("c", 2))
+
+ val inputData =
+ Seq(
+ Seq("a"),
+ Seq("a", "b"),
+ Seq("a", "b", "c"),
+ Seq("a", "b"),
+ Seq("a"),
+ Seq()
+ )
+
+ val outputData =
+ Seq(
+ Seq(("a", 2), ("c", 2)),
+ Seq(("a", 3), ("b", 1), ("c", 2)),
+ Seq(("a", 4), ("b", 2), ("c", 3)),
+ Seq(("a", 5), ("b", 3), ("c", 3)),
+ Seq(("a", 6), ("b", 3), ("c", 3)),
+ Seq(("a", 6), ("b", 3), ("c", 3))
+ )
+
+ val updateStateOperation = (s: DStream[String]) => {
+ val initialRDD = s.context.sparkContext.makeRDD(initial)
+ val updateFunc = (values: Seq[Int], state: Option[Int]) => {
+ Some(values.sum + state.getOrElse(0))
+ }
+ val newUpdateFunc = (iterator: Iterator[(String, Seq[Int], Option[Int])]) => {
+ iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s)))
+ }
+ s.map(x => (x, 1)).updateStateByKey[Int](newUpdateFunc,
+ new HashPartitioner (numInputPartitions), true, initialRDD)
+ }
+
+ testOperation(inputData, updateStateOperation, outputData, true)
+ }
+
test("updateStateByKey - object lifecycle") {
val inputData =
Seq(