diff options
author | Reynold Xin <rxin@databricks.com> | 2016-05-02 21:12:48 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-05-02 21:12:48 -0700 |
commit | bb9ab56b960153d374d7e8838f62a18e7e45481e (patch) | |
tree | f24cadbd8818550f4a5ecb31fb3d406e7fce5e6b /core/src/main/scala/org/apache/spark/AccumulatorV2.scala | |
parent | 8028f3a0b4003af15ed44d9ef4727b56f4b10534 (diff) | |
download | spark-bb9ab56b960153d374d7e8838f62a18e7e45481e.tar.gz spark-bb9ab56b960153d374d7e8838f62a18e7e45481e.tar.bz2 spark-bb9ab56b960153d374d7e8838f62a18e7e45481e.zip |
[SPARK-15079] Support average/count/sum in Long/DoubleAccumulator
## What changes were proposed in this pull request?
This patch removes AverageAccumulator and adds the ability to compute average to LongAccumulator and DoubleAccumulator. The patch also improves documentation for the two accumulators.
## How was this patch tested?
Added unit tests for this.
Author: Reynold Xin <rxin@databricks.com>
Closes #12858 from rxin/SPARK-15079.
Diffstat (limited to 'core/src/main/scala/org/apache/spark/AccumulatorV2.scala')
-rw-r--r-- | core/src/main/scala/org/apache/spark/AccumulatorV2.scala | 137 |
1 files changed, 91 insertions, 46 deletions
diff --git a/core/src/main/scala/org/apache/spark/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/AccumulatorV2.scala index c65108a55e..a6c64fd680 100644 --- a/core/src/main/scala/org/apache/spark/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/AccumulatorV2.scala @@ -257,23 +257,66 @@ private[spark] object AccumulatorContext { } +/** + * An [[AccumulatorV2 accumulator]] for computing sum, count, and averages for 64-bit integers. + * + * @since 2.0.0 + */ class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] { private[this] var _sum = 0L + private[this] var _count = 0L - override def isZero: Boolean = _sum == 0 + /** + * Adds v to the accumulator, i.e. increment sum by v and count by 1. + * @since 2.0.0 + */ + override def isZero: Boolean = _count == 0L override def copyAndReset(): LongAccumulator = new LongAccumulator - override def add(v: jl.Long): Unit = _sum += v + /** + * Adds v to the accumulator, i.e. increment sum by v and count by 1. + * @since 2.0.0 + */ + override def add(v: jl.Long): Unit = { + _sum += v + _count += 1 + } + + /** + * Adds v to the accumulator, i.e. increment sum by v and count by 1. + * @since 2.0.0 + */ + def add(v: Long): Unit = { + _sum += v + _count += 1 + } - def add(v: Long): Unit = _sum += v + /** + * Returns the number of elements added to the accumulator. + * @since 2.0.0 + */ + def count: Long = _count + /** + * Returns the sum of elements added to the accumulator. + * @since 2.0.0 + */ def sum: Long = _sum + /** + * Returns the average of elements added to the accumulator. + * @since 2.0.0 + */ + def avg: Double = _sum.toDouble / _count + override def merge(other: AccumulatorV2[jl.Long, jl.Long]): Unit = other match { - case o: LongAccumulator => _sum += o.sum - case _ => throw new UnsupportedOperationException( - s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") + case o: LongAccumulator => + _sum += o.sum + _count += o.count + case _ => + throw new UnsupportedOperationException( + s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") } private[spark] def setValue(newValue: Long): Unit = _sum = newValue @@ -282,66 +325,68 @@ class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] { } +/** + * An [[AccumulatorV2 accumulator]] for computing sum, count, and averages for double precision + * floating numbers. + * + * @since 2.0.0 + */ class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] { private[this] var _sum = 0.0 - - override def isZero: Boolean = _sum == 0.0 - - override def copyAndReset(): DoubleAccumulator = new DoubleAccumulator - - override def add(v: jl.Double): Unit = _sum += v - - def add(v: Double): Unit = _sum += v - - def sum: Double = _sum - - override def merge(other: AccumulatorV2[jl.Double, jl.Double]): Unit = other match { - case o: DoubleAccumulator => _sum += o.sum - case _ => throw new UnsupportedOperationException( - s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") - } - - private[spark] def setValue(newValue: Double): Unit = _sum = newValue - - override def localValue: jl.Double = _sum -} - - -class AverageAccumulator extends AccumulatorV2[jl.Double, jl.Double] { - private[this] var _sum = 0.0 private[this] var _count = 0L - override def isZero: Boolean = _sum == 0.0 && _count == 0 + override def isZero: Boolean = _count == 0L - override def copyAndReset(): AverageAccumulator = new AverageAccumulator + override def copyAndReset(): DoubleAccumulator = new DoubleAccumulator + /** + * Adds v to the accumulator, i.e. increment sum by v and count by 1. + * @since 2.0.0 + */ override def add(v: jl.Double): Unit = { _sum += v _count += 1 } - def add(d: Double): Unit = { - _sum += d + /** + * Adds v to the accumulator, i.e. increment sum by v and count by 1. + * @since 2.0.0 + */ + def add(v: Double): Unit = { + _sum += v _count += 1 } + /** + * Returns the number of elements added to the accumulator. + * @since 2.0.0 + */ + def count: Long = _count + + /** + * Returns the sum of elements added to the accumulator. + * @since 2.0.0 + */ + def sum: Double = _sum + + /** + * Returns the average of elements added to the accumulator. + * @since 2.0.0 + */ + def avg: Double = _sum / _count + override def merge(other: AccumulatorV2[jl.Double, jl.Double]): Unit = other match { - case o: AverageAccumulator => + case o: DoubleAccumulator => _sum += o.sum _count += o.count - case _ => throw new UnsupportedOperationException( - s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") - } - - override def localValue: jl.Double = if (_count == 0) { - Double.NaN - } else { - _sum / _count + case _ => + throw new UnsupportedOperationException( + s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") } - def sum: Double = _sum + private[spark] def setValue(newValue: Double): Unit = _sum = newValue - def count: Long = _count + override def localValue: jl.Double = _sum } |