aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/Accumulator.scala17
-rw-r--r--core/src/main/scala/org/apache/spark/AccumulatorV2.scala137
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala22
-rw-r--r--core/src/test/scala/org/apache/spark/AccumulatorSuite.scala17
-rw-r--r--core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala89
5 files changed, 181 insertions, 101 deletions
diff --git a/core/src/main/scala/org/apache/spark/Accumulator.scala b/core/src/main/scala/org/apache/spark/Accumulator.scala
index e52d36b7b5..23245043e2 100644
--- a/core/src/main/scala/org/apache/spark/Accumulator.scala
+++ b/core/src/main/scala/org/apache/spark/Accumulator.scala
@@ -17,9 +17,6 @@
package org.apache.spark
-import org.apache.spark.storage.{BlockId, BlockStatus}
-
-
/**
* A simpler value of [[Accumulable]] where the result type being accumulated is the same
* as the types of elements being merged, i.e. variables that are only "added" to through an
@@ -117,18 +114,4 @@ object AccumulatorParam {
def addInPlace(t1: String, t2: String): String = t2
def zero(initialValue: String): String = ""
}
-
- // Note: this is expensive as it makes a copy of the list every time the caller adds an item.
- // A better way to use this is to first accumulate the values yourself then them all at once.
- @deprecated("use AccumulatorV2", "2.0.0")
- private[spark] class ListAccumulatorParam[T] extends AccumulatorParam[Seq[T]] {
- def addInPlace(t1: Seq[T], t2: Seq[T]): Seq[T] = t1 ++ t2
- def zero(initialValue: Seq[T]): Seq[T] = Seq.empty[T]
- }
-
- // For the internal metric that records what blocks are updated in a particular task
- @deprecated("use AccumulatorV2", "2.0.0")
- private[spark] object UpdatedBlockStatusesAccumulatorParam
- extends ListAccumulatorParam[(BlockId, BlockStatus)]
-
}
diff --git a/core/src/main/scala/org/apache/spark/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/AccumulatorV2.scala
index c65108a55e..a6c64fd680 100644
--- a/core/src/main/scala/org/apache/spark/AccumulatorV2.scala
+++ b/core/src/main/scala/org/apache/spark/AccumulatorV2.scala
@@ -257,23 +257,66 @@ private[spark] object AccumulatorContext {
}
+/**
+ * An [[AccumulatorV2 accumulator]] for computing sum, count, and averages for 64-bit integers.
+ *
+ * @since 2.0.0
+ */
class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] {
private[this] var _sum = 0L
+ private[this] var _count = 0L
- override def isZero: Boolean = _sum == 0
+ /**
+ * Adds v to the accumulator, i.e. increment sum by v and count by 1.
+ * @since 2.0.0
+ */
+ override def isZero: Boolean = _count == 0L
override def copyAndReset(): LongAccumulator = new LongAccumulator
- override def add(v: jl.Long): Unit = _sum += v
+ /**
+ * Adds v to the accumulator, i.e. increment sum by v and count by 1.
+ * @since 2.0.0
+ */
+ override def add(v: jl.Long): Unit = {
+ _sum += v
+ _count += 1
+ }
+
+ /**
+ * Adds v to the accumulator, i.e. increment sum by v and count by 1.
+ * @since 2.0.0
+ */
+ def add(v: Long): Unit = {
+ _sum += v
+ _count += 1
+ }
- def add(v: Long): Unit = _sum += v
+ /**
+ * Returns the number of elements added to the accumulator.
+ * @since 2.0.0
+ */
+ def count: Long = _count
+ /**
+ * Returns the sum of elements added to the accumulator.
+ * @since 2.0.0
+ */
def sum: Long = _sum
+ /**
+ * Returns the average of elements added to the accumulator.
+ * @since 2.0.0
+ */
+ def avg: Double = _sum.toDouble / _count
+
override def merge(other: AccumulatorV2[jl.Long, jl.Long]): Unit = other match {
- case o: LongAccumulator => _sum += o.sum
- case _ => throw new UnsupportedOperationException(
- s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
+ case o: LongAccumulator =>
+ _sum += o.sum
+ _count += o.count
+ case _ =>
+ throw new UnsupportedOperationException(
+ s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
}
private[spark] def setValue(newValue: Long): Unit = _sum = newValue
@@ -282,66 +325,68 @@ class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] {
}
+/**
+ * An [[AccumulatorV2 accumulator]] for computing sum, count, and averages for double precision
+ * floating numbers.
+ *
+ * @since 2.0.0
+ */
class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] {
private[this] var _sum = 0.0
-
- override def isZero: Boolean = _sum == 0.0
-
- override def copyAndReset(): DoubleAccumulator = new DoubleAccumulator
-
- override def add(v: jl.Double): Unit = _sum += v
-
- def add(v: Double): Unit = _sum += v
-
- def sum: Double = _sum
-
- override def merge(other: AccumulatorV2[jl.Double, jl.Double]): Unit = other match {
- case o: DoubleAccumulator => _sum += o.sum
- case _ => throw new UnsupportedOperationException(
- s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
- }
-
- private[spark] def setValue(newValue: Double): Unit = _sum = newValue
-
- override def localValue: jl.Double = _sum
-}
-
-
-class AverageAccumulator extends AccumulatorV2[jl.Double, jl.Double] {
- private[this] var _sum = 0.0
private[this] var _count = 0L
- override def isZero: Boolean = _sum == 0.0 && _count == 0
+ override def isZero: Boolean = _count == 0L
- override def copyAndReset(): AverageAccumulator = new AverageAccumulator
+ override def copyAndReset(): DoubleAccumulator = new DoubleAccumulator
+ /**
+ * Adds v to the accumulator, i.e. increment sum by v and count by 1.
+ * @since 2.0.0
+ */
override def add(v: jl.Double): Unit = {
_sum += v
_count += 1
}
- def add(d: Double): Unit = {
- _sum += d
+ /**
+ * Adds v to the accumulator, i.e. increment sum by v and count by 1.
+ * @since 2.0.0
+ */
+ def add(v: Double): Unit = {
+ _sum += v
_count += 1
}
+ /**
+ * Returns the number of elements added to the accumulator.
+ * @since 2.0.0
+ */
+ def count: Long = _count
+
+ /**
+ * Returns the sum of elements added to the accumulator.
+ * @since 2.0.0
+ */
+ def sum: Double = _sum
+
+ /**
+ * Returns the average of elements added to the accumulator.
+ * @since 2.0.0
+ */
+ def avg: Double = _sum / _count
+
override def merge(other: AccumulatorV2[jl.Double, jl.Double]): Unit = other match {
- case o: AverageAccumulator =>
+ case o: DoubleAccumulator =>
_sum += o.sum
_count += o.count
- case _ => throw new UnsupportedOperationException(
- s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
- }
-
- override def localValue: jl.Double = if (_count == 0) {
- Double.NaN
- } else {
- _sum / _count
+ case _ =>
+ throw new UnsupportedOperationException(
+ s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
}
- def sum: Double = _sum
+ private[spark] def setValue(newValue: Double): Unit = _sum = newValue
- def count: Long = _count
+ override def localValue: jl.Double = _sum
}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 58618b4192..e391599336 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -1341,28 +1341,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
}
/**
- * Create and register an average accumulator, which accumulates double inputs by recording the
- * total sum and total count, and produce the output by sum / total. Note that Double.NaN will be
- * returned if no input is added.
- */
- def averageAccumulator: AverageAccumulator = {
- val acc = new AverageAccumulator
- register(acc)
- acc
- }
-
- /**
- * Create and register an average accumulator, which accumulates double inputs by recording the
- * total sum and total count, and produce the output by sum / total. Note that Double.NaN will be
- * returned if no input is added.
- */
- def averageAccumulator(name: String): AverageAccumulator = {
- val acc = new AverageAccumulator
- register(acc, name)
- acc
- }
-
- /**
* Create and register a list accumulator, which starts with empty list and accumulates inputs
* by adding them into the inner list.
*/
diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
index 09eb9c1dbd..0020096254 100644
--- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
@@ -28,7 +28,7 @@ import scala.util.control.NonFatal
import org.scalatest.Matchers
import org.scalatest.exceptions.TestFailedException
-import org.apache.spark.AccumulatorParam.{ListAccumulatorParam, StringAccumulatorParam}
+import org.apache.spark.AccumulatorParam.StringAccumulatorParam
import org.apache.spark.scheduler._
import org.apache.spark.serializer.JavaSerializer
@@ -234,21 +234,6 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex
acc.merge("kindness")
assert(acc.value === "kindness")
}
-
- test("list accumulator param") {
- val acc = new Accumulator(Seq.empty[Int], new ListAccumulatorParam[Int], Some("numbers"))
- assert(acc.value === Seq.empty[Int])
- acc.add(Seq(1, 2))
- assert(acc.value === Seq(1, 2))
- acc += Seq(3, 4)
- assert(acc.value === Seq(1, 2, 3, 4))
- acc ++= Seq(5, 6)
- assert(acc.value === Seq(1, 2, 3, 4, 5, 6))
- acc.merge(Seq(7, 8))
- assert(acc.value === Seq(1, 2, 3, 4, 5, 6, 7, 8))
- acc.setValue(Seq(9, 10))
- assert(acc.value === Seq(9, 10))
- }
}
private[spark] object AccumulatorSuite {
diff --git a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala
new file mode 100644
index 0000000000..41cdd02492
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import org.apache.spark.{DoubleAccumulator, LongAccumulator, SparkFunSuite}
+
+class AccumulatorV2Suite extends SparkFunSuite {
+
+ test("LongAccumulator add/avg/sum/count/isZero") {
+ val acc = new LongAccumulator
+ assert(acc.isZero)
+ assert(acc.count == 0)
+ assert(acc.sum == 0)
+ assert(acc.avg.isNaN)
+
+ acc.add(0)
+ assert(!acc.isZero)
+ assert(acc.count == 1)
+ assert(acc.sum == 0)
+ assert(acc.avg == 0.0)
+
+ acc.add(1)
+ assert(acc.count == 2)
+ assert(acc.sum == 1)
+ assert(acc.avg == 0.5)
+
+ // Also test add using non-specialized add function
+ acc.add(new java.lang.Long(2))
+ assert(acc.count == 3)
+ assert(acc.sum == 3)
+ assert(acc.avg == 1.0)
+
+ // Test merging
+ val acc2 = new LongAccumulator
+ acc2.add(2)
+ acc.merge(acc2)
+ assert(acc.count == 4)
+ assert(acc.sum == 5)
+ assert(acc.avg == 1.25)
+ }
+
+ test("DoubleAccumulator add/avg/sum/count/isZero") {
+ val acc = new DoubleAccumulator
+ assert(acc.isZero)
+ assert(acc.count == 0)
+ assert(acc.sum == 0.0)
+ assert(acc.avg.isNaN)
+
+ acc.add(0.0)
+ assert(!acc.isZero)
+ assert(acc.count == 1)
+ assert(acc.sum == 0.0)
+ assert(acc.avg == 0.0)
+
+ acc.add(1.0)
+ assert(acc.count == 2)
+ assert(acc.sum == 1.0)
+ assert(acc.avg == 0.5)
+
+ // Also test add using non-specialized add function
+ acc.add(new java.lang.Double(2.0))
+ assert(acc.count == 3)
+ assert(acc.sum == 3.0)
+ assert(acc.avg == 1.0)
+
+ // Test merging
+ val acc2 = new DoubleAccumulator
+ acc2.add(2.0)
+ acc.merge(acc2)
+ assert(acc.count == 4)
+ assert(acc.sum == 5.0)
+ assert(acc.avg == 1.25)
+ }
+}