From 8e6cbbc6c7434b53c63e19a1c9c2dca1f24de654 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 16 Jan 2013 13:50:02 -0800 Subject: Adding other updateState functions --- .../spark/streaming/api/java/JavaPairDStream.scala | 62 +++++++++++++++++----- 1 file changed, 49 insertions(+), 13 deletions(-) (limited to 'streaming/src') diff --git a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala index 1c5b864ff0..8c76d8c1d8 100644 --- a/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/spark/streaming/api/java/JavaPairDStream.scala @@ -420,6 +420,23 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( dstream.countByKeyAndWindow(windowDuration, slideDuration, numPartitions) } + private def convertUpdateStateFunction[S](in: JFunction2[JList[V], Optional[S], Optional[S]]): + (Seq[V], Option[S]) => Option[S] = { + val scalaFunc: (Seq[V], Option[S]) => Option[S] = (values, state) => { + val list: JList[V] = values + val scalaState: Optional[S] = state match { + case Some(s) => Optional.of(s) + case _ => Optional.absent() + } + val result: Optional[S] = in.apply(list, scalaState) + result.isPresent match { + case true => Some(result.get()) + case _ => None + } + } + scalaFunc + } + /** * Create a new "state" DStream where the state for each key is updated by applying * the given function on the previous state of the key and the new values of each key. @@ -432,20 +449,39 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( : JavaPairDStream[K, S] = { implicit val cm: ClassManifest[S] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[S]] + dstream.updateStateByKey(convertUpdateStateFunction(updateFunc)) + } - val scalaFunc: (Seq[V], Option[S]) => Option[S] = (values, state) => { - val list: JList[V] = values - val scalaState: Optional[S] = state match { - case Some(s) => Optional.of(s) - case _ => Optional.absent() - } - val result: Optional[S] = updateFunc.apply(list, scalaState) - result.isPresent match { - case true => Some(result.get()) - case _ => None - } - } - dstream.updateStateByKey(scalaFunc) + /** + * Create a new "state" DStream where the state for each key is updated by applying + * the given function on the previous state of the key and the new values of each key. + * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. + * @param updateFunc State update function. If `this` function returns None, then + * corresponding state key-value pair will be eliminated. + * @param numPartitions Number of partitions of each RDD in the new DStream. + * @tparam S State type + */ + def updateStateByKey[S: ClassManifest]( + updateFunc: JFunction2[JList[V], Optional[S], Optional[S]], + numPartitions: Int) + : JavaPairDStream[K, S] = { + dstream.updateStateByKey(convertUpdateStateFunction(updateFunc), numPartitions) + } + + /** + * Create a new "state" DStream where the state for each key is updated by applying + * the given function on the previous state of the key and the new values of the key. + * [[spark.Partitioner]] is used to control the partitioning of each RDD. + * @param updateFunc State update function. If `this` function returns None, then + * corresponding state key-value pair will be eliminated. + * @param partitioner Partitioner for controlling the partitioning of each RDD in the new DStream. + * @tparam S State type + */ + def updateStateByKey[S: ClassManifest]( + updateFunc: JFunction2[JList[V], Optional[S], Optional[S]], + partitioner: Partitioner + ): JavaPairDStream[K, S] = { + dstream.updateStateByKey(convertUpdateStateFunction(updateFunc), partitioner) } def mapValues[U](f: JFunction[V, U]): JavaPairDStream[K, U] = { -- cgit v1.2.3