aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/org/apache/spark/AccumulatorV2.scala
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main/scala/org/apache/spark/AccumulatorV2.scala')
-rw-r--r--core/src/main/scala/org/apache/spark/AccumulatorV2.scala137
1 files changed, 91 insertions, 46 deletions
diff --git a/core/src/main/scala/org/apache/spark/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/AccumulatorV2.scala
index c65108a55e..a6c64fd680 100644
--- a/core/src/main/scala/org/apache/spark/AccumulatorV2.scala
+++ b/core/src/main/scala/org/apache/spark/AccumulatorV2.scala
@@ -257,23 +257,66 @@ private[spark] object AccumulatorContext {
}
+/**
+ * An [[AccumulatorV2 accumulator]] for computing sum, count, and averages for 64-bit integers.
+ *
+ * @since 2.0.0
+ */
class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] {
private[this] var _sum = 0L
+ private[this] var _count = 0L
- override def isZero: Boolean = _sum == 0
+ /**
+ * Adds v to the accumulator, i.e. increment sum by v and count by 1.
+ * @since 2.0.0
+ */
+ override def isZero: Boolean = _count == 0L
override def copyAndReset(): LongAccumulator = new LongAccumulator
- override def add(v: jl.Long): Unit = _sum += v
+ /**
+ * Adds v to the accumulator, i.e. increment sum by v and count by 1.
+ * @since 2.0.0
+ */
+ override def add(v: jl.Long): Unit = {
+ _sum += v
+ _count += 1
+ }
+
+ /**
+ * Adds v to the accumulator, i.e. increment sum by v and count by 1.
+ * @since 2.0.0
+ */
+ def add(v: Long): Unit = {
+ _sum += v
+ _count += 1
+ }
- def add(v: Long): Unit = _sum += v
+ /**
+ * Returns the number of elements added to the accumulator.
+ * @since 2.0.0
+ */
+ def count: Long = _count
+ /**
+ * Returns the sum of elements added to the accumulator.
+ * @since 2.0.0
+ */
def sum: Long = _sum
+ /**
+ * Returns the average of elements added to the accumulator.
+ * @since 2.0.0
+ */
+ def avg: Double = _sum.toDouble / _count
+
override def merge(other: AccumulatorV2[jl.Long, jl.Long]): Unit = other match {
- case o: LongAccumulator => _sum += o.sum
- case _ => throw new UnsupportedOperationException(
- s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
+ case o: LongAccumulator =>
+ _sum += o.sum
+ _count += o.count
+ case _ =>
+ throw new UnsupportedOperationException(
+ s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
}
private[spark] def setValue(newValue: Long): Unit = _sum = newValue
@@ -282,66 +325,68 @@ class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] {
}
+/**
+ * An [[AccumulatorV2 accumulator]] for computing sum, count, and averages for double precision
+ * floating numbers.
+ *
+ * @since 2.0.0
+ */
class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] {
private[this] var _sum = 0.0
-
- override def isZero: Boolean = _sum == 0.0
-
- override def copyAndReset(): DoubleAccumulator = new DoubleAccumulator
-
- override def add(v: jl.Double): Unit = _sum += v
-
- def add(v: Double): Unit = _sum += v
-
- def sum: Double = _sum
-
- override def merge(other: AccumulatorV2[jl.Double, jl.Double]): Unit = other match {
- case o: DoubleAccumulator => _sum += o.sum
- case _ => throw new UnsupportedOperationException(
- s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
- }
-
- private[spark] def setValue(newValue: Double): Unit = _sum = newValue
-
- override def localValue: jl.Double = _sum
-}
-
-
-class AverageAccumulator extends AccumulatorV2[jl.Double, jl.Double] {
- private[this] var _sum = 0.0
private[this] var _count = 0L
- override def isZero: Boolean = _sum == 0.0 && _count == 0
+ override def isZero: Boolean = _count == 0L
- override def copyAndReset(): AverageAccumulator = new AverageAccumulator
+ override def copyAndReset(): DoubleAccumulator = new DoubleAccumulator
+ /**
+ * Adds v to the accumulator, i.e. increment sum by v and count by 1.
+ * @since 2.0.0
+ */
override def add(v: jl.Double): Unit = {
_sum += v
_count += 1
}
- def add(d: Double): Unit = {
- _sum += d
+ /**
+ * Adds v to the accumulator, i.e. increment sum by v and count by 1.
+ * @since 2.0.0
+ */
+ def add(v: Double): Unit = {
+ _sum += v
_count += 1
}
+ /**
+ * Returns the number of elements added to the accumulator.
+ * @since 2.0.0
+ */
+ def count: Long = _count
+
+ /**
+ * Returns the sum of elements added to the accumulator.
+ * @since 2.0.0
+ */
+ def sum: Double = _sum
+
+ /**
+ * Returns the average of elements added to the accumulator.
+ * @since 2.0.0
+ */
+ def avg: Double = _sum / _count
+
override def merge(other: AccumulatorV2[jl.Double, jl.Double]): Unit = other match {
- case o: AverageAccumulator =>
+ case o: DoubleAccumulator =>
_sum += o.sum
_count += o.count
- case _ => throw new UnsupportedOperationException(
- s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
- }
-
- override def localValue: jl.Double = if (_count == 0) {
- Double.NaN
- } else {
- _sum / _count
+ case _ =>
+ throw new UnsupportedOperationException(
+ s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
}
- def sum: Double = _sum
+ private[spark] def setValue(newValue: Double): Unit = _sum = newValue
- def count: Long = _count
+ override def localValue: jl.Double = _sum
}