From 6d7f907e73e9702c0dbd0e41e4a52022c0b81d3d Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Tue, 11 Sep 2012 16:00:06 -0700 Subject: Manually merge pull request #175 by Imran Rashid --- core/src/main/scala/spark/Accumulators.scala | 24 +++++++++++++++++++++ core/src/main/scala/spark/SparkContext.scala | 11 ++++++++++ core/src/test/scala/spark/AccumulatorSuite.scala | 27 ++++++++++++++++++++++-- 3 files changed, 60 insertions(+), 2 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/spark/Accumulators.scala b/core/src/main/scala/spark/Accumulators.scala index d764ffc29d..c157cc8feb 100644 --- a/core/src/main/scala/spark/Accumulators.scala +++ b/core/src/main/scala/spark/Accumulators.scala @@ -3,6 +3,7 @@ package spark import java.io._ import scala.collection.mutable.Map +import scala.collection.generic.Growable /** * A datatype that can be accumulated, i.e. has an commutative and associative +. @@ -92,6 +93,29 @@ trait AccumulableParam[R, T] extends Serializable { def zero(initialValue: R): R } +class GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializable, T] + extends AccumulableParam[R,T] { + + def addAccumulator(growable: R, elem: T) : R = { + growable += elem + growable + } + + def addInPlace(t1: R, t2: R) : R = { + t1 ++= t2 + t1 + } + + def zero(initialValue: R): R = { + // We need to clone initialValue, but it's hard to specify that R should also be Cloneable. + // Instead we'll serialize it to a buffer and load it back. + val ser = (new spark.JavaSerializer).newInstance() + val copy = ser.deserialize[R](ser.serialize(initialValue)) + copy.clear() // In case it contained stuff + copy + } +} + /** * A simpler value of [[spark.Accumulable]] where the result type being accumulated is the same * as the types of elements being merged. diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 5d0f2950d6..0dec44979f 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -7,6 +7,7 @@ import akka.actor.Actor import akka.actor.Actor._ import scala.collection.mutable.ArrayBuffer +import scala.collection.generic.Growable import org.apache.hadoop.fs.Path import org.apache.hadoop.conf.Configuration @@ -307,6 +308,16 @@ class SparkContext( def accumulable[T,R](initialValue: T)(implicit param: AccumulableParam[T,R]) = new Accumulable(initialValue, param) + /** + * Create an accumulator from a "mutable collection" type. + * + * Growable and TraversableOnce are the standard APIs that guarantee += and ++=, implemented by + * standard mutable collections. So you can use this with mutable Map, Set, etc. + */ + def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable, T](initialValue: R) = { + val param = new GrowableAccumulableParam[R,T] + new Accumulable(initialValue, param) + } // Keep around a weak hash map of values to Cached versions? def broadcast[T](value: T) = SparkEnv.get.broadcastManager.newBroadcast[T] (value, isLocal) diff --git a/core/src/test/scala/spark/AccumulatorSuite.scala b/core/src/test/scala/spark/AccumulatorSuite.scala index d55969c261..71df5941e5 100644 --- a/core/src/test/scala/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/spark/AccumulatorSuite.scala @@ -56,7 +56,6 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter } } - implicit object SetAccum extends AccumulableParam[mutable.Set[Any], Any] { def addInPlace(t1: mutable.Set[Any], t2: mutable.Set[Any]) : mutable.Set[Any] = { t1 ++= t2 @@ -71,7 +70,6 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter } } - test ("value not readable in tasks") { import SetAccum._ val maxI = 1000 @@ -89,4 +87,29 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter } } + test ("collection accumulators") { + val maxI = 1000 + for (nThreads <- List(1, 10)) { + // test single & multi-threaded + val sc = new SparkContext("local[" + nThreads + "]", "test") + val setAcc = sc.accumulableCollection(mutable.HashSet[Int]()) + val bufferAcc = sc.accumulableCollection(mutable.ArrayBuffer[Int]()) + val mapAcc = sc.accumulableCollection(mutable.HashMap[Int,String]()) + val d = sc.parallelize((1 to maxI) ++ (1 to maxI)) + d.foreach { + x => {setAcc += x; bufferAcc += x; mapAcc += (x -> x.toString)} + } + + // Note that this is typed correctly -- no casts necessary + setAcc.value.size should be (maxI) + bufferAcc.value.size should be (2 * maxI) + mapAcc.value.size should be (maxI) + for (i <- 1 to maxI) { + setAcc.value should contain(i) + bufferAcc.value should contain(i) + mapAcc.value should contain (i -> i.toString) + } + sc.stop() + } + } } -- cgit v1.2.3