diff options
author | Matei Zaharia <matei@eecs.berkeley.edu> | 2013-01-05 20:54:08 -0500 |
---|---|---|
committer | Matei Zaharia <matei@eecs.berkeley.edu> | 2013-01-05 20:55:17 -0500 |
commit | 86af64b0a6fde5a6418727a77b43bdfeda1b81cd (patch) | |
tree | e1d1c002bfb29de789de9be463dff273f41c43c3 /core/src | |
parent | 7ab9f091408dad2c264cede965ab284c9d33deb9 (diff) | |
download | spark-86af64b0a6fde5a6418727a77b43bdfeda1b81cd.tar.gz spark-86af64b0a6fde5a6418727a77b43bdfeda1b81cd.tar.bz2 spark-86af64b0a6fde5a6418727a77b43bdfeda1b81cd.zip |
Fix Accumulators in Java, and add a test for them
Diffstat (limited to 'core/src')
-rw-r--r-- | core/src/main/scala/spark/Accumulators.scala | 18 | ||||
-rw-r--r-- | core/src/main/scala/spark/SparkContext.scala | 7 | ||||
-rw-r--r-- | core/src/main/scala/spark/api/java/JavaSparkContext.scala | 23 | ||||
-rw-r--r-- | core/src/test/scala/spark/JavaAPISuite.java | 44 |
4 files changed, 79 insertions, 13 deletions
diff --git a/core/src/main/scala/spark/Accumulators.scala b/core/src/main/scala/spark/Accumulators.scala index bacd0ace37..6280f25391 100644 --- a/core/src/main/scala/spark/Accumulators.scala +++ b/core/src/main/scala/spark/Accumulators.scala @@ -39,14 +39,28 @@ class Accumulable[R, T] ( def += (term: T) { value_ = param.addAccumulator(value_, term) } /** + * Add more data to this accumulator / accumulable + * @param term the data to add + */ + def add(term: T) { value_ = param.addAccumulator(value_, term) } + + /** * Merge two accumulable objects together - * + * * Normally, a user will not want to use this version, but will instead call `+=`. - * @param term the other Accumulable that will get merged with this + * @param term the other `R` that will get merged with this */ def ++= (term: R) { value_ = param.addInPlace(value_, term)} /** + * Merge two accumulable objects together + * + * Normally, a user will not want to use this version, but will instead call `add`. + * @param term the other `R` that will get merged with this + */ + def merge(term: R) { value_ = param.addInPlace(value_, term)} + + /** * Access the accumulator's current value; only allowed on master. */ def value = { diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 4fd81bc63b..bbf8272eb3 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -382,11 +382,12 @@ class SparkContext( new Accumulator(initialValue, param) /** - * Create an [[spark.Accumulable]] shared variable, with a `+=` method + * Create an [[spark.Accumulable]] shared variable, to which tasks can add values with `+=`. + * Only the master can access the accumuable's `value`. * @tparam T accumulator type * @tparam R type that can be added to the accumulator */ - def accumulable[T,R](initialValue: T)(implicit param: AccumulableParam[T,R]) = + def accumulable[T, R](initialValue: T)(implicit param: AccumulableParam[T, R]) = new Accumulable(initialValue, param) /** @@ -404,7 +405,7 @@ class SparkContext( * Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for * reading it in distributed functions. The variable will be sent to each cluster only once. */ - def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T] (value, isLocal) + def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal) /** * Add a file to be downloaded into the working directory of this Spark job on every node. diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala index b7725313c4..bf9ad7a200 100644 --- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala @@ -10,7 +10,7 @@ import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} -import spark.{Accumulator, AccumulatorParam, RDD, SparkContext} +import spark.{Accumulable, AccumulableParam, Accumulator, AccumulatorParam, RDD, SparkContext} import spark.SparkContext.IntAccumulatorParam import spark.SparkContext.DoubleAccumulatorParam import spark.broadcast.Broadcast @@ -265,26 +265,33 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork /** * Create an [[spark.Accumulator]] integer variable, which tasks can "add" values - * to using the `+=` method. Only the master can access the accumulator's `value`. + * to using the `add` method. Only the master can access the accumulator's `value`. */ - def intAccumulator(initialValue: Int): Accumulator[Int] = - sc.accumulator(initialValue)(IntAccumulatorParam) + def intAccumulator(initialValue: Int): Accumulator[java.lang.Integer] = + sc.accumulator(initialValue)(IntAccumulatorParam).asInstanceOf[Accumulator[java.lang.Integer]] /** * Create an [[spark.Accumulator]] double variable, which tasks can "add" values - * to using the `+=` method. Only the master can access the accumulator's `value`. + * to using the `add` method. Only the master can access the accumulator's `value`. */ - def doubleAccumulator(initialValue: Double): Accumulator[Double] = - sc.accumulator(initialValue)(DoubleAccumulatorParam) + def doubleAccumulator(initialValue: Double): Accumulator[java.lang.Double] = + sc.accumulator(initialValue)(DoubleAccumulatorParam).asInstanceOf[Accumulator[java.lang.Double]] /** * Create an [[spark.Accumulator]] variable of a given type, which tasks can "add" values - * to using the `+=` method. Only the master can access the accumulator's `value`. + * to using the `add` method. Only the master can access the accumulator's `value`. */ def accumulator[T](initialValue: T, accumulatorParam: AccumulatorParam[T]): Accumulator[T] = sc.accumulator(initialValue)(accumulatorParam) /** + * Create an [[spark.Accumulable]] shared variable of the given type, to which tasks can + * "add" values with `add`. Only the master can access the accumuable's `value`. + */ + def accumulable[T, R](initialValue: T, param: AccumulableParam[T, R]): Accumulable[T, R] = + sc.accumulable(initialValue)(param) + + /** * Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for * reading it in distributed functions. The variable will be sent to each cluster only once. */ diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index 33d5fc2d89..b99e790093 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -581,4 +581,48 @@ public class JavaAPISuite implements Serializable { JavaPairRDD<Integer, Double> zipped = rdd.zip(doubles); zipped.count(); } + + @Test + public void accumulators() { + JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); + + final Accumulator<Integer> intAccum = sc.intAccumulator(10); + rdd.foreach(new VoidFunction<Integer>() { + public void call(Integer x) { + intAccum.add(x); + } + }); + Assert.assertEquals((Integer) 25, intAccum.value()); + + final Accumulator<Double> doubleAccum = sc.doubleAccumulator(10.0); + rdd.foreach(new VoidFunction<Integer>() { + public void call(Integer x) { + doubleAccum.add((double) x); + } + }); + Assert.assertEquals((Double) 25.0, doubleAccum.value()); + + // Try a custom accumulator type + AccumulatorParam<Float> floatAccumulatorParam = new AccumulatorParam<Float>() { + public Float addInPlace(Float r, Float t) { + return r + t; + } + + public Float addAccumulator(Float r, Float t) { + return r + t; + } + + public Float zero(Float initialValue) { + return 0.0f; + } + }; + + final Accumulator<Float> floatAccum = sc.accumulator((Float) 10.0f, floatAccumulatorParam); + rdd.foreach(new VoidFunction<Integer>() { + public void call(Integer x) { + floatAccum.add((float) x); + } + }); + Assert.assertEquals((Float) 25.0f, floatAccum.value()); + } } |