diff options
3 files changed, 18 insertions, 18 deletions
diff --git a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala index 5de57eb2fd..ce1f4ad0a0 100644 --- a/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/spark/streaming/PairDStreamFunctions.scala @@ -156,30 +156,30 @@ extends Serializable { // // def updateStateByKey[S <: AnyRef : ClassManifest]( - updateFunc: (Seq[V], S) => S + updateFunc: (Seq[V], Option[S]) => Option[S] ): DStream[(K, S)] = { updateStateByKey(updateFunc, defaultPartitioner()) } def updateStateByKey[S <: AnyRef : ClassManifest]( - updateFunc: (Seq[V], S) => S, + updateFunc: (Seq[V], Option[S]) => Option[S], numPartitions: Int ): DStream[(K, S)] = { updateStateByKey(updateFunc, defaultPartitioner(numPartitions)) } def updateStateByKey[S <: AnyRef : ClassManifest]( - updateFunc: (Seq[V], S) => S, + updateFunc: (Seq[V], Option[S]) => Option[S], partitioner: Partitioner ): DStream[(K, S)] = { - val func = (iterator: Iterator[(K, Seq[V], S)]) => { - iterator.map(tuple => (tuple._1, updateFunc(tuple._2, tuple._3))) + val newUpdateFunc = (iterator: Iterator[(K, Seq[V], Option[S])]) => { + iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s))) } - updateStateByKey(func, partitioner, true) + updateStateByKey(newUpdateFunc, partitioner, true) } def updateStateByKey[S <: AnyRef : ClassManifest]( - updateFunc: (Iterator[(K, Seq[V], S)]) => Iterator[(K, S)], + updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)], partitioner: Partitioner, rememberPartitioner: Boolean ): DStream[(K, S)] = { diff --git a/streaming/src/main/scala/spark/streaming/StateDStream.scala b/streaming/src/main/scala/spark/streaming/StateDStream.scala index d223f25dfc..3ba8fb45fb 100644 --- a/streaming/src/main/scala/spark/streaming/StateDStream.scala +++ b/streaming/src/main/scala/spark/streaming/StateDStream.scala @@ -8,14 +8,17 @@ import spark.SparkContext._ import spark.storage.StorageLevel -class StateRDD[U: ClassManifest, T: ClassManifest](prev: RDD[T], f: Iterator[T] => Iterator[U], rememberPartitioner: Boolean) - extends MapPartitionsRDD[U, T](prev, f) { +class StateRDD[U: ClassManifest, T: ClassManifest]( + prev: RDD[T], + f: Iterator[T] => Iterator[U], + rememberPartitioner: Boolean + ) extends MapPartitionsRDD[U, T](prev, f) { override val partitioner = if (rememberPartitioner) prev.partitioner else None } class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManifest]( - @transient parent: DStream[(K, V)], - updateFunc: (Iterator[(K, Seq[V], S)]) => Iterator[(K, S)], + parent: DStream[(K, V)], + updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)], partitioner: Partitioner, rememberPartitioner: Boolean ) extends DStream[(K, S)](parent.ssc) { @@ -82,7 +85,7 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife val updateFuncLocal = updateFunc val finalFunc = (iterator: Iterator[(K, (Seq[V], Seq[S]))]) => { val i = iterator.map(t => { - (t._1, t._2._1, t._2._2.headOption.getOrElse(null.asInstanceOf[S])) + (t._1, t._2._1, t._2._2.headOption) }) updateFuncLocal(i) } @@ -108,7 +111,7 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife // and then apply the update function val updateFuncLocal = updateFunc val finalFunc = (iterator: Iterator[(K, Seq[V])]) => { - updateFuncLocal(iterator.map(tuple => (tuple._1, tuple._2, null.asInstanceOf[S]))) + updateFuncLocal(iterator.map(tuple => (tuple._1, tuple._2, None))) } val groupedRDD = parentRDD.groupByKey(partitioner) diff --git a/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala b/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala index db95c2cfaa..290a216797 100644 --- a/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala +++ b/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala @@ -149,11 +149,8 @@ class DStreamBasicSuite extends DStreamSuiteBase { ) val updateStateOperation = (s: DStream[String]) => { - val updateFunc = (values: Seq[Int], state: RichInt) => { - var newState = 0 - if (values != null && values.size > 0) newState += values.reduce(_ + _) - if (state != null) newState += state.self - new RichInt(newState) + val updateFunc = (values: Seq[Int], state: Option[RichInt]) => { + Some(new RichInt(values.foldLeft(0)(_ + _) + state.map(_.self).getOrElse(0))) } s.map(x => (x, 1)).updateStateByKey[RichInt](updateFunc).map(t => (t._1, t._2.self)) } |