aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/org/apache/spark/AccumulatorV2.scala
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2016-05-02 21:12:48 -0700
committerReynold Xin <rxin@databricks.com>2016-05-02 21:12:48 -0700
commitbb9ab56b960153d374d7e8838f62a18e7e45481e (patch)
treef24cadbd8818550f4a5ecb31fb3d406e7fce5e6b /core/src/main/scala/org/apache/spark/AccumulatorV2.scala
parent8028f3a0b4003af15ed44d9ef4727b56f4b10534 (diff)
downloadspark-bb9ab56b960153d374d7e8838f62a18e7e45481e.tar.gz
spark-bb9ab56b960153d374d7e8838f62a18e7e45481e.tar.bz2
spark-bb9ab56b960153d374d7e8838f62a18e7e45481e.zip
[SPARK-15079] Support average/count/sum in Long/DoubleAccumulator
## What changes were proposed in this pull request? This patch removes AverageAccumulator and adds the ability to compute average to LongAccumulator and DoubleAccumulator. The patch also improves documentation for the two accumulators. ## How was this patch tested? Added unit tests for this. Author: Reynold Xin <rxin@databricks.com> Closes #12858 from rxin/SPARK-15079.
Diffstat (limited to 'core/src/main/scala/org/apache/spark/AccumulatorV2.scala')
-rw-r--r--core/src/main/scala/org/apache/spark/AccumulatorV2.scala137
1 files changed, 91 insertions, 46 deletions
diff --git a/core/src/main/scala/org/apache/spark/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/AccumulatorV2.scala
index c65108a55e..a6c64fd680 100644
--- a/core/src/main/scala/org/apache/spark/AccumulatorV2.scala
+++ b/core/src/main/scala/org/apache/spark/AccumulatorV2.scala
@@ -257,23 +257,66 @@ private[spark] object AccumulatorContext {
}
+/**
+ * An [[AccumulatorV2 accumulator]] for computing sum, count, and averages for 64-bit integers.
+ *
+ * @since 2.0.0
+ */
class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] {
private[this] var _sum = 0L
+ private[this] var _count = 0L
- override def isZero: Boolean = _sum == 0
+ /**
+ * Adds v to the accumulator, i.e. increment sum by v and count by 1.
+ * @since 2.0.0
+ */
+ override def isZero: Boolean = _count == 0L
override def copyAndReset(): LongAccumulator = new LongAccumulator
- override def add(v: jl.Long): Unit = _sum += v
+ /**
+ * Adds v to the accumulator, i.e. increment sum by v and count by 1.
+ * @since 2.0.0
+ */
+ override def add(v: jl.Long): Unit = {
+ _sum += v
+ _count += 1
+ }
+
+ /**
+ * Adds v to the accumulator, i.e. increment sum by v and count by 1.
+ * @since 2.0.0
+ */
+ def add(v: Long): Unit = {
+ _sum += v
+ _count += 1
+ }
- def add(v: Long): Unit = _sum += v
+ /**
+ * Returns the number of elements added to the accumulator.
+ * @since 2.0.0
+ */
+ def count: Long = _count
+ /**
+ * Returns the sum of elements added to the accumulator.
+ * @since 2.0.0
+ */
def sum: Long = _sum
+ /**
+ * Returns the average of elements added to the accumulator.
+ * @since 2.0.0
+ */
+ def avg: Double = _sum.toDouble / _count
+
override def merge(other: AccumulatorV2[jl.Long, jl.Long]): Unit = other match {
- case o: LongAccumulator => _sum += o.sum
- case _ => throw new UnsupportedOperationException(
- s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
+ case o: LongAccumulator =>
+ _sum += o.sum
+ _count += o.count
+ case _ =>
+ throw new UnsupportedOperationException(
+ s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
}
private[spark] def setValue(newValue: Long): Unit = _sum = newValue
@@ -282,66 +325,68 @@ class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] {
}
+/**
+ * An [[AccumulatorV2 accumulator]] for computing sum, count, and averages for double precision
+ * floating numbers.
+ *
+ * @since 2.0.0
+ */
class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] {
private[this] var _sum = 0.0
-
- override def isZero: Boolean = _sum == 0.0
-
- override def copyAndReset(): DoubleAccumulator = new DoubleAccumulator
-
- override def add(v: jl.Double): Unit = _sum += v
-
- def add(v: Double): Unit = _sum += v
-
- def sum: Double = _sum
-
- override def merge(other: AccumulatorV2[jl.Double, jl.Double]): Unit = other match {
- case o: DoubleAccumulator => _sum += o.sum
- case _ => throw new UnsupportedOperationException(
- s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
- }
-
- private[spark] def setValue(newValue: Double): Unit = _sum = newValue
-
- override def localValue: jl.Double = _sum
-}
-
-
-class AverageAccumulator extends AccumulatorV2[jl.Double, jl.Double] {
- private[this] var _sum = 0.0
private[this] var _count = 0L
- override def isZero: Boolean = _sum == 0.0 && _count == 0
+ override def isZero: Boolean = _count == 0L
- override def copyAndReset(): AverageAccumulator = new AverageAccumulator
+ override def copyAndReset(): DoubleAccumulator = new DoubleAccumulator
+ /**
+ * Adds v to the accumulator, i.e. increment sum by v and count by 1.
+ * @since 2.0.0
+ */
override def add(v: jl.Double): Unit = {
_sum += v
_count += 1
}
- def add(d: Double): Unit = {
- _sum += d
+ /**
+ * Adds v to the accumulator, i.e. increment sum by v and count by 1.
+ * @since 2.0.0
+ */
+ def add(v: Double): Unit = {
+ _sum += v
_count += 1
}
+ /**
+ * Returns the number of elements added to the accumulator.
+ * @since 2.0.0
+ */
+ def count: Long = _count
+
+ /**
+ * Returns the sum of elements added to the accumulator.
+ * @since 2.0.0
+ */
+ def sum: Double = _sum
+
+ /**
+ * Returns the average of elements added to the accumulator.
+ * @since 2.0.0
+ */
+ def avg: Double = _sum / _count
+
override def merge(other: AccumulatorV2[jl.Double, jl.Double]): Unit = other match {
- case o: AverageAccumulator =>
+ case o: DoubleAccumulator =>
_sum += o.sum
_count += o.count
- case _ => throw new UnsupportedOperationException(
- s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
- }
-
- override def localValue: jl.Double = if (_count == 0) {
- Double.NaN
- } else {
- _sum / _count
+ case _ =>
+ throw new UnsupportedOperationException(
+ s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
}
- def sum: Double = _sum
+ private[spark] def setValue(newValue: Double): Unit = _sum = newValue
- def count: Long = _count
+ override def localValue: jl.Double = _sum
}