diff options
author | Shixiong Zhu <shixiong@databricks.com> | 2015-11-12 17:48:43 -0800 |
---|---|---|
committer | Tathagata Das <tathagata.das1565@gmail.com> | 2015-11-12 17:48:43 -0800 |
commit | 0f1d00a905614bb5eebf260566dbcb831158d445 (patch) | |
tree | 5b46386a0c742cd035549fa26c08da296010e86d /examples/src | |
parent | 41bbd2300472501d69ed46f0407d5ed7cbede4a8 (diff) | |
download | spark-0f1d00a905614bb5eebf260566dbcb831158d445.tar.gz spark-0f1d00a905614bb5eebf260566dbcb831158d445.tar.bz2 spark-0f1d00a905614bb5eebf260566dbcb831158d445.zip |
[SPARK-11663][STREAMING] Add Java API for trackStateByKey
TODO
- [x] Add Java API
- [x] Add API tests
- [x] Add a function test
Author: Shixiong Zhu <shixiong@databricks.com>
Closes #9636 from zsxwing/java-track.
Diffstat (limited to 'examples/src')
2 files changed, 22 insertions, 25 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java index 99b63a2590..c400e4237a 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java @@ -26,18 +26,15 @@ import scala.Tuple2; import com.google.common.base.Optional; import com.google.common.collect.Lists; -import org.apache.spark.HashPartitioner; import org.apache.spark.SparkConf; +import org.apache.spark.api.java.function.*; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.StorageLevels; -import org.apache.spark.api.java.function.FlatMapFunction; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.streaming.Durations; -import org.apache.spark.streaming.api.java.JavaDStream; -import org.apache.spark.streaming.api.java.JavaPairDStream; -import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; -import org.apache.spark.streaming.api.java.JavaStreamingContext; +import org.apache.spark.streaming.State; +import org.apache.spark.streaming.StateSpec; +import org.apache.spark.streaming.Time; +import org.apache.spark.streaming.api.java.*; /** * Counts words cumulatively in UTF8 encoded, '\n' delimited text received from the network every @@ -63,25 +60,12 @@ public class JavaStatefulNetworkWordCount { StreamingExamples.setStreamingLogLevels(); - // Update the cumulative count function - final Function2<List<Integer>, Optional<Integer>, Optional<Integer>> updateFunction = - new Function2<List<Integer>, Optional<Integer>, Optional<Integer>>() { - @Override - public Optional<Integer> call(List<Integer> values, Optional<Integer> state) { - Integer newSum = state.or(0); - for (Integer value : values) { - newSum += value; - } - return Optional.of(newSum); - } - }; - // Create the context with a 1 second batch size SparkConf sparkConf = new SparkConf().setAppName("JavaStatefulNetworkWordCount"); JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, Durations.seconds(1)); ssc.checkpoint("."); - // Initial RDD input to updateStateByKey + // Initial RDD input to trackStateByKey @SuppressWarnings("unchecked") List<Tuple2<String, Integer>> tuples = Arrays.asList(new Tuple2<String, Integer>("hello", 1), new Tuple2<String, Integer>("world", 1)); @@ -105,9 +89,22 @@ public class JavaStatefulNetworkWordCount { } }); + // Update the cumulative count function + final Function4<Time, String, Optional<Integer>, State<Integer>, Optional<Tuple2<String, Integer>>> trackStateFunc = + new Function4<Time, String, Optional<Integer>, State<Integer>, Optional<Tuple2<String, Integer>>>() { + + @Override + public Optional<Tuple2<String, Integer>> call(Time time, String word, Optional<Integer> one, State<Integer> state) { + int sum = one.or(0) + (state.exists() ? state.get() : 0); + Tuple2<String, Integer> output = new Tuple2<String, Integer>(word, sum); + state.update(sum); + return Optional.of(output); + } + }; + // This will give a Dstream made of state (which is the cumulative count of the words) - JavaPairDStream<String, Integer> stateDstream = wordsDstream.updateStateByKey(updateFunction, - new HashPartitioner(ssc.sparkContext().defaultParallelism()), initialRDD); + JavaTrackStateDStream<String, Integer, Integer, Tuple2<String, Integer>> stateDstream = + wordsDstream.trackStateByKey(StateSpec.function(trackStateFunc).initialState(initialRDD)); stateDstream.print(); ssc.start(); diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala index be2ae0b473..a4f847f118 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala @@ -49,7 +49,7 @@ object StatefulNetworkWordCount { val ssc = new StreamingContext(sparkConf, Seconds(1)) ssc.checkpoint(".") - // Initial RDD input to updateStateByKey + // Initial RDD input to trackStateByKey val initialRDD = ssc.sparkContext.parallelize(List(("hello", 1), ("world", 1))) // Create a ReceiverInputDStream on target ip:port and count the |