aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSandeep Singh <sandeep@techaddict.me>2016-05-12 11:12:09 +0800
committerWenchen Fan <wenchen@databricks.com>2016-05-12 11:12:09 +0800
commitff92eb2e80f2f38d10ac524ced82bb3f94b5b2bf (patch)
treeb5916751f27515ffae96f4e8342dea22be9eb9ea
parentdb573fc743d12446dd0421fb45d00c2f541eaf9a (diff)
downloadspark-ff92eb2e80f2f38d10ac524ced82bb3f94b5b2bf.tar.gz
spark-ff92eb2e80f2f38d10ac524ced82bb3f94b5b2bf.tar.bz2
spark-ff92eb2e80f2f38d10ac524ced82bb3f94b5b2bf.zip
[SPARK-15080][CORE] Break copyAndReset into copy and reset
## What changes were proposed in this pull request? Break copyAndReset into two methods copy and reset instead of just one. ## How was this patch tested? Existing Tests Author: Sandeep Singh <sandeep@techaddict.me> Closes #12936 from techaddict/SPARK-15080.
-rw-r--r--core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala75
-rw-r--r--core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala17
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala13
4 files changed, 96 insertions, 19 deletions
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index 1893167cf7..5bb505bf09 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -291,12 +291,20 @@ private[spark] object TaskMetrics extends Logging {
private[spark] class BlockStatusesAccumulator
extends AccumulatorV2[(BlockId, BlockStatus), Seq[(BlockId, BlockStatus)]] {
- private[this] var _seq = ArrayBuffer.empty[(BlockId, BlockStatus)]
+ private var _seq = ArrayBuffer.empty[(BlockId, BlockStatus)]
override def isZero(): Boolean = _seq.isEmpty
override def copyAndReset(): BlockStatusesAccumulator = new BlockStatusesAccumulator
+ override def copy(): BlockStatusesAccumulator = {
+ val newAcc = new BlockStatusesAccumulator
+ newAcc._seq = _seq.clone()
+ newAcc
+ }
+
+ override def reset(): Unit = _seq.clear()
+
override def add(v: (BlockId, BlockStatus)): Unit = _seq += v
override def merge(other: AccumulatorV2[(BlockId, BlockStatus), Seq[(BlockId, BlockStatus)]])
diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
index c4879036f6..0cf9df084f 100644
--- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
+++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
@@ -112,7 +112,22 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable {
* Creates a new copy of this accumulator, which is zero value. i.e. call `isZero` on the copy
* must return true.
*/
- def copyAndReset(): AccumulatorV2[IN, OUT]
+ def copyAndReset(): AccumulatorV2[IN, OUT] = {
+ val copyAcc = copy()
+ copyAcc.reset()
+ copyAcc
+ }
+
+ /**
+ * Creates a new copy of this accumulator.
+ */
+ def copy(): AccumulatorV2[IN, OUT]
+
+ /**
+ * Resets this accumulator, which is zero value. i.e. call `isZero` must
+ * return true.
+ */
+ def reset(): Unit
/**
* Takes the inputs and accumulates. e.g. it can be a simple `+=` for counter accumulator.
@@ -137,10 +152,10 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable {
throw new UnsupportedOperationException(
"Accumulator must be registered before send to executor")
}
- val copy = copyAndReset()
- assert(copy.isZero, "copyAndReset must return a zero value copy")
- copy.metadata = metadata
- copy
+ val copyAcc = copyAndReset()
+ assert(copyAcc.isZero, "copyAndReset must return a zero value copy")
+ copyAcc.metadata = metadata
+ copyAcc
} else {
this
}
@@ -249,8 +264,8 @@ private[spark] object AccumulatorContext {
* @since 2.0.0
*/
class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] {
- private[this] var _sum = 0L
- private[this] var _count = 0L
+ private var _sum = 0L
+ private var _count = 0L
/**
* Adds v to the accumulator, i.e. increment sum by v and count by 1.
@@ -258,7 +273,17 @@ class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] {
*/
override def isZero: Boolean = _sum == 0L && _count == 0
- override def copyAndReset(): LongAccumulator = new LongAccumulator
+ override def copy(): LongAccumulator = {
+ val newAcc = new LongAccumulator
+ newAcc._count = this._count
+ newAcc._sum = this._sum
+ newAcc
+ }
+
+ override def reset(): Unit = {
+ _sum = 0L
+ _count = 0L
+ }
/**
* Adds v to the accumulator, i.e. increment sum by v and count by 1.
@@ -318,12 +343,22 @@ class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] {
* @since 2.0.0
*/
class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] {
- private[this] var _sum = 0.0
- private[this] var _count = 0L
+ private var _sum = 0.0
+ private var _count = 0L
override def isZero: Boolean = _sum == 0.0 && _count == 0
- override def copyAndReset(): DoubleAccumulator = new DoubleAccumulator
+ override def copy(): DoubleAccumulator = {
+ val newAcc = new DoubleAccumulator
+ newAcc._count = this._count
+ newAcc._sum = this._sum
+ newAcc
+ }
+
+ override def reset(): Unit = {
+ _sum = 0.0
+ _count = 0L
+ }
/**
* Adds v to the accumulator, i.e. increment sum by v and count by 1.
@@ -377,12 +412,20 @@ class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] {
class ListAccumulator[T] extends AccumulatorV2[T, java.util.List[T]] {
- private[this] val _list: java.util.List[T] = new java.util.ArrayList[T]
+ private val _list: java.util.List[T] = new java.util.ArrayList[T]
override def isZero: Boolean = _list.isEmpty
override def copyAndReset(): ListAccumulator[T] = new ListAccumulator
+ override def copy(): ListAccumulator[T] = {
+ val newAcc = new ListAccumulator[T]
+ newAcc._list.addAll(_list)
+ newAcc
+ }
+
+ override def reset(): Unit = _list.clear()
+
override def add(v: T): Unit = _list.add(v)
override def merge(other: AccumulatorV2[T, java.util.List[T]]): Unit = other match {
@@ -407,12 +450,16 @@ class LegacyAccumulatorWrapper[R, T](
override def isZero: Boolean = _value == param.zero(initialValue)
- override def copyAndReset(): LegacyAccumulatorWrapper[R, T] = {
+ override def copy(): LegacyAccumulatorWrapper[R, T] = {
val acc = new LegacyAccumulatorWrapper(initialValue, param)
- acc._value = param.zero(initialValue)
+ acc._value = _value
acc
}
+ override def reset(): Unit = {
+ _value = param.zero(initialValue)
+ }
+
override def add(v: T): Unit = _value = param.addAccumulator(_value, v)
override def merge(other: AccumulatorV2[T, R]): Unit = other match {
diff --git a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala
index ecaf4f0c64..439da1306f 100644
--- a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala
+++ b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala
@@ -116,6 +116,15 @@ class AccumulatorV2Suite extends SparkFunSuite {
assert(acc.value.contains(2.0))
assert(!acc.isZero)
assert(acc.value.size() === 3)
+
+ val acc3 = acc.copy()
+ assert(acc3.value.contains(2.0))
+ assert(!acc3.isZero)
+ assert(acc3.value.size() === 3)
+
+ acc3.reset()
+ assert(acc3.isZero)
+ assert(acc3.value.isEmpty)
}
test("LegacyAccumulatorWrapper") {
@@ -144,5 +153,13 @@ class AccumulatorV2Suite extends SparkFunSuite {
acc.merge(acc2)
assert(acc.value === "baz")
assert(!acc.isZero)
+
+ val acc3 = acc.copy()
+ assert(acc3.value === "baz")
+ assert(!acc3.isZero)
+
+ acc3.reset()
+ assert(acc3.isZero)
+ assert(acc3.value === "")
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
index 786110477d..d6de15494f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
@@ -30,8 +30,15 @@ class SQLMetric(val metricType: String, initValue: Long = 0L) extends Accumulato
// update it at the end of task and the value will be at least 0. Then we can filter out the -1
// values before calculate max, min, etc.
private[this] var _value = initValue
+ private var _zeroValue = initValue
- override def copyAndReset(): SQLMetric = new SQLMetric(metricType, initValue)
+ override def copy(): SQLMetric = {
+ val newAcc = new SQLMetric(metricType, _value)
+ newAcc._zeroValue = initValue
+ newAcc
+ }
+
+ override def reset(): Unit = _value = _zeroValue
override def merge(other: AccumulatorV2[Long, Long]): Unit = other match {
case o: SQLMetric => _value += o.value
@@ -39,7 +46,7 @@ class SQLMetric(val metricType: String, initValue: Long = 0L) extends Accumulato
s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
}
- override def isZero(): Boolean = _value == initValue
+ override def isZero(): Boolean = _value == _zeroValue
override def add(v: Long): Unit = _value += v
@@ -51,8 +58,6 @@ class SQLMetric(val metricType: String, initValue: Long = 0L) extends Accumulato
private[spark] override def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = {
new AccumulableInfo(id, name, update, value, true, true, Some(SQLMetrics.ACCUM_IDENTIFIER))
}
-
- def reset(): Unit = _value = initValue
}