aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
Diffstat (limited to 'examples')
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala25
1 files changed, 10 insertions, 15 deletions
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala
index 02ba1c2eed..be2ae0b473 100644
--- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala
@@ -44,18 +44,6 @@ object StatefulNetworkWordCount {
StreamingExamples.setStreamingLogLevels()
- val updateFunc = (values: Seq[Int], state: Option[Int]) => {
- val currentCount = values.sum
-
- val previousCount = state.getOrElse(0)
-
- Some(currentCount + previousCount)
- }
-
- val newUpdateFunc = (iterator: Iterator[(String, Seq[Int], Option[Int])]) => {
- iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s)))
- }
-
val sparkConf = new SparkConf().setAppName("StatefulNetworkWordCount")
// Create the context with a 1 second batch size
val ssc = new StreamingContext(sparkConf, Seconds(1))
@@ -71,9 +59,16 @@ object StatefulNetworkWordCount {
val wordDstream = words.map(x => (x, 1))
// Update the cumulative count using updateStateByKey
- // This will give a Dstream made of state (which is the cumulative count of the words)
- val stateDstream = wordDstream.updateStateByKey[Int](newUpdateFunc,
- new HashPartitioner (ssc.sparkContext.defaultParallelism), true, initialRDD)
+ // This will give a DStream made of state (which is the cumulative count of the words)
+ val trackStateFunc = (batchTime: Time, word: String, one: Option[Int], state: State[Int]) => {
+ val sum = one.getOrElse(0) + state.getOption.getOrElse(0)
+ val output = (word, sum)
+ state.update(sum)
+ Some(output)
+ }
+
+ val stateDstream = wordDstream.trackStateByKey(
+ StateSpec.function(trackStateFunc).initialState(initialRDD))
stateDstream.print()
ssc.start()
ssc.awaitTermination()