aboutsummaryrefslogtreecommitdiff
path: root/core/src
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2013-01-05 20:54:08 -0500
committerMatei Zaharia <matei@eecs.berkeley.edu>2013-01-05 20:55:17 -0500
commit86af64b0a6fde5a6418727a77b43bdfeda1b81cd (patch)
treee1d1c002bfb29de789de9be463dff273f41c43c3 /core/src
parent7ab9f091408dad2c264cede965ab284c9d33deb9 (diff)
downloadspark-86af64b0a6fde5a6418727a77b43bdfeda1b81cd.tar.gz
spark-86af64b0a6fde5a6418727a77b43bdfeda1b81cd.tar.bz2
spark-86af64b0a6fde5a6418727a77b43bdfeda1b81cd.zip
Fix Accumulators in Java, and add a test for them
Diffstat (limited to 'core/src')
-rw-r--r--core/src/main/scala/spark/Accumulators.scala18
-rw-r--r--core/src/main/scala/spark/SparkContext.scala7
-rw-r--r--core/src/main/scala/spark/api/java/JavaSparkContext.scala23
-rw-r--r--core/src/test/scala/spark/JavaAPISuite.java44
4 files changed, 79 insertions, 13 deletions
diff --git a/core/src/main/scala/spark/Accumulators.scala b/core/src/main/scala/spark/Accumulators.scala
index bacd0ace37..6280f25391 100644
--- a/core/src/main/scala/spark/Accumulators.scala
+++ b/core/src/main/scala/spark/Accumulators.scala
@@ -39,14 +39,28 @@ class Accumulable[R, T] (
def += (term: T) { value_ = param.addAccumulator(value_, term) }
/**
+ * Add more data to this accumulator / accumulable
+ * @param term the data to add
+ */
+ def add(term: T) { value_ = param.addAccumulator(value_, term) }
+
+ /**
* Merge two accumulable objects together
- *
+ *
* Normally, a user will not want to use this version, but will instead call `+=`.
- * @param term the other Accumulable that will get merged with this
+ * @param term the other `R` that will get merged with this
*/
def ++= (term: R) { value_ = param.addInPlace(value_, term)}
/**
+ * Merge two accumulable objects together
+ *
+ * Normally, a user will not want to use this version, but will instead call `add`.
+ * @param term the other `R` that will get merged with this
+ */
+ def merge(term: R) { value_ = param.addInPlace(value_, term)}
+
+ /**
* Access the accumulator's current value; only allowed on master.
*/
def value = {
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index 4fd81bc63b..bbf8272eb3 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -382,11 +382,12 @@ class SparkContext(
new Accumulator(initialValue, param)
/**
- * Create an [[spark.Accumulable]] shared variable, with a `+=` method
+ * Create an [[spark.Accumulable]] shared variable, to which tasks can add values with `+=`.
+ * Only the master can access the accumuable's `value`.
* @tparam T accumulator type
* @tparam R type that can be added to the accumulator
*/
- def accumulable[T,R](initialValue: T)(implicit param: AccumulableParam[T,R]) =
+ def accumulable[T, R](initialValue: T)(implicit param: AccumulableParam[T, R]) =
new Accumulable(initialValue, param)
/**
@@ -404,7 +405,7 @@ class SparkContext(
* Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for
* reading it in distributed functions. The variable will be sent to each cluster only once.
*/
- def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T] (value, isLocal)
+ def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal)
/**
* Add a file to be downloaded into the working directory of this Spark job on every node.
diff --git a/core/src/main/scala/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/spark/api/java/JavaSparkContext.scala
index b7725313c4..bf9ad7a200 100644
--- a/core/src/main/scala/spark/api/java/JavaSparkContext.scala
+++ b/core/src/main/scala/spark/api/java/JavaSparkContext.scala
@@ -10,7 +10,7 @@ import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
-import spark.{Accumulator, AccumulatorParam, RDD, SparkContext}
+import spark.{Accumulable, AccumulableParam, Accumulator, AccumulatorParam, RDD, SparkContext}
import spark.SparkContext.IntAccumulatorParam
import spark.SparkContext.DoubleAccumulatorParam
import spark.broadcast.Broadcast
@@ -265,26 +265,33 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
/**
* Create an [[spark.Accumulator]] integer variable, which tasks can "add" values
- * to using the `+=` method. Only the master can access the accumulator's `value`.
+ * to using the `add` method. Only the master can access the accumulator's `value`.
*/
- def intAccumulator(initialValue: Int): Accumulator[Int] =
- sc.accumulator(initialValue)(IntAccumulatorParam)
+ def intAccumulator(initialValue: Int): Accumulator[java.lang.Integer] =
+ sc.accumulator(initialValue)(IntAccumulatorParam).asInstanceOf[Accumulator[java.lang.Integer]]
/**
* Create an [[spark.Accumulator]] double variable, which tasks can "add" values
- * to using the `+=` method. Only the master can access the accumulator's `value`.
+ * to using the `add` method. Only the master can access the accumulator's `value`.
*/
- def doubleAccumulator(initialValue: Double): Accumulator[Double] =
- sc.accumulator(initialValue)(DoubleAccumulatorParam)
+ def doubleAccumulator(initialValue: Double): Accumulator[java.lang.Double] =
+ sc.accumulator(initialValue)(DoubleAccumulatorParam).asInstanceOf[Accumulator[java.lang.Double]]
/**
* Create an [[spark.Accumulator]] variable of a given type, which tasks can "add" values
- * to using the `+=` method. Only the master can access the accumulator's `value`.
+ * to using the `add` method. Only the master can access the accumulator's `value`.
*/
def accumulator[T](initialValue: T, accumulatorParam: AccumulatorParam[T]): Accumulator[T] =
sc.accumulator(initialValue)(accumulatorParam)
/**
+ * Create an [[spark.Accumulable]] shared variable of the given type, to which tasks can
+ * "add" values with `add`. Only the master can access the accumuable's `value`.
+ */
+ def accumulable[T, R](initialValue: T, param: AccumulableParam[T, R]): Accumulable[T, R] =
+ sc.accumulable(initialValue)(param)
+
+ /**
* Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for
* reading it in distributed functions. The variable will be sent to each cluster only once.
*/
diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java
index 33d5fc2d89..b99e790093 100644
--- a/core/src/test/scala/spark/JavaAPISuite.java
+++ b/core/src/test/scala/spark/JavaAPISuite.java
@@ -581,4 +581,48 @@ public class JavaAPISuite implements Serializable {
JavaPairRDD<Integer, Double> zipped = rdd.zip(doubles);
zipped.count();
}
+
+ @Test
+ public void accumulators() {
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
+
+ final Accumulator<Integer> intAccum = sc.intAccumulator(10);
+ rdd.foreach(new VoidFunction<Integer>() {
+ public void call(Integer x) {
+ intAccum.add(x);
+ }
+ });
+ Assert.assertEquals((Integer) 25, intAccum.value());
+
+ final Accumulator<Double> doubleAccum = sc.doubleAccumulator(10.0);
+ rdd.foreach(new VoidFunction<Integer>() {
+ public void call(Integer x) {
+ doubleAccum.add((double) x);
+ }
+ });
+ Assert.assertEquals((Double) 25.0, doubleAccum.value());
+
+ // Try a custom accumulator type
+ AccumulatorParam<Float> floatAccumulatorParam = new AccumulatorParam<Float>() {
+ public Float addInPlace(Float r, Float t) {
+ return r + t;
+ }
+
+ public Float addAccumulator(Float r, Float t) {
+ return r + t;
+ }
+
+ public Float zero(Float initialValue) {
+ return 0.0f;
+ }
+ };
+
+ final Accumulator<Float> floatAccum = sc.accumulator((Float) 10.0f, floatAccumulatorParam);
+ rdd.foreach(new VoidFunction<Integer>() {
+ public void call(Integer x) {
+ floatAccum.add((float) x);
+ }
+ });
+ Assert.assertEquals((Float) 25.0f, floatAccum.value());
+ }
}