diff options
author | Denny <dennybritz@gmail.com> | 2012-09-11 16:57:17 -0700 |
---|---|---|
committer | Denny <dennybritz@gmail.com> | 2012-09-11 16:57:17 -0700 |
commit | 5e4076e3f2eb6b0206119c5d67ac6ee405cee1ad (patch) | |
tree | eeeadcd957958ad0210d86d0e8defde534ab28eb /core | |
parent | 77873d2c8eda58278e136f01f03e154cba40ee79 (diff) | |
parent | 943df48348662d1ca17091dd403c5365e27924a8 (diff) | |
download | spark-5e4076e3f2eb6b0206119c5d67ac6ee405cee1ad.tar.gz spark-5e4076e3f2eb6b0206119c5d67ac6ee405cee1ad.tar.bz2 spark-5e4076e3f2eb6b0206119c5d67ac6ee405cee1ad.zip |
Merge branch 'dev' into feature/fileserver
Conflicts:
core/src/main/scala/spark/SparkContext.scala
Diffstat (limited to 'core')
-rw-r--r-- | core/src/main/scala/spark/Accumulators.scala | 24 | ||||
-rw-r--r-- | core/src/main/scala/spark/SparkContext.scala | 11 | ||||
-rw-r--r-- | core/src/main/scala/spark/storage/BlockManagerWorker.scala | 2 | ||||
-rw-r--r-- | core/src/test/scala/spark/AccumulatorSuite.scala | 27 |
4 files changed, 61 insertions, 3 deletions
diff --git a/core/src/main/scala/spark/Accumulators.scala b/core/src/main/scala/spark/Accumulators.scala index d764ffc29d..c157cc8feb 100644 --- a/core/src/main/scala/spark/Accumulators.scala +++ b/core/src/main/scala/spark/Accumulators.scala @@ -3,6 +3,7 @@ package spark import java.io._ import scala.collection.mutable.Map +import scala.collection.generic.Growable /** * A datatype that can be accumulated, i.e. has an commutative and associative +. @@ -92,6 +93,29 @@ trait AccumulableParam[R, T] extends Serializable { def zero(initialValue: R): R } +class GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializable, T] + extends AccumulableParam[R,T] { + + def addAccumulator(growable: R, elem: T) : R = { + growable += elem + growable + } + + def addInPlace(t1: R, t2: R) : R = { + t1 ++= t2 + t1 + } + + def zero(initialValue: R): R = { + // We need to clone initialValue, but it's hard to specify that R should also be Cloneable. + // Instead we'll serialize it to a buffer and load it back. + val ser = (new spark.JavaSerializer).newInstance() + val copy = ser.deserialize[R](ser.serialize(initialValue)) + copy.clear() // In case it contained stuff + copy + } +} + /** * A simpler value of [[spark.Accumulable]] where the result type being accumulated is the same * as the types of elements being merged. diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 2bd07f10d4..758c42fa61 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -8,6 +8,7 @@ import akka.actor.Actor import akka.actor.Actor._ import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.collection.generic.Growable import org.apache.hadoop.fs.{FileUtil, Path} import org.apache.hadoop.conf.Configuration @@ -315,6 +316,16 @@ class SparkContext( def accumulable[T,R](initialValue: T)(implicit param: AccumulableParam[T,R]) = new Accumulable(initialValue, param) + /** + * Create an accumulator from a "mutable collection" type. + * + * Growable and TraversableOnce are the standard APIs that guarantee += and ++=, implemented by + * standard mutable collections. So you can use this with mutable Map, Set, etc. + */ + def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable, T](initialValue: R) = { + val param = new GrowableAccumulableParam[R,T] + new Accumulable(initialValue, param) + } // Keep around a weak hash map of values to Cached versions? def broadcast[T](value: T) = SparkEnv.get.broadcastManager.newBroadcast[T] (value, isLocal) diff --git a/core/src/main/scala/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/spark/storage/BlockManagerWorker.scala index e317ad3642..0eaa558f44 100644 --- a/core/src/main/scala/spark/storage/BlockManagerWorker.scala +++ b/core/src/main/scala/spark/storage/BlockManagerWorker.scala @@ -34,7 +34,7 @@ class BlockManagerWorker(val blockManager: BlockManager) extends Logging { /*logDebug("Processed block messages")*/ return Some(new BlockMessageArray(responseMessages).toBufferMessage) } catch { - case e: Exception => logError("Exception handling buffer message: " + e.getMessage) + case e: Exception => logError("Exception handling buffer message", e) return None } } diff --git a/core/src/test/scala/spark/AccumulatorSuite.scala b/core/src/test/scala/spark/AccumulatorSuite.scala index d55969c261..71df5941e5 100644 --- a/core/src/test/scala/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/spark/AccumulatorSuite.scala @@ -56,7 +56,6 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter } } - implicit object SetAccum extends AccumulableParam[mutable.Set[Any], Any] { def addInPlace(t1: mutable.Set[Any], t2: mutable.Set[Any]) : mutable.Set[Any] = { t1 ++= t2 @@ -71,7 +70,6 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter } } - test ("value not readable in tasks") { import SetAccum._ val maxI = 1000 @@ -89,4 +87,29 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with BeforeAndAfter } } + test ("collection accumulators") { + val maxI = 1000 + for (nThreads <- List(1, 10)) { + // test single & multi-threaded + val sc = new SparkContext("local[" + nThreads + "]", "test") + val setAcc = sc.accumulableCollection(mutable.HashSet[Int]()) + val bufferAcc = sc.accumulableCollection(mutable.ArrayBuffer[Int]()) + val mapAcc = sc.accumulableCollection(mutable.HashMap[Int,String]()) + val d = sc.parallelize((1 to maxI) ++ (1 to maxI)) + d.foreach { + x => {setAcc += x; bufferAcc += x; mapAcc += (x -> x.toString)} + } + + // Note that this is typed correctly -- no casts necessary + setAcc.value.size should be (maxI) + bufferAcc.value.size should be (2 * maxI) + mapAcc.value.size should be (maxI) + for (i <- 1 to maxI) { + setAcc.value should contain(i) + bufferAcc.value should contain(i) + mapAcc.value should contain (i -> i.toString) + } + sc.stop() + } + } } |