From bb9ab56b960153d374d7e8838f62a18e7e45481e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 2 May 2016 21:12:48 -0700 Subject: [SPARK-15079] Support average/count/sum in Long/DoubleAccumulator ## What changes were proposed in this pull request? This patch removes AverageAccumulator and adds the ability to compute average to LongAccumulator and DoubleAccumulator. The patch also improves documentation for the two accumulators. ## How was this patch tested? Added unit tests for this. Author: Reynold Xin Closes #12858 from rxin/SPARK-15079. --- .../main/scala/org/apache/spark/Accumulator.scala | 17 --- .../scala/org/apache/spark/AccumulatorV2.scala | 137 ++++++++++++++------- .../main/scala/org/apache/spark/SparkContext.scala | 22 ---- .../scala/org/apache/spark/AccumulatorSuite.scala | 17 +-- .../org/apache/spark/util/AccumulatorV2Suite.scala | 89 +++++++++++++ 5 files changed, 181 insertions(+), 101 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala (limited to 'core/src') diff --git a/core/src/main/scala/org/apache/spark/Accumulator.scala b/core/src/main/scala/org/apache/spark/Accumulator.scala index e52d36b7b5..23245043e2 100644 --- a/core/src/main/scala/org/apache/spark/Accumulator.scala +++ b/core/src/main/scala/org/apache/spark/Accumulator.scala @@ -17,9 +17,6 @@ package org.apache.spark -import org.apache.spark.storage.{BlockId, BlockStatus} - - /** * A simpler value of [[Accumulable]] where the result type being accumulated is the same * as the types of elements being merged, i.e. variables that are only "added" to through an @@ -117,18 +114,4 @@ object AccumulatorParam { def addInPlace(t1: String, t2: String): String = t2 def zero(initialValue: String): String = "" } - - // Note: this is expensive as it makes a copy of the list every time the caller adds an item. - // A better way to use this is to first accumulate the values yourself then them all at once. - @deprecated("use AccumulatorV2", "2.0.0") - private[spark] class ListAccumulatorParam[T] extends AccumulatorParam[Seq[T]] { - def addInPlace(t1: Seq[T], t2: Seq[T]): Seq[T] = t1 ++ t2 - def zero(initialValue: Seq[T]): Seq[T] = Seq.empty[T] - } - - // For the internal metric that records what blocks are updated in a particular task - @deprecated("use AccumulatorV2", "2.0.0") - private[spark] object UpdatedBlockStatusesAccumulatorParam - extends ListAccumulatorParam[(BlockId, BlockStatus)] - } diff --git a/core/src/main/scala/org/apache/spark/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/AccumulatorV2.scala index c65108a55e..a6c64fd680 100644 --- a/core/src/main/scala/org/apache/spark/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/AccumulatorV2.scala @@ -257,23 +257,66 @@ private[spark] object AccumulatorContext { } +/** + * An [[AccumulatorV2 accumulator]] for computing sum, count, and averages for 64-bit integers. + * + * @since 2.0.0 + */ class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] { private[this] var _sum = 0L + private[this] var _count = 0L - override def isZero: Boolean = _sum == 0 + /** + * Adds v to the accumulator, i.e. increment sum by v and count by 1. + * @since 2.0.0 + */ + override def isZero: Boolean = _count == 0L override def copyAndReset(): LongAccumulator = new LongAccumulator - override def add(v: jl.Long): Unit = _sum += v + /** + * Adds v to the accumulator, i.e. increment sum by v and count by 1. + * @since 2.0.0 + */ + override def add(v: jl.Long): Unit = { + _sum += v + _count += 1 + } + + /** + * Adds v to the accumulator, i.e. increment sum by v and count by 1. + * @since 2.0.0 + */ + def add(v: Long): Unit = { + _sum += v + _count += 1 + } - def add(v: Long): Unit = _sum += v + /** + * Returns the number of elements added to the accumulator. + * @since 2.0.0 + */ + def count: Long = _count + /** + * Returns the sum of elements added to the accumulator. + * @since 2.0.0 + */ def sum: Long = _sum + /** + * Returns the average of elements added to the accumulator. + * @since 2.0.0 + */ + def avg: Double = _sum.toDouble / _count + override def merge(other: AccumulatorV2[jl.Long, jl.Long]): Unit = other match { - case o: LongAccumulator => _sum += o.sum - case _ => throw new UnsupportedOperationException( - s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") + case o: LongAccumulator => + _sum += o.sum + _count += o.count + case _ => + throw new UnsupportedOperationException( + s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") } private[spark] def setValue(newValue: Long): Unit = _sum = newValue @@ -282,66 +325,68 @@ class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] { } +/** + * An [[AccumulatorV2 accumulator]] for computing sum, count, and averages for double precision + * floating numbers. + * + * @since 2.0.0 + */ class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] { private[this] var _sum = 0.0 - - override def isZero: Boolean = _sum == 0.0 - - override def copyAndReset(): DoubleAccumulator = new DoubleAccumulator - - override def add(v: jl.Double): Unit = _sum += v - - def add(v: Double): Unit = _sum += v - - def sum: Double = _sum - - override def merge(other: AccumulatorV2[jl.Double, jl.Double]): Unit = other match { - case o: DoubleAccumulator => _sum += o.sum - case _ => throw new UnsupportedOperationException( - s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") - } - - private[spark] def setValue(newValue: Double): Unit = _sum = newValue - - override def localValue: jl.Double = _sum -} - - -class AverageAccumulator extends AccumulatorV2[jl.Double, jl.Double] { - private[this] var _sum = 0.0 private[this] var _count = 0L - override def isZero: Boolean = _sum == 0.0 && _count == 0 + override def isZero: Boolean = _count == 0L - override def copyAndReset(): AverageAccumulator = new AverageAccumulator + override def copyAndReset(): DoubleAccumulator = new DoubleAccumulator + /** + * Adds v to the accumulator, i.e. increment sum by v and count by 1. + * @since 2.0.0 + */ override def add(v: jl.Double): Unit = { _sum += v _count += 1 } - def add(d: Double): Unit = { - _sum += d + /** + * Adds v to the accumulator, i.e. increment sum by v and count by 1. + * @since 2.0.0 + */ + def add(v: Double): Unit = { + _sum += v _count += 1 } + /** + * Returns the number of elements added to the accumulator. + * @since 2.0.0 + */ + def count: Long = _count + + /** + * Returns the sum of elements added to the accumulator. + * @since 2.0.0 + */ + def sum: Double = _sum + + /** + * Returns the average of elements added to the accumulator. + * @since 2.0.0 + */ + def avg: Double = _sum / _count + override def merge(other: AccumulatorV2[jl.Double, jl.Double]): Unit = other match { - case o: AverageAccumulator => + case o: DoubleAccumulator => _sum += o.sum _count += o.count - case _ => throw new UnsupportedOperationException( - s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") - } - - override def localValue: jl.Double = if (_count == 0) { - Double.NaN - } else { - _sum / _count + case _ => + throw new UnsupportedOperationException( + s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") } - def sum: Double = _sum + private[spark] def setValue(newValue: Double): Unit = _sum = newValue - def count: Long = _count + override def localValue: jl.Double = _sum } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 58618b4192..e391599336 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1340,28 +1340,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli acc } - /** - * Create and register an average accumulator, which accumulates double inputs by recording the - * total sum and total count, and produce the output by sum / total. Note that Double.NaN will be - * returned if no input is added. - */ - def averageAccumulator: AverageAccumulator = { - val acc = new AverageAccumulator - register(acc) - acc - } - - /** - * Create and register an average accumulator, which accumulates double inputs by recording the - * total sum and total count, and produce the output by sum / total. Note that Double.NaN will be - * returned if no input is added. - */ - def averageAccumulator(name: String): AverageAccumulator = { - val acc = new AverageAccumulator - register(acc, name) - acc - } - /** * Create and register a list accumulator, which starts with empty list and accumulates inputs * by adding them into the inner list. diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index 09eb9c1dbd..0020096254 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -28,7 +28,7 @@ import scala.util.control.NonFatal import org.scalatest.Matchers import org.scalatest.exceptions.TestFailedException -import org.apache.spark.AccumulatorParam.{ListAccumulatorParam, StringAccumulatorParam} +import org.apache.spark.AccumulatorParam.StringAccumulatorParam import org.apache.spark.scheduler._ import org.apache.spark.serializer.JavaSerializer @@ -234,21 +234,6 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex acc.merge("kindness") assert(acc.value === "kindness") } - - test("list accumulator param") { - val acc = new Accumulator(Seq.empty[Int], new ListAccumulatorParam[Int], Some("numbers")) - assert(acc.value === Seq.empty[Int]) - acc.add(Seq(1, 2)) - assert(acc.value === Seq(1, 2)) - acc += Seq(3, 4) - assert(acc.value === Seq(1, 2, 3, 4)) - acc ++= Seq(5, 6) - assert(acc.value === Seq(1, 2, 3, 4, 5, 6)) - acc.merge(Seq(7, 8)) - assert(acc.value === Seq(1, 2, 3, 4, 5, 6, 7, 8)) - acc.setValue(Seq(9, 10)) - assert(acc.value === Seq(9, 10)) - } } private[spark] object AccumulatorSuite { diff --git a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala new file mode 100644 index 0000000000..41cdd02492 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import org.apache.spark.{DoubleAccumulator, LongAccumulator, SparkFunSuite} + +class AccumulatorV2Suite extends SparkFunSuite { + + test("LongAccumulator add/avg/sum/count/isZero") { + val acc = new LongAccumulator + assert(acc.isZero) + assert(acc.count == 0) + assert(acc.sum == 0) + assert(acc.avg.isNaN) + + acc.add(0) + assert(!acc.isZero) + assert(acc.count == 1) + assert(acc.sum == 0) + assert(acc.avg == 0.0) + + acc.add(1) + assert(acc.count == 2) + assert(acc.sum == 1) + assert(acc.avg == 0.5) + + // Also test add using non-specialized add function + acc.add(new java.lang.Long(2)) + assert(acc.count == 3) + assert(acc.sum == 3) + assert(acc.avg == 1.0) + + // Test merging + val acc2 = new LongAccumulator + acc2.add(2) + acc.merge(acc2) + assert(acc.count == 4) + assert(acc.sum == 5) + assert(acc.avg == 1.25) + } + + test("DoubleAccumulator add/avg/sum/count/isZero") { + val acc = new DoubleAccumulator + assert(acc.isZero) + assert(acc.count == 0) + assert(acc.sum == 0.0) + assert(acc.avg.isNaN) + + acc.add(0.0) + assert(!acc.isZero) + assert(acc.count == 1) + assert(acc.sum == 0.0) + assert(acc.avg == 0.0) + + acc.add(1.0) + assert(acc.count == 2) + assert(acc.sum == 1.0) + assert(acc.avg == 0.5) + + // Also test add using non-specialized add function + acc.add(new java.lang.Double(2.0)) + assert(acc.count == 3) + assert(acc.sum == 3.0) + assert(acc.avg == 1.0) + + // Test merging + val acc2 = new DoubleAccumulator + acc2.add(2.0) + acc.merge(acc2) + assert(acc.count == 4) + assert(acc.sum == 5.0) + assert(acc.avg == 1.25) + } +} -- cgit v1.2.3