diff options
Diffstat (limited to 'streaming')
-rw-r--r-- | streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala | 29 | ||||
-rw-r--r-- | streaming/src/test/scala/spark/streaming/JavaAPISuite.java | 33 |
2 files changed, 61 insertions, 1 deletions
diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala index 0cccb083c5..49a0f27b5b 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala @@ -15,6 +15,7 @@ import org.apache.hadoop.conf.Configuration import spark.api.java.JavaPairRDD import spark.storage.StorageLevel import java.lang +import com.google.common.base.Optional class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( implicit val kManifiest: ClassManifest[K], @@ -419,7 +420,33 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( dstream.countByKeyAndWindow(windowDuration, slideDuration, numPartitions) } - // TODO: Update State + /** + * Create a new "state" DStream where the state for each key is updated by applying + * the given function on the previous state of the key and the new values of each key. + * Hash partitioning is used to generate the RDDs with Spark's default number of partitions. + * @param updateFunc State update function. If `this` function returns None, then + * corresponding state key-value pair will be eliminated. + * @tparam S State type + */ + def updateStateByKey[S](updateFunc: JFunction2[JList[V], Optional[S], Optional[S]]) + : JavaPairDStream[K, S] = { + implicit val cm: ClassManifest[S] = + implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[S]] + + def scalaFunc(values: Seq[V], state: Option[S]): Option[S] = { + val list: JList[V] = values + val scalaState: Optional[S] = state match { + case Some(s) => Optional.of(s) + case _ => Optional.absent() + } + val result: Optional[S] = updateFunc.apply(list, scalaState) + result.isPresent match { + case true => Some(result.get()) + case _ => None + } + } + dstream.updateStateByKey(scalaFunc _) + } def mapValues[U](f: JFunction[V, U]): JavaPairDStream[K, U] = { implicit val cm: ClassManifest[U] = diff --git a/streaming/src/test/scala/spark/streaming/JavaAPISuite.java b/streaming/src/test/scala/spark/streaming/JavaAPISuite.java index 7475b9536b..d95ab485f8 100644 --- a/streaming/src/test/scala/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/scala/spark/streaming/JavaAPISuite.java @@ -1,5 +1,6 @@ package spark.streaming; +import com.google.common.base.Optional; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; @@ -552,6 +553,38 @@ public class JavaAPISuite implements Serializable { } @Test + public void testUpdateStateByKey() { + List<List<Tuple2<String, Integer>>> inputData = stringIntKVStream; + + List<List<Tuple2<String, Integer>>> expected = Arrays.asList( + Arrays.asList(new Tuple2<String, Integer>("california", 4), + new Tuple2<String, Integer>("new york", 5)), + Arrays.asList(new Tuple2<String, Integer>("california", 14), + new Tuple2<String, Integer>("new york", 9)), + Arrays.asList(new Tuple2<String, Integer>("california", 10), + new Tuple2<String, Integer>("new york", 4))); + + JavaDStream<Tuple2<String, Integer>> stream = JavaTestUtils.attachTestInputStream(sc, inputData, 1); + JavaPairDStream<String, Integer> pairStream = JavaPairDStream.fromJavaDStream(stream); + + JavaPairDStream<String, Integer> updated = pairStream.updateStateByKey( + new Function2<List<Integer>, Optional<Integer>, Optional<Integer>>(){ + @Override + public Optional<Integer> call(List<Integer> values, Optional<Integer> state) { + int out = 0; + for (Integer v: values) { + out = out + v; + } + return Optional.of(out); + } + }); + JavaTestUtils.attachTestOutputStream(updated); + List<List<Tuple2<String, Integer>>> result = JavaTestUtils.runStreams(sc, 3, 3); + + Assert.assertEquals(expected, result); + } + + @Test public void testReduceByKeyAndWindowWithInverse() { List<List<Tuple2<String, Integer>>> inputData = stringIntKVStream; |