aboutsummaryrefslogtreecommitdiff
path: root/core/src
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2016-05-02 21:12:48 -0700
committerReynold Xin <rxin@databricks.com>2016-05-02 21:12:48 -0700
commitbb9ab56b960153d374d7e8838f62a18e7e45481e (patch)
treef24cadbd8818550f4a5ecb31fb3d406e7fce5e6b /core/src
parent8028f3a0b4003af15ed44d9ef4727b56f4b10534 (diff)
downloadspark-bb9ab56b960153d374d7e8838f62a18e7e45481e.tar.gz
spark-bb9ab56b960153d374d7e8838f62a18e7e45481e.tar.bz2
spark-bb9ab56b960153d374d7e8838f62a18e7e45481e.zip
[SPARK-15079] Support average/count/sum in Long/DoubleAccumulator
## What changes were proposed in this pull request? This patch removes AverageAccumulator and adds the ability to compute average to LongAccumulator and DoubleAccumulator. The patch also improves documentation for the two accumulators. ## How was this patch tested? Added unit tests for this. Author: Reynold Xin <rxin@databricks.com> Closes #12858 from rxin/SPARK-15079.
Diffstat (limited to 'core/src')
-rw-r--r--core/src/main/scala/org/apache/spark/Accumulator.scala17
-rw-r--r--core/src/main/scala/org/apache/spark/AccumulatorV2.scala137
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala22
-rw-r--r--core/src/test/scala/org/apache/spark/AccumulatorSuite.scala17
-rw-r--r--core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala89
5 files changed, 181 insertions, 101 deletions
diff --git a/core/src/main/scala/org/apache/spark/Accumulator.scala b/core/src/main/scala/org/apache/spark/Accumulator.scala
index e52d36b7b5..23245043e2 100644
--- a/core/src/main/scala/org/apache/spark/Accumulator.scala
+++ b/core/src/main/scala/org/apache/spark/Accumulator.scala
@@ -17,9 +17,6 @@
package org.apache.spark
-import org.apache.spark.storage.{BlockId, BlockStatus}
-
-
/**
* A simpler value of [[Accumulable]] where the result type being accumulated is the same
* as the types of elements being merged, i.e. variables that are only "added" to through an
@@ -117,18 +114,4 @@ object AccumulatorParam {
def addInPlace(t1: String, t2: String): String = t2
def zero(initialValue: String): String = ""
}
-
- // Note: this is expensive as it makes a copy of the list every time the caller adds an item.
- // A better way to use this is to first accumulate the values yourself then them all at once.
- @deprecated("use AccumulatorV2", "2.0.0")
- private[spark] class ListAccumulatorParam[T] extends AccumulatorParam[Seq[T]] {
- def addInPlace(t1: Seq[T], t2: Seq[T]): Seq[T] = t1 ++ t2
- def zero(initialValue: Seq[T]): Seq[T] = Seq.empty[T]
- }
-
- // For the internal metric that records what blocks are updated in a particular task
- @deprecated("use AccumulatorV2", "2.0.0")
- private[spark] object UpdatedBlockStatusesAccumulatorParam
- extends ListAccumulatorParam[(BlockId, BlockStatus)]
-
}
diff --git a/core/src/main/scala/org/apache/spark/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/AccumulatorV2.scala
index c65108a55e..a6c64fd680 100644
--- a/core/src/main/scala/org/apache/spark/AccumulatorV2.scala
+++ b/core/src/main/scala/org/apache/spark/AccumulatorV2.scala
@@ -257,23 +257,66 @@ private[spark] object AccumulatorContext {
}
+/**
+ * An [[AccumulatorV2 accumulator]] for computing sum, count, and averages for 64-bit integers.
+ *
+ * @since 2.0.0
+ */
class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] {
private[this] var _sum = 0L
+ private[this] var _count = 0L
- override def isZero: Boolean = _sum == 0
+ /**
+ * Adds v to the accumulator, i.e. increment sum by v and count by 1.
+ * @since 2.0.0
+ */
+ override def isZero: Boolean = _count == 0L
override def copyAndReset(): LongAccumulator = new LongAccumulator
- override def add(v: jl.Long): Unit = _sum += v
+ /**
+ * Adds v to the accumulator, i.e. increment sum by v and count by 1.
+ * @since 2.0.0
+ */
+ override def add(v: jl.Long): Unit = {
+ _sum += v
+ _count += 1
+ }
+
+ /**
+ * Adds v to the accumulator, i.e. increment sum by v and count by 1.
+ * @since 2.0.0
+ */
+ def add(v: Long): Unit = {
+ _sum += v
+ _count += 1
+ }
- def add(v: Long): Unit = _sum += v
+ /**
+ * Returns the number of elements added to the accumulator.
+ * @since 2.0.0
+ */
+ def count: Long = _count
+ /**
+ * Returns the sum of elements added to the accumulator.
+ * @since 2.0.0
+ */
def sum: Long = _sum
+ /**
+ * Returns the average of elements added to the accumulator.
+ * @since 2.0.0
+ */
+ def avg: Double = _sum.toDouble / _count
+
override def merge(other: AccumulatorV2[jl.Long, jl.Long]): Unit = other match {
- case o: LongAccumulator => _sum += o.sum
- case _ => throw new UnsupportedOperationException(
- s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
+ case o: LongAccumulator =>
+ _sum += o.sum
+ _count += o.count
+ case _ =>
+ throw new UnsupportedOperationException(
+ s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
}
private[spark] def setValue(newValue: Long): Unit = _sum = newValue
@@ -282,66 +325,68 @@ class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] {
}
+/**
+ * An [[AccumulatorV2 accumulator]] for computing sum, count, and averages for double precision
+ * floating numbers.
+ *
+ * @since 2.0.0
+ */
class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] {
private[this] var _sum = 0.0
-
- override def isZero: Boolean = _sum == 0.0
-
- override def copyAndReset(): DoubleAccumulator = new DoubleAccumulator
-
- override def add(v: jl.Double): Unit = _sum += v
-
- def add(v: Double): Unit = _sum += v
-
- def sum: Double = _sum
-
- override def merge(other: AccumulatorV2[jl.Double, jl.Double]): Unit = other match {
- case o: DoubleAccumulator => _sum += o.sum
- case _ => throw new UnsupportedOperationException(
- s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
- }
-
- private[spark] def setValue(newValue: Double): Unit = _sum = newValue
-
- override def localValue: jl.Double = _sum
-}
-
-
-class AverageAccumulator extends AccumulatorV2[jl.Double, jl.Double] {
- private[this] var _sum = 0.0
private[this] var _count = 0L
- override def isZero: Boolean = _sum == 0.0 && _count == 0
+ override def isZero: Boolean = _count == 0L
- override def copyAndReset(): AverageAccumulator = new AverageAccumulator
+ override def copyAndReset(): DoubleAccumulator = new DoubleAccumulator
+ /**
+ * Adds v to the accumulator, i.e. increment sum by v and count by 1.
+ * @since 2.0.0
+ */
override def add(v: jl.Double): Unit = {
_sum += v
_count += 1
}
- def add(d: Double): Unit = {
- _sum += d
+ /**
+ * Adds v to the accumulator, i.e. increment sum by v and count by 1.
+ * @since 2.0.0
+ */
+ def add(v: Double): Unit = {
+ _sum += v
_count += 1
}
+ /**
+ * Returns the number of elements added to the accumulator.
+ * @since 2.0.0
+ */
+ def count: Long = _count
+
+ /**
+ * Returns the sum of elements added to the accumulator.
+ * @since 2.0.0
+ */
+ def sum: Double = _sum
+
+ /**
+ * Returns the average of elements added to the accumulator.
+ * @since 2.0.0
+ */
+ def avg: Double = _sum / _count
+
override def merge(other: AccumulatorV2[jl.Double, jl.Double]): Unit = other match {
- case o: AverageAccumulator =>
+ case o: DoubleAccumulator =>
_sum += o.sum
_count += o.count
- case _ => throw new UnsupportedOperationException(
- s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
- }
-
- override def localValue: jl.Double = if (_count == 0) {
- Double.NaN
- } else {
- _sum / _count
+ case _ =>
+ throw new UnsupportedOperationException(
+ s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
}
- def sum: Double = _sum
+ private[spark] def setValue(newValue: Double): Unit = _sum = newValue
- def count: Long = _count
+ override def localValue: jl.Double = _sum
}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 58618b4192..e391599336 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -1341,28 +1341,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
}
/**
- * Create and register an average accumulator, which accumulates double inputs by recording the
- * total sum and total count, and produce the output by sum / total. Note that Double.NaN will be
- * returned if no input is added.
- */
- def averageAccumulator: AverageAccumulator = {
- val acc = new AverageAccumulator
- register(acc)
- acc
- }
-
- /**
- * Create and register an average accumulator, which accumulates double inputs by recording the
- * total sum and total count, and produce the output by sum / total. Note that Double.NaN will be
- * returned if no input is added.
- */
- def averageAccumulator(name: String): AverageAccumulator = {
- val acc = new AverageAccumulator
- register(acc, name)
- acc
- }
-
- /**
* Create and register a list accumulator, which starts with empty list and accumulates inputs
* by adding them into the inner list.
*/
diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
index 09eb9c1dbd..0020096254 100644
--- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
@@ -28,7 +28,7 @@ import scala.util.control.NonFatal
import org.scalatest.Matchers
import org.scalatest.exceptions.TestFailedException
-import org.apache.spark.AccumulatorParam.{ListAccumulatorParam, StringAccumulatorParam}
+import org.apache.spark.AccumulatorParam.StringAccumulatorParam
import org.apache.spark.scheduler._
import org.apache.spark.serializer.JavaSerializer
@@ -234,21 +234,6 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex
acc.merge("kindness")
assert(acc.value === "kindness")
}
-
- test("list accumulator param") {
- val acc = new Accumulator(Seq.empty[Int], new ListAccumulatorParam[Int], Some("numbers"))
- assert(acc.value === Seq.empty[Int])
- acc.add(Seq(1, 2))
- assert(acc.value === Seq(1, 2))
- acc += Seq(3, 4)
- assert(acc.value === Seq(1, 2, 3, 4))
- acc ++= Seq(5, 6)
- assert(acc.value === Seq(1, 2, 3, 4, 5, 6))
- acc.merge(Seq(7, 8))
- assert(acc.value === Seq(1, 2, 3, 4, 5, 6, 7, 8))
- acc.setValue(Seq(9, 10))
- assert(acc.value === Seq(9, 10))
- }
}
private[spark] object AccumulatorSuite {
diff --git a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala
new file mode 100644
index 0000000000..41cdd02492
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import org.apache.spark.{DoubleAccumulator, LongAccumulator, SparkFunSuite}
+
+class AccumulatorV2Suite extends SparkFunSuite {
+
+ test("LongAccumulator add/avg/sum/count/isZero") {
+ val acc = new LongAccumulator
+ assert(acc.isZero)
+ assert(acc.count == 0)
+ assert(acc.sum == 0)
+ assert(acc.avg.isNaN)
+
+ acc.add(0)
+ assert(!acc.isZero)
+ assert(acc.count == 1)
+ assert(acc.sum == 0)
+ assert(acc.avg == 0.0)
+
+ acc.add(1)
+ assert(acc.count == 2)
+ assert(acc.sum == 1)
+ assert(acc.avg == 0.5)
+
+ // Also test add using non-specialized add function
+ acc.add(new java.lang.Long(2))
+ assert(acc.count == 3)
+ assert(acc.sum == 3)
+ assert(acc.avg == 1.0)
+
+ // Test merging
+ val acc2 = new LongAccumulator
+ acc2.add(2)
+ acc.merge(acc2)
+ assert(acc.count == 4)
+ assert(acc.sum == 5)
+ assert(acc.avg == 1.25)
+ }
+
+ test("DoubleAccumulator add/avg/sum/count/isZero") {
+ val acc = new DoubleAccumulator
+ assert(acc.isZero)
+ assert(acc.count == 0)
+ assert(acc.sum == 0.0)
+ assert(acc.avg.isNaN)
+
+ acc.add(0.0)
+ assert(!acc.isZero)
+ assert(acc.count == 1)
+ assert(acc.sum == 0.0)
+ assert(acc.avg == 0.0)
+
+ acc.add(1.0)
+ assert(acc.count == 2)
+ assert(acc.sum == 1.0)
+ assert(acc.avg == 0.5)
+
+ // Also test add using non-specialized add function
+ acc.add(new java.lang.Double(2.0))
+ assert(acc.count == 3)
+ assert(acc.sum == 3.0)
+ assert(acc.avg == 1.0)
+
+ // Test merging
+ val acc2 = new DoubleAccumulator
+ acc2.add(2.0)
+ acc.merge(acc2)
+ assert(acc.count == 4)
+ assert(acc.sum == 5.0)
+ assert(acc.avg == 1.25)
+ }
+}