aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorShixiong Zhu <shixiong@databricks.com>2015-11-12 17:48:43 -0800
committerTathagata Das <tathagata.das1565@gmail.com>2015-11-12 17:48:43 -0800
commit0f1d00a905614bb5eebf260566dbcb831158d445 (patch)
tree5b46386a0c742cd035549fa26c08da296010e86d /examples
parent41bbd2300472501d69ed46f0407d5ed7cbede4a8 (diff)
downloadspark-0f1d00a905614bb5eebf260566dbcb831158d445.tar.gz
spark-0f1d00a905614bb5eebf260566dbcb831158d445.tar.bz2
spark-0f1d00a905614bb5eebf260566dbcb831158d445.zip
[SPARK-11663][STREAMING] Add Java API for trackStateByKey
TODO - [x] Add Java API - [x] Add API tests - [x] Add a function test Author: Shixiong Zhu <shixiong@databricks.com> Closes #9636 from zsxwing/java-track.
Diffstat (limited to 'examples')
-rw-r--r--examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java45
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala2
2 files changed, 22 insertions, 25 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java
index 99b63a2590..c400e4237a 100644
--- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java
+++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java
@@ -26,18 +26,15 @@ import scala.Tuple2;
import com.google.common.base.Optional;
import com.google.common.collect.Lists;
-import org.apache.spark.HashPartitioner;
import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.function.*;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.StorageLevels;
-import org.apache.spark.api.java.function.FlatMapFunction;
-import org.apache.spark.api.java.function.Function2;
-import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.streaming.Durations;
-import org.apache.spark.streaming.api.java.JavaDStream;
-import org.apache.spark.streaming.api.java.JavaPairDStream;
-import org.apache.spark.streaming.api.java.JavaReceiverInputDStream;
-import org.apache.spark.streaming.api.java.JavaStreamingContext;
+import org.apache.spark.streaming.State;
+import org.apache.spark.streaming.StateSpec;
+import org.apache.spark.streaming.Time;
+import org.apache.spark.streaming.api.java.*;
/**
* Counts words cumulatively in UTF8 encoded, '\n' delimited text received from the network every
@@ -63,25 +60,12 @@ public class JavaStatefulNetworkWordCount {
StreamingExamples.setStreamingLogLevels();
- // Update the cumulative count function
- final Function2<List<Integer>, Optional<Integer>, Optional<Integer>> updateFunction =
- new Function2<List<Integer>, Optional<Integer>, Optional<Integer>>() {
- @Override
- public Optional<Integer> call(List<Integer> values, Optional<Integer> state) {
- Integer newSum = state.or(0);
- for (Integer value : values) {
- newSum += value;
- }
- return Optional.of(newSum);
- }
- };
-
// Create the context with a 1 second batch size
SparkConf sparkConf = new SparkConf().setAppName("JavaStatefulNetworkWordCount");
JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, Durations.seconds(1));
ssc.checkpoint(".");
- // Initial RDD input to updateStateByKey
+ // Initial RDD input to trackStateByKey
@SuppressWarnings("unchecked")
List<Tuple2<String, Integer>> tuples = Arrays.asList(new Tuple2<String, Integer>("hello", 1),
new Tuple2<String, Integer>("world", 1));
@@ -105,9 +89,22 @@ public class JavaStatefulNetworkWordCount {
}
});
+ // Update the cumulative count function
+ final Function4<Time, String, Optional<Integer>, State<Integer>, Optional<Tuple2<String, Integer>>> trackStateFunc =
+ new Function4<Time, String, Optional<Integer>, State<Integer>, Optional<Tuple2<String, Integer>>>() {
+
+ @Override
+ public Optional<Tuple2<String, Integer>> call(Time time, String word, Optional<Integer> one, State<Integer> state) {
+ int sum = one.or(0) + (state.exists() ? state.get() : 0);
+ Tuple2<String, Integer> output = new Tuple2<String, Integer>(word, sum);
+ state.update(sum);
+ return Optional.of(output);
+ }
+ };
+
// This will give a Dstream made of state (which is the cumulative count of the words)
- JavaPairDStream<String, Integer> stateDstream = wordsDstream.updateStateByKey(updateFunction,
- new HashPartitioner(ssc.sparkContext().defaultParallelism()), initialRDD);
+ JavaTrackStateDStream<String, Integer, Integer, Tuple2<String, Integer>> stateDstream =
+ wordsDstream.trackStateByKey(StateSpec.function(trackStateFunc).initialState(initialRDD));
stateDstream.print();
ssc.start();
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala
index be2ae0b473..a4f847f118 100644
--- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala
@@ -49,7 +49,7 @@ object StatefulNetworkWordCount {
val ssc = new StreamingContext(sparkConf, Seconds(1))
ssc.checkpoint(".")
- // Initial RDD input to updateStateByKey
+ // Initial RDD input to trackStateByKey
val initialRDD = ssc.sparkContext.parallelize(List(("hello", 1), ("world", 1)))
// Create a ReceiverInputDStream on target ip:port and count the