aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2012-09-11 16:00:06 -0700
committerMatei Zaharia <matei@eecs.berkeley.edu>2012-09-11 16:00:06 -0700
commit6d7f907e73e9702c0dbd0e41e4a52022c0b81d3d (patch)
tree560c61abb885f2fe77ddcc49f6b6ffdf0bfe02ac /core
parent995982b3c9fdd4b031ccca4dfe76b4951ce1fcff (diff)
downloadspark-6d7f907e73e9702c0dbd0e41e4a52022c0b81d3d.tar.gz
spark-6d7f907e73e9702c0dbd0e41e4a52022c0b81d3d.tar.bz2
spark-6d7f907e73e9702c0dbd0e41e4a52022c0b81d3d.zip
Manually merge pull request #175 by Imran Rashid
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/spark/Accumulators.scala24
-rw-r--r--core/src/main/scala/spark/SparkContext.scala11
-rw-r--r--core/src/test/scala/spark/AccumulatorSuite.scala27
3 files changed, 60 insertions, 2 deletions
diff --git a/core/src/main/scala/spark/Accumulators.scala b/core/src/main/scala/spark/Accumulators.scala
index d764ffc29d..c157cc8feb 100644
--- a/core/src/main/scala/spark/Accumulators.scala
+++ b/core/src/main/scala/spark/Accumulators.scala
@@ -3,6 +3,7 @@ package spark
import java.io._
import scala.collection.mutable.Map
+import scala.collection.generic.Growable
/**
* A datatype that can be accumulated, i.e. has an commutative and associative +.
@@ -92,6 +93,29 @@ trait AccumulableParam[R, T] extends Serializable {
def zero(initialValue: R): R
}
+class GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializable, T]
+ extends AccumulableParam[R,T] {
+
+ def addAccumulator(growable: R, elem: T) : R = {
+ growable += elem
+ growable
+ }
+
+ def addInPlace(t1: R, t2: R) : R = {
+ t1 ++= t2
+ t1
+ }
+
+ def zero(initialValue: R): R = {
+ // We need to clone initialValue, but it's hard to specify that R should also be Cloneable.
+ // Instead we'll serialize it to a buffer and load it back.
+ val ser = (new spark.JavaSerializer).newInstance()
+ val copy = ser.deserialize[R](ser.serialize(initialValue))
+ copy.clear() // In case it contained stuff
+ copy
+ }
+}
+
/**
* A simpler value of [[spark.Accumulable]] where the result type being accumulated is the same
* as the types of elements being merged.
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index 5d0f2950d6..0dec44979f 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -7,6 +7,7 @@ import akka.actor.Actor
import akka.actor.Actor._
import scala.collection.mutable.ArrayBuffer
+import scala.collection.generic.Growable
import org.apache.hadoop.fs.Path
import org.apache.hadoop.conf.Configuration
@@ -307,6 +308,16 @@ class SparkContext(
def accumulable[T,R](initialValue: T)(implicit param: AccumulableParam[T,R]) =
new Accumulable(initialValue, param)
+ /**
+ * Create an accumulator from a "mutable collection" type.
+ *
+ * Growable and TraversableOnce are the standard APIs that guarantee += and ++=, implemented by
+ * standard mutable collections. So you can use this with mutable Map, Set, etc.
+ */
+ def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable, T](initialValue: R) = {
+ val param = new GrowableAccumulableParam[R,T]
+ new Accumulable(initialValue, param)
+ }
// Keep around a weak hash map of values to Cached versions?
def broadcast[T](value: T) = SparkEnv.get.broadcastManager.newBroadcast[T] (value, isLocal)
diff --git a/core/src/test/scala/spark/AccumulatorSuite.scala b/core/src/test/scala/spark/AccumulatorSuite.scala
index d55969c261..71df5941e5 100644
--- a/core/src/test/scala/spark/AccumulatorSuite.scala
+++ b/core/src/test/scala/spark/AccumulatorSuite.scala
@@ -56,7 +56,6 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
}
}
-
implicit object SetAccum extends AccumulableParam[mutable.Set[Any], Any] {
def addInPlace(t1: mutable.Set[Any], t2: mutable.Set[Any]) : mutable.Set[Any] = {
t1 ++= t2
@@ -71,7 +70,6 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
}
}
-
test ("value not readable in tasks") {
import SetAccum._
val maxI = 1000
@@ -89,4 +87,29 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
}
}
+ test ("collection accumulators") {
+ val maxI = 1000
+ for (nThreads <- List(1, 10)) {
+ // test single & multi-threaded
+ val sc = new SparkContext("local[" + nThreads + "]", "test")
+ val setAcc = sc.accumulableCollection(mutable.HashSet[Int]())
+ val bufferAcc = sc.accumulableCollection(mutable.ArrayBuffer[Int]())
+ val mapAcc = sc.accumulableCollection(mutable.HashMap[Int,String]())
+ val d = sc.parallelize((1 to maxI) ++ (1 to maxI))
+ d.foreach {
+ x => {setAcc += x; bufferAcc += x; mapAcc += (x -> x.toString)}
+ }
+
+ // Note that this is typed correctly -- no casts necessary
+ setAcc.value.size should be (maxI)
+ bufferAcc.value.size should be (2 * maxI)
+ mapAcc.value.size should be (maxI)
+ for (i <- 1 to maxI) {
+ setAcc.value should contain(i)
+ bufferAcc.value should contain(i)
+ mapAcc.value should contain (i -> i.toString)
+ }
+ sc.stop()
+ }
+ }
}