aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-04-28 00:26:39 -0700
committerReynold Xin <rxin@databricks.com>2016-04-28 00:26:39 -0700
commitbf5496dbdac75ea69081c95a92a29771e635ea98 (patch)
treeb6151d946b25171b5bbcd3252aa77f7f7a69f60c
parentbe317d4a90b3ca906fefeb438f89a09b1c7da5a8 (diff)
downloadspark-bf5496dbdac75ea69081c95a92a29771e635ea98.tar.gz
spark-bf5496dbdac75ea69081c95a92a29771e635ea98.tar.bz2
spark-bf5496dbdac75ea69081c95a92a29771e635ea98.zip
[SPARK-14654][CORE] New accumulator API
## What changes were proposed in this pull request? This PR introduces a new accumulator API which is much simpler than before: 1. the type hierarchy is simplified, now we only have an `Accumulator` class 2. Combine `initialValue` and `zeroValue` concepts into just one concept: `zeroValue` 3. there in only one `register` method, the accumulator registration and cleanup registration are combined. 4. the `id`,`name` and `countFailedValues` are combined into an `AccumulatorMetadata`, and is provided during registration. `SQLMetric` is a good example to show the simplicity of this new API. What we break: 1. no `setValue` anymore. In the new API, the intermedia type can be different from the result type, it's very hard to implement a general `setValue` 2. accumulator can't be serialized before registered. Problems need to be addressed in follow-ups: 1. with this new API, `AccumulatorInfo` doesn't make a lot of sense, the partial output is not partial updates, we need to expose the intermediate value. 2. `ExceptionFailure` should not carry the accumulator updates. Why do users care about accumulator updates for failed cases? It looks like we only use this feature to update the internal metrics, how about we sending a heartbeat to update internal metrics after the failure event? 3. the public event `SparkListenerTaskEnd` carries a `TaskMetrics`. Ideally this `TaskMetrics` don't need to carry external accumulators, as the only method of `TaskMetrics` that can access external accumulators is `private[spark]`. However, `SQLListener` use it to retrieve sql metrics. ## How was this patch tested? existing tests Author: Wenchen Fan <wenchen@databricks.com> Closes #12612 from cloud-fan/acc.
-rw-r--r--core/src/main/scala/org/apache/spark/Accumulable.scala64
-rw-r--r--core/src/main/scala/org/apache/spark/Accumulator.scala67
-rw-r--r--core/src/main/scala/org/apache/spark/ContextCleaner.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/NewAccumulator.scala391
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala107
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContext.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContextImpl.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/TaskEndReason.scala20
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/executor/InputMetrics.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala64
-rw-r--r--core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala19
-rw-r--r--core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala228
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala20
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Stage.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Task.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/status/api/v1/api.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala17
-rw-r--r--core/src/main/scala/org/apache/spark/util/JsonProtocol.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/AccumulatorSuite.scala132
-rw-r--r--core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala6
-rw-r--r--core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala24
-rw-r--r--core/src/test/scala/org/apache/spark/SparkFunSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala85
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala71
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala7
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala16
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala11
-rw-r--r--core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala12
-rw-r--r--core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala7
-rw-r--r--project/MimaExcludes.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregateExec.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala218
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala16
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala32
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala6
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala2
73 files changed, 1071 insertions, 842 deletions
diff --git a/core/src/main/scala/org/apache/spark/Accumulable.scala b/core/src/main/scala/org/apache/spark/Accumulable.scala
index e8f053c150..c76720c4bb 100644
--- a/core/src/main/scala/org/apache/spark/Accumulable.scala
+++ b/core/src/main/scala/org/apache/spark/Accumulable.scala
@@ -63,7 +63,7 @@ class Accumulable[R, T] private (
param: AccumulableParam[R, T],
name: Option[String],
countFailedValues: Boolean) = {
- this(Accumulators.newId(), initialValue, param, name, countFailedValues)
+ this(AccumulatorContext.newId(), initialValue, param, name, countFailedValues)
}
private[spark] def this(initialValue: R, param: AccumulableParam[R, T], name: Option[String]) = {
@@ -72,34 +72,23 @@ class Accumulable[R, T] private (
def this(initialValue: R, param: AccumulableParam[R, T]) = this(initialValue, param, None)
- @volatile @transient private var value_ : R = initialValue // Current value on driver
- val zero = param.zero(initialValue) // Zero value to be passed to executors
- private var deserialized = false
-
- Accumulators.register(this)
-
- /**
- * Return a copy of this [[Accumulable]].
- *
- * The copy will have the same ID as the original and will not be registered with
- * [[Accumulators]] again. This method exists so that the caller can avoid passing the
- * same mutable instance around.
- */
- private[spark] def copy(): Accumulable[R, T] = {
- new Accumulable[R, T](id, initialValue, param, name, countFailedValues)
- }
+ val zero = param.zero(initialValue)
+ private[spark] val newAcc = new LegacyAccumulatorWrapper(initialValue, param)
+ newAcc.metadata = AccumulatorMetadata(id, name, countFailedValues)
+ // Register the new accumulator in ctor, to follow the previous behaviour.
+ AccumulatorContext.register(newAcc)
/**
* Add more data to this accumulator / accumulable
* @param term the data to add
*/
- def += (term: T) { value_ = param.addAccumulator(value_, term) }
+ def += (term: T) { newAcc.add(term) }
/**
* Add more data to this accumulator / accumulable
* @param term the data to add
*/
- def add(term: T) { value_ = param.addAccumulator(value_, term) }
+ def add(term: T) { newAcc.add(term) }
/**
* Merge two accumulable objects together
@@ -107,7 +96,7 @@ class Accumulable[R, T] private (
* Normally, a user will not want to use this version, but will instead call `+=`.
* @param term the other `R` that will get merged with this
*/
- def ++= (term: R) { value_ = param.addInPlace(value_, term)}
+ def ++= (term: R) { newAcc._value = param.addInPlace(newAcc._value, term) }
/**
* Merge two accumulable objects together
@@ -115,18 +104,12 @@ class Accumulable[R, T] private (
* Normally, a user will not want to use this version, but will instead call `add`.
* @param term the other `R` that will get merged with this
*/
- def merge(term: R) { value_ = param.addInPlace(value_, term)}
+ def merge(term: R) { newAcc._value = param.addInPlace(newAcc._value, term) }
/**
* Access the accumulator's current value; only allowed on driver.
*/
- def value: R = {
- if (!deserialized) {
- value_
- } else {
- throw new UnsupportedOperationException("Can't read accumulator value in task")
- }
- }
+ def value: R = newAcc.value
/**
* Get the current value of this accumulator from within a task.
@@ -137,14 +120,14 @@ class Accumulable[R, T] private (
* The typical use of this method is to directly mutate the local value, eg., to add
* an element to a Set.
*/
- def localValue: R = value_
+ def localValue: R = newAcc.localValue
/**
* Set the accumulator's value; only allowed on driver.
*/
def value_= (newValue: R) {
- if (!deserialized) {
- value_ = newValue
+ if (newAcc.isAtDriverSide) {
+ newAcc._value = newValue
} else {
throw new UnsupportedOperationException("Can't assign accumulator value in task")
}
@@ -153,7 +136,7 @@ class Accumulable[R, T] private (
/**
* Set the accumulator's value. For internal use only.
*/
- def setValue(newValue: R): Unit = { value_ = newValue }
+ def setValue(newValue: R): Unit = { newAcc._value = newValue }
/**
* Set the accumulator's value. For internal use only.
@@ -168,22 +151,7 @@ class Accumulable[R, T] private (
new AccumulableInfo(id, name, update, value, isInternal, countFailedValues)
}
- // Called by Java when deserializing an object
- private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
- in.defaultReadObject()
- value_ = zero
- deserialized = true
-
- // Automatically register the accumulator when it is deserialized with the task closure.
- // This is for external accumulators and internal ones that do not represent task level
- // metrics, e.g. internal SQL metrics, which are per-operator.
- val taskContext = TaskContext.get()
- if (taskContext != null) {
- taskContext.registerAccumulator(this)
- }
- }
-
- override def toString: String = if (value_ == null) "null" else value_.toString
+ override def toString: String = if (newAcc._value == null) "null" else newAcc._value.toString
}
diff --git a/core/src/main/scala/org/apache/spark/Accumulator.scala b/core/src/main/scala/org/apache/spark/Accumulator.scala
index 0c17f014c9..9b007b9776 100644
--- a/core/src/main/scala/org/apache/spark/Accumulator.scala
+++ b/core/src/main/scala/org/apache/spark/Accumulator.scala
@@ -68,73 +68,6 @@ class Accumulator[T] private[spark] (
extends Accumulable[T, T](initialValue, param, name, countFailedValues)
-// TODO: The multi-thread support in accumulators is kind of lame; check
-// if there's a more intuitive way of doing it right
-private[spark] object Accumulators extends Logging {
- /**
- * This global map holds the original accumulator objects that are created on the driver.
- * It keeps weak references to these objects so that accumulators can be garbage-collected
- * once the RDDs and user-code that reference them are cleaned up.
- * TODO: Don't use a global map; these should be tied to a SparkContext (SPARK-13051).
- */
- @GuardedBy("Accumulators")
- val originals = mutable.Map[Long, WeakReference[Accumulable[_, _]]]()
-
- private val nextId = new AtomicLong(0L)
-
- /**
- * Return a globally unique ID for a new [[Accumulable]].
- * Note: Once you copy the [[Accumulable]] the ID is no longer unique.
- */
- def newId(): Long = nextId.getAndIncrement
-
- /**
- * Register an [[Accumulable]] created on the driver such that it can be used on the executors.
- *
- * All accumulators registered here can later be used as a container for accumulating partial
- * values across multiple tasks. This is what [[org.apache.spark.scheduler.DAGScheduler]] does.
- * Note: if an accumulator is registered here, it should also be registered with the active
- * context cleaner for cleanup so as to avoid memory leaks.
- *
- * If an [[Accumulable]] with the same ID was already registered, this does nothing instead
- * of overwriting it. This happens when we copy accumulators, e.g. when we reconstruct
- * [[org.apache.spark.executor.TaskMetrics]] from accumulator updates.
- */
- def register(a: Accumulable[_, _]): Unit = synchronized {
- if (!originals.contains(a.id)) {
- originals(a.id) = new WeakReference[Accumulable[_, _]](a)
- }
- }
-
- /**
- * Unregister the [[Accumulable]] with the given ID, if any.
- */
- def remove(accId: Long): Unit = synchronized {
- originals.remove(accId)
- }
-
- /**
- * Return the [[Accumulable]] registered with the given ID, if any.
- */
- def get(id: Long): Option[Accumulable[_, _]] = synchronized {
- originals.get(id).map { weakRef =>
- // Since we are storing weak references, we must check whether the underlying data is valid.
- weakRef.get.getOrElse {
- throw new IllegalAccessError(s"Attempted to access garbage collected accumulator $id")
- }
- }
- }
-
- /**
- * Clear all registered [[Accumulable]]s. For testing only.
- */
- def clear(): Unit = synchronized {
- originals.clear()
- }
-
-}
-
-
/**
* A simpler version of [[org.apache.spark.AccumulableParam]] where the only data type you can add
* in is the same type as the accumulated value. An implicit AccumulatorParam object needs to be
diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala
index 76692ccec8..63a00a84af 100644
--- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala
@@ -144,7 +144,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
registerForCleanup(rdd, CleanRDD(rdd.id))
}
- def registerAccumulatorForCleanup(a: Accumulable[_, _]): Unit = {
+ def registerAccumulatorForCleanup(a: NewAccumulator[_, _]): Unit = {
registerForCleanup(a, CleanAccum(a.id))
}
@@ -241,7 +241,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
def doCleanupAccum(accId: Long, blocking: Boolean): Unit = {
try {
logDebug("Cleaning accumulator " + accId)
- Accumulators.remove(accId)
+ AccumulatorContext.remove(accId)
listeners.asScala.foreach(_.accumCleaned(accId))
logInfo("Cleaned accumulator " + accId)
} catch {
diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
index 2bdbd3fae9..9eac05fdf9 100644
--- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
+++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
@@ -35,7 +35,7 @@ import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils}
*/
private[spark] case class Heartbeat(
executorId: String,
- accumUpdates: Array[(Long, Seq[AccumulableInfo])], // taskId -> accum updates
+ accumUpdates: Array[(Long, Seq[NewAccumulator[_, _]])], // taskId -> accumulator updates
blockManagerId: BlockManagerId)
/**
diff --git a/core/src/main/scala/org/apache/spark/NewAccumulator.scala b/core/src/main/scala/org/apache/spark/NewAccumulator.scala
new file mode 100644
index 0000000000..edb9b741a8
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/NewAccumulator.scala
@@ -0,0 +1,391 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.{lang => jl}
+import java.io.ObjectInputStream
+import java.util.concurrent.atomic.AtomicLong
+import javax.annotation.concurrent.GuardedBy
+
+import org.apache.spark.scheduler.AccumulableInfo
+import org.apache.spark.util.Utils
+
+
+private[spark] case class AccumulatorMetadata(
+ id: Long,
+ name: Option[String],
+ countFailedValues: Boolean) extends Serializable
+
+
+/**
+ * The base class for accumulators, that can accumulate inputs of type `IN`, and produce output of
+ * type `OUT`.
+ */
+abstract class NewAccumulator[IN, OUT] extends Serializable {
+ private[spark] var metadata: AccumulatorMetadata = _
+ private[this] var atDriverSide = true
+
+ private[spark] def register(
+ sc: SparkContext,
+ name: Option[String] = None,
+ countFailedValues: Boolean = false): Unit = {
+ if (this.metadata != null) {
+ throw new IllegalStateException("Cannot register an Accumulator twice.")
+ }
+ this.metadata = AccumulatorMetadata(AccumulatorContext.newId(), name, countFailedValues)
+ AccumulatorContext.register(this)
+ sc.cleaner.foreach(_.registerAccumulatorForCleanup(this))
+ }
+
+ /**
+ * Returns true if this accumulator has been registered. Note that all accumulators must be
+ * registered before ues, or it will throw exception.
+ */
+ final def isRegistered: Boolean =
+ metadata != null && AccumulatorContext.originals.containsKey(metadata.id)
+
+ private def assertMetadataNotNull(): Unit = {
+ if (metadata == null) {
+ throw new IllegalAccessError("The metadata of this accumulator has not been assigned yet.")
+ }
+ }
+
+ /**
+ * Returns the id of this accumulator, can only be called after registration.
+ */
+ final def id: Long = {
+ assertMetadataNotNull()
+ metadata.id
+ }
+
+ /**
+ * Returns the name of this accumulator, can only be called after registration.
+ */
+ final def name: Option[String] = {
+ assertMetadataNotNull()
+ metadata.name
+ }
+
+ /**
+ * Whether to accumulate values from failed tasks. This is set to true for system and time
+ * metrics like serialization time or bytes spilled, and false for things with absolute values
+ * like number of input rows. This should be used for internal metrics only.
+ */
+ private[spark] final def countFailedValues: Boolean = {
+ assertMetadataNotNull()
+ metadata.countFailedValues
+ }
+
+ /**
+ * Creates an [[AccumulableInfo]] representation of this [[NewAccumulator]] with the provided
+ * values.
+ */
+ private[spark] def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = {
+ val isInternal = name.exists(_.startsWith(InternalAccumulator.METRICS_PREFIX))
+ new AccumulableInfo(id, name, update, value, isInternal, countFailedValues)
+ }
+
+ final private[spark] def isAtDriverSide: Boolean = atDriverSide
+
+ /**
+ * Tells if this accumulator is zero value or not. e.g. for a counter accumulator, 0 is zero
+ * value; for a list accumulator, Nil is zero value.
+ */
+ def isZero(): Boolean
+
+ /**
+ * Creates a new copy of this accumulator, which is zero value. i.e. call `isZero` on the copy
+ * must return true.
+ */
+ def copyAndReset(): NewAccumulator[IN, OUT]
+
+ /**
+ * Takes the inputs and accumulates. e.g. it can be a simple `+=` for counter accumulator.
+ */
+ def add(v: IN): Unit
+
+ /**
+ * Merges another same-type accumulator into this one and update its state, i.e. this should be
+ * merge-in-place.
+ */
+ def merge(other: NewAccumulator[IN, OUT]): Unit
+
+ /**
+ * Access this accumulator's current value; only allowed on driver.
+ */
+ final def value: OUT = {
+ if (atDriverSide) {
+ localValue
+ } else {
+ throw new UnsupportedOperationException("Can't read accumulator value in task")
+ }
+ }
+
+ /**
+ * Defines the current value of this accumulator.
+ *
+ * This is NOT the global value of the accumulator. To get the global value after a
+ * completed operation on the dataset, call `value`.
+ */
+ def localValue: OUT
+
+ // Called by Java when serializing an object
+ final protected def writeReplace(): Any = {
+ if (atDriverSide) {
+ if (!isRegistered) {
+ throw new UnsupportedOperationException(
+ "Accumulator must be registered before send to executor")
+ }
+ val copy = copyAndReset()
+ assert(copy.isZero(), "copyAndReset must return a zero value copy")
+ copy.metadata = metadata
+ copy
+ } else {
+ this
+ }
+ }
+
+ // Called by Java when deserializing an object
+ private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
+ in.defaultReadObject()
+ if (atDriverSide) {
+ atDriverSide = false
+
+ // Automatically register the accumulator when it is deserialized with the task closure.
+ // This is for external accumulators and internal ones that do not represent task level
+ // metrics, e.g. internal SQL metrics, which are per-operator.
+ val taskContext = TaskContext.get()
+ if (taskContext != null) {
+ taskContext.registerAccumulator(this)
+ }
+ } else {
+ atDriverSide = true
+ }
+ }
+
+ override def toString: String = {
+ if (metadata == null) {
+ "Un-registered Accumulator: " + getClass.getSimpleName
+ } else {
+ getClass.getSimpleName + s"(id: $id, name: $name, value: $localValue)"
+ }
+ }
+}
+
+
+private[spark] object AccumulatorContext {
+
+ /**
+ * This global map holds the original accumulator objects that are created on the driver.
+ * It keeps weak references to these objects so that accumulators can be garbage-collected
+ * once the RDDs and user-code that reference them are cleaned up.
+ * TODO: Don't use a global map; these should be tied to a SparkContext (SPARK-13051).
+ */
+ @GuardedBy("AccumulatorContext")
+ val originals = new java.util.HashMap[Long, jl.ref.WeakReference[NewAccumulator[_, _]]]
+
+ private[this] val nextId = new AtomicLong(0L)
+
+ /**
+ * Return a globally unique ID for a new [[Accumulator]].
+ * Note: Once you copy the [[Accumulator]] the ID is no longer unique.
+ */
+ def newId(): Long = nextId.getAndIncrement
+
+ /**
+ * Register an [[Accumulator]] created on the driver such that it can be used on the executors.
+ *
+ * All accumulators registered here can later be used as a container for accumulating partial
+ * values across multiple tasks. This is what [[org.apache.spark.scheduler.DAGScheduler]] does.
+ * Note: if an accumulator is registered here, it should also be registered with the active
+ * context cleaner for cleanup so as to avoid memory leaks.
+ *
+ * If an [[Accumulator]] with the same ID was already registered, this does nothing instead
+ * of overwriting it. We will never register same accumulator twice, this is just a sanity check.
+ */
+ def register(a: NewAccumulator[_, _]): Unit = synchronized {
+ if (!originals.containsKey(a.id)) {
+ originals.put(a.id, new jl.ref.WeakReference[NewAccumulator[_, _]](a))
+ }
+ }
+
+ /**
+ * Unregister the [[Accumulator]] with the given ID, if any.
+ */
+ def remove(id: Long): Unit = synchronized {
+ originals.remove(id)
+ }
+
+ /**
+ * Return the [[Accumulator]] registered with the given ID, if any.
+ */
+ def get(id: Long): Option[NewAccumulator[_, _]] = synchronized {
+ Option(originals.get(id)).map { ref =>
+ // Since we are storing weak references, we must check whether the underlying data is valid.
+ val acc = ref.get
+ if (acc eq null) {
+ throw new IllegalAccessError(s"Attempted to access garbage collected accumulator $id")
+ }
+ acc
+ }
+ }
+
+ /**
+ * Clear all registered [[Accumulator]]s. For testing only.
+ */
+ def clear(): Unit = synchronized {
+ originals.clear()
+ }
+}
+
+
+class LongAccumulator extends NewAccumulator[jl.Long, jl.Long] {
+ private[this] var _sum = 0L
+
+ override def isZero(): Boolean = _sum == 0
+
+ override def copyAndReset(): LongAccumulator = new LongAccumulator
+
+ override def add(v: jl.Long): Unit = _sum += v
+
+ def add(v: Long): Unit = _sum += v
+
+ def sum: Long = _sum
+
+ override def merge(other: NewAccumulator[jl.Long, jl.Long]): Unit = other match {
+ case o: LongAccumulator => _sum += o.sum
+ case _ => throw new UnsupportedOperationException(
+ s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
+ }
+
+ private[spark] def setValue(newValue: Long): Unit = _sum = newValue
+
+ override def localValue: jl.Long = _sum
+}
+
+
+class DoubleAccumulator extends NewAccumulator[jl.Double, jl.Double] {
+ private[this] var _sum = 0.0
+
+ override def isZero(): Boolean = _sum == 0.0
+
+ override def copyAndReset(): DoubleAccumulator = new DoubleAccumulator
+
+ override def add(v: jl.Double): Unit = _sum += v
+
+ def add(v: Double): Unit = _sum += v
+
+ def sum: Double = _sum
+
+ override def merge(other: NewAccumulator[jl.Double, jl.Double]): Unit = other match {
+ case o: DoubleAccumulator => _sum += o.sum
+ case _ => throw new UnsupportedOperationException(
+ s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
+ }
+
+ private[spark] def setValue(newValue: Double): Unit = _sum = newValue
+
+ override def localValue: jl.Double = _sum
+}
+
+
+class AverageAccumulator extends NewAccumulator[jl.Double, jl.Double] {
+ private[this] var _sum = 0.0
+ private[this] var _count = 0L
+
+ override def isZero(): Boolean = _sum == 0.0 && _count == 0
+
+ override def copyAndReset(): AverageAccumulator = new AverageAccumulator
+
+ override def add(v: jl.Double): Unit = {
+ _sum += v
+ _count += 1
+ }
+
+ def add(d: Double): Unit = {
+ _sum += d
+ _count += 1
+ }
+
+ override def merge(other: NewAccumulator[jl.Double, jl.Double]): Unit = other match {
+ case o: AverageAccumulator =>
+ _sum += o.sum
+ _count += o.count
+ case _ => throw new UnsupportedOperationException(
+ s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
+ }
+
+ override def localValue: jl.Double = if (_count == 0) {
+ Double.NaN
+ } else {
+ _sum / _count
+ }
+
+ def sum: Double = _sum
+
+ def count: Long = _count
+}
+
+
+class ListAccumulator[T] extends NewAccumulator[T, java.util.List[T]] {
+ private[this] val _list: java.util.List[T] = new java.util.ArrayList[T]
+
+ override def isZero(): Boolean = _list.isEmpty
+
+ override def copyAndReset(): ListAccumulator[T] = new ListAccumulator
+
+ override def add(v: T): Unit = _list.add(v)
+
+ override def merge(other: NewAccumulator[T, java.util.List[T]]): Unit = other match {
+ case o: ListAccumulator[T] => _list.addAll(o.localValue)
+ case _ => throw new UnsupportedOperationException(
+ s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
+ }
+
+ override def localValue: java.util.List[T] = java.util.Collections.unmodifiableList(_list)
+
+ private[spark] def setValue(newValue: java.util.List[T]): Unit = {
+ _list.clear()
+ _list.addAll(newValue)
+ }
+}
+
+
+class LegacyAccumulatorWrapper[R, T](
+ initialValue: R,
+ param: org.apache.spark.AccumulableParam[R, T]) extends NewAccumulator[T, R] {
+ private[spark] var _value = initialValue // Current value on driver
+
+ override def isZero(): Boolean = _value == param.zero(initialValue)
+
+ override def copyAndReset(): LegacyAccumulatorWrapper[R, T] = {
+ val acc = new LegacyAccumulatorWrapper(initialValue, param)
+ acc._value = param.zero(initialValue)
+ acc
+ }
+
+ override def add(v: T): Unit = _value = param.addAccumulator(_value, v)
+
+ override def merge(other: NewAccumulator[T, R]): Unit = other match {
+ case o: LegacyAccumulatorWrapper[R, T] => _value = param.addInPlace(_value, o.localValue)
+ case _ => throw new UnsupportedOperationException(
+ s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
+ }
+
+ override def localValue: R = _value
+}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index f322a770bf..865989aee0 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -1217,10 +1217,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* Create an [[org.apache.spark.Accumulator]] variable of a given type, which tasks can "add"
* values to using the `+=` method. Only the driver can access the accumulator's `value`.
*/
- def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]): Accumulator[T] =
- {
+ def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]): Accumulator[T] = {
val acc = new Accumulator(initialValue, param)
- cleaner.foreach(_.registerAccumulatorForCleanup(acc))
+ cleaner.foreach(_.registerAccumulatorForCleanup(acc.newAcc))
acc
}
@@ -1232,7 +1231,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
def accumulator[T](initialValue: T, name: String)(implicit param: AccumulatorParam[T])
: Accumulator[T] = {
val acc = new Accumulator(initialValue, param, Some(name))
- cleaner.foreach(_.registerAccumulatorForCleanup(acc))
+ cleaner.foreach(_.registerAccumulatorForCleanup(acc.newAcc))
acc
}
@@ -1245,7 +1244,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
def accumulable[R, T](initialValue: R)(implicit param: AccumulableParam[R, T])
: Accumulable[R, T] = {
val acc = new Accumulable(initialValue, param)
- cleaner.foreach(_.registerAccumulatorForCleanup(acc))
+ cleaner.foreach(_.registerAccumulatorForCleanup(acc.newAcc))
acc
}
@@ -1259,7 +1258,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
def accumulable[R, T](initialValue: R, name: String)(implicit param: AccumulableParam[R, T])
: Accumulable[R, T] = {
val acc = new Accumulable(initialValue, param, Some(name))
- cleaner.foreach(_.registerAccumulatorForCleanup(acc))
+ cleaner.foreach(_.registerAccumulatorForCleanup(acc.newAcc))
acc
}
@@ -1273,7 +1272,101 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
(initialValue: R): Accumulable[R, T] = {
val param = new GrowableAccumulableParam[R, T]
val acc = new Accumulable(initialValue, param)
- cleaner.foreach(_.registerAccumulatorForCleanup(acc))
+ cleaner.foreach(_.registerAccumulatorForCleanup(acc.newAcc))
+ acc
+ }
+
+ /**
+ * Register the given accumulator. Note that accumulators must be registered before use, or it
+ * will throw exception.
+ */
+ def register(acc: NewAccumulator[_, _]): Unit = {
+ acc.register(this)
+ }
+
+ /**
+ * Register the given accumulator with given name. Note that accumulators must be registered
+ * before use, or it will throw exception.
+ */
+ def register(acc: NewAccumulator[_, _], name: String): Unit = {
+ acc.register(this, name = Some(name))
+ }
+
+ /**
+ * Create and register a long accumulator, which starts with 0 and accumulates inputs by `+=`.
+ */
+ def longAccumulator: LongAccumulator = {
+ val acc = new LongAccumulator
+ register(acc)
+ acc
+ }
+
+ /**
+ * Create and register a long accumulator, which starts with 0 and accumulates inputs by `+=`.
+ */
+ def longAccumulator(name: String): LongAccumulator = {
+ val acc = new LongAccumulator
+ register(acc, name)
+ acc
+ }
+
+ /**
+ * Create and register a double accumulator, which starts with 0 and accumulates inputs by `+=`.
+ */
+ def doubleAccumulator: DoubleAccumulator = {
+ val acc = new DoubleAccumulator
+ register(acc)
+ acc
+ }
+
+ /**
+ * Create and register a double accumulator, which starts with 0 and accumulates inputs by `+=`.
+ */
+ def doubleAccumulator(name: String): DoubleAccumulator = {
+ val acc = new DoubleAccumulator
+ register(acc, name)
+ acc
+ }
+
+ /**
+ * Create and register an average accumulator, which accumulates double inputs by recording the
+ * total sum and total count, and produce the output by sum / total. Note that Double.NaN will be
+ * returned if no input is added.
+ */
+ def averageAccumulator: AverageAccumulator = {
+ val acc = new AverageAccumulator
+ register(acc)
+ acc
+ }
+
+ /**
+ * Create and register an average accumulator, which accumulates double inputs by recording the
+ * total sum and total count, and produce the output by sum / total. Note that Double.NaN will be
+ * returned if no input is added.
+ */
+ def averageAccumulator(name: String): AverageAccumulator = {
+ val acc = new AverageAccumulator
+ register(acc, name)
+ acc
+ }
+
+ /**
+ * Create and register a list accumulator, which starts with empty list and accumulates inputs
+ * by adding them into the inner list.
+ */
+ def listAccumulator[T]: ListAccumulator[T] = {
+ val acc = new ListAccumulator[T]
+ register(acc)
+ acc
+ }
+
+ /**
+ * Create and register a list accumulator, which starts with empty list and accumulates inputs
+ * by adding them into the inner list.
+ */
+ def listAccumulator[T](name: String): ListAccumulator[T] = {
+ val acc = new ListAccumulator[T]
+ register(acc, name)
acc
}
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index e7940bd9ed..9e53257462 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -188,6 +188,6 @@ abstract class TaskContext extends Serializable {
* Register an accumulator that belongs to this task. Accumulators must call this method when
* deserializing in executors.
*/
- private[spark] def registerAccumulator(a: Accumulable[_, _]): Unit
+ private[spark] def registerAccumulator(a: NewAccumulator[_, _]): Unit
}
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index 43e555670d..bc3807f5db 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -122,7 +122,7 @@ private[spark] class TaskContextImpl(
override def getMetricsSources(sourceName: String): Seq[Source] =
metricsSystem.getSourcesByName(sourceName)
- private[spark] override def registerAccumulator(a: Accumulable[_, _]): Unit = {
+ private[spark] override def registerAccumulator(a: NewAccumulator[_, _]): Unit = {
taskMetrics.registerAccumulator(a)
}
diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
index 7487cfe9c5..82ba2d0c27 100644
--- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala
+++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
@@ -19,10 +19,7 @@ package org.apache.spark
import java.io.{ObjectInputStream, ObjectOutputStream}
-import scala.util.Try
-
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.executor.TaskMetrics
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.AccumulableInfo
import org.apache.spark.storage.BlockManagerId
@@ -120,18 +117,10 @@ case class ExceptionFailure(
stackTrace: Array[StackTraceElement],
fullStackTrace: String,
private val exceptionWrapper: Option[ThrowableSerializationWrapper],
- accumUpdates: Seq[AccumulableInfo] = Seq.empty[AccumulableInfo])
+ accumUpdates: Seq[AccumulableInfo] = Seq.empty,
+ private[spark] var accums: Seq[NewAccumulator[_, _]] = Nil)
extends TaskFailedReason {
- @deprecated("use accumUpdates instead", "2.0.0")
- val metrics: Option[TaskMetrics] = {
- if (accumUpdates.nonEmpty) {
- Try(TaskMetrics.fromAccumulatorUpdates(accumUpdates)).toOption
- } else {
- None
- }
- }
-
/**
* `preserveCause` is used to keep the exception itself so it is available to the
* driver. This may be set to `false` in the event that the exception is not in fact
@@ -149,6 +138,11 @@ case class ExceptionFailure(
this(e, accumUpdates, preserveCause = true)
}
+ private[spark] def withAccums(accums: Seq[NewAccumulator[_, _]]): ExceptionFailure = {
+ this.accums = accums
+ this
+ }
+
def exception: Option[Throwable] = exceptionWrapper.flatMap(w => Option(w.exception))
override def toErrorString: String =
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 650f05c309..4d61f7e232 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -353,22 +353,24 @@ private[spark] class Executor(
logError(s"Exception in $taskName (TID $taskId)", t)
// Collect latest accumulator values to report back to the driver
- val accumulatorUpdates: Seq[AccumulableInfo] =
+ val accums: Seq[NewAccumulator[_, _]] =
if (task != null) {
task.metrics.setExecutorRunTime(System.currentTimeMillis() - taskStart)
task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime)
task.collectAccumulatorUpdates(taskFailed = true)
} else {
- Seq.empty[AccumulableInfo]
+ Seq.empty
}
+ val accUpdates = accums.map(acc => acc.toInfo(Some(acc.localValue), None))
+
val serializedTaskEndReason = {
try {
- ser.serialize(new ExceptionFailure(t, accumulatorUpdates))
+ ser.serialize(new ExceptionFailure(t, accUpdates).withAccums(accums))
} catch {
case _: NotSerializableException =>
// t is not serializable so just send the stacktrace
- ser.serialize(new ExceptionFailure(t, accumulatorUpdates, preserveCause = false))
+ ser.serialize(new ExceptionFailure(t, accUpdates, false).withAccums(accums))
}
}
execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason)
@@ -476,14 +478,14 @@ private[spark] class Executor(
/** Reports heartbeat and metrics for active tasks to the driver. */
private def reportHeartBeat(): Unit = {
// list of (task id, accumUpdates) to send back to the driver
- val accumUpdates = new ArrayBuffer[(Long, Seq[AccumulableInfo])]()
+ val accumUpdates = new ArrayBuffer[(Long, Seq[NewAccumulator[_, _]])]()
val curGCTime = computeTotalGcTime()
for (taskRunner <- runningTasks.values().asScala) {
if (taskRunner.task != null) {
taskRunner.task.metrics.mergeShuffleReadMetrics()
taskRunner.task.metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime)
- accumUpdates += ((taskRunner.taskId, taskRunner.task.metrics.accumulatorUpdates()))
+ accumUpdates += ((taskRunner.taskId, taskRunner.task.metrics.accumulators()))
}
}
diff --git a/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala b/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala
index 535352e7dd..6f7160ac0d 100644
--- a/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala
@@ -17,7 +17,7 @@
package org.apache.spark.executor
-import org.apache.spark.InternalAccumulator
+import org.apache.spark.LongAccumulator
import org.apache.spark.annotation.DeveloperApi
@@ -40,20 +40,18 @@ object DataReadMethod extends Enumeration with Serializable {
*/
@DeveloperApi
class InputMetrics private[spark] () extends Serializable {
- import InternalAccumulator._
-
- private[executor] val _bytesRead = TaskMetrics.createLongAccum(input.BYTES_READ)
- private[executor] val _recordsRead = TaskMetrics.createLongAccum(input.RECORDS_READ)
+ private[executor] val _bytesRead = new LongAccumulator
+ private[executor] val _recordsRead = new LongAccumulator
/**
* Total number of bytes read.
*/
- def bytesRead: Long = _bytesRead.localValue
+ def bytesRead: Long = _bytesRead.sum
/**
* Total number of records read.
*/
- def recordsRead: Long = _recordsRead.localValue
+ def recordsRead: Long = _recordsRead.sum
private[spark] def incBytesRead(v: Long): Unit = _bytesRead.add(v)
private[spark] def incRecordsRead(v: Long): Unit = _recordsRead.add(v)
diff --git a/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala b/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala
index 586c98b156..db3924cb69 100644
--- a/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala
@@ -17,7 +17,7 @@
package org.apache.spark.executor
-import org.apache.spark.InternalAccumulator
+import org.apache.spark.LongAccumulator
import org.apache.spark.annotation.DeveloperApi
@@ -39,20 +39,18 @@ object DataWriteMethod extends Enumeration with Serializable {
*/
@DeveloperApi
class OutputMetrics private[spark] () extends Serializable {
- import InternalAccumulator._
-
- private[executor] val _bytesWritten = TaskMetrics.createLongAccum(output.BYTES_WRITTEN)
- private[executor] val _recordsWritten = TaskMetrics.createLongAccum(output.RECORDS_WRITTEN)
+ private[executor] val _bytesWritten = new LongAccumulator
+ private[executor] val _recordsWritten = new LongAccumulator
/**
* Total number of bytes written.
*/
- def bytesWritten: Long = _bytesWritten.localValue
+ def bytesWritten: Long = _bytesWritten.sum
/**
* Total number of records written.
*/
- def recordsWritten: Long = _recordsWritten.localValue
+ def recordsWritten: Long = _recordsWritten.sum
private[spark] def setBytesWritten(v: Long): Unit = _bytesWritten.setValue(v)
private[spark] def setRecordsWritten(v: Long): Unit = _recordsWritten.setValue(v)
diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala
index f012a74db6..fa962108c3 100644
--- a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala
@@ -17,7 +17,7 @@
package org.apache.spark.executor
-import org.apache.spark.InternalAccumulator
+import org.apache.spark.LongAccumulator
import org.apache.spark.annotation.DeveloperApi
@@ -28,52 +28,44 @@ import org.apache.spark.annotation.DeveloperApi
*/
@DeveloperApi
class ShuffleReadMetrics private[spark] () extends Serializable {
- import InternalAccumulator._
-
- private[executor] val _remoteBlocksFetched =
- TaskMetrics.createIntAccum(shuffleRead.REMOTE_BLOCKS_FETCHED)
- private[executor] val _localBlocksFetched =
- TaskMetrics.createIntAccum(shuffleRead.LOCAL_BLOCKS_FETCHED)
- private[executor] val _remoteBytesRead =
- TaskMetrics.createLongAccum(shuffleRead.REMOTE_BYTES_READ)
- private[executor] val _localBytesRead =
- TaskMetrics.createLongAccum(shuffleRead.LOCAL_BYTES_READ)
- private[executor] val _fetchWaitTime =
- TaskMetrics.createLongAccum(shuffleRead.FETCH_WAIT_TIME)
- private[executor] val _recordsRead =
- TaskMetrics.createLongAccum(shuffleRead.RECORDS_READ)
+ private[executor] val _remoteBlocksFetched = new LongAccumulator
+ private[executor] val _localBlocksFetched = new LongAccumulator
+ private[executor] val _remoteBytesRead = new LongAccumulator
+ private[executor] val _localBytesRead = new LongAccumulator
+ private[executor] val _fetchWaitTime = new LongAccumulator
+ private[executor] val _recordsRead = new LongAccumulator
/**
* Number of remote blocks fetched in this shuffle by this task.
*/
- def remoteBlocksFetched: Int = _remoteBlocksFetched.localValue
+ def remoteBlocksFetched: Long = _remoteBlocksFetched.sum
/**
* Number of local blocks fetched in this shuffle by this task.
*/
- def localBlocksFetched: Int = _localBlocksFetched.localValue
+ def localBlocksFetched: Long = _localBlocksFetched.sum
/**
* Total number of remote bytes read from the shuffle by this task.
*/
- def remoteBytesRead: Long = _remoteBytesRead.localValue
+ def remoteBytesRead: Long = _remoteBytesRead.sum
/**
* Shuffle data that was read from the local disk (as opposed to from a remote executor).
*/
- def localBytesRead: Long = _localBytesRead.localValue
+ def localBytesRead: Long = _localBytesRead.sum
/**
* Time the task spent waiting for remote shuffle blocks. This only includes the time
* blocking on shuffle input data. For instance if block B is being fetched while the task is
* still not finished processing block A, it is not considered to be blocking on block B.
*/
- def fetchWaitTime: Long = _fetchWaitTime.localValue
+ def fetchWaitTime: Long = _fetchWaitTime.sum
/**
* Total number of records read from the shuffle by this task.
*/
- def recordsRead: Long = _recordsRead.localValue
+ def recordsRead: Long = _recordsRead.sum
/**
* Total bytes fetched in the shuffle by this task (both remote and local).
@@ -83,10 +75,10 @@ class ShuffleReadMetrics private[spark] () extends Serializable {
/**
* Number of blocks fetched in this shuffle by this task (remote or local).
*/
- def totalBlocksFetched: Int = remoteBlocksFetched + localBlocksFetched
+ def totalBlocksFetched: Long = remoteBlocksFetched + localBlocksFetched
- private[spark] def incRemoteBlocksFetched(v: Int): Unit = _remoteBlocksFetched.add(v)
- private[spark] def incLocalBlocksFetched(v: Int): Unit = _localBlocksFetched.add(v)
+ private[spark] def incRemoteBlocksFetched(v: Long): Unit = _remoteBlocksFetched.add(v)
+ private[spark] def incLocalBlocksFetched(v: Long): Unit = _localBlocksFetched.add(v)
private[spark] def incRemoteBytesRead(v: Long): Unit = _remoteBytesRead.add(v)
private[spark] def incLocalBytesRead(v: Long): Unit = _localBytesRead.add(v)
private[spark] def incFetchWaitTime(v: Long): Unit = _fetchWaitTime.add(v)
@@ -104,12 +96,12 @@ class ShuffleReadMetrics private[spark] () extends Serializable {
* [[TempShuffleReadMetrics]] into `this`.
*/
private[spark] def setMergeValues(metrics: Seq[TempShuffleReadMetrics]): Unit = {
- _remoteBlocksFetched.setValue(_remoteBlocksFetched.zero)
- _localBlocksFetched.setValue(_localBlocksFetched.zero)
- _remoteBytesRead.setValue(_remoteBytesRead.zero)
- _localBytesRead.setValue(_localBytesRead.zero)
- _fetchWaitTime.setValue(_fetchWaitTime.zero)
- _recordsRead.setValue(_recordsRead.zero)
+ _remoteBlocksFetched.setValue(0)
+ _localBlocksFetched.setValue(0)
+ _remoteBytesRead.setValue(0)
+ _localBytesRead.setValue(0)
+ _fetchWaitTime.setValue(0)
+ _recordsRead.setValue(0)
metrics.foreach { metric =>
_remoteBlocksFetched.add(metric.remoteBlocksFetched)
_localBlocksFetched.add(metric.localBlocksFetched)
@@ -127,22 +119,22 @@ class ShuffleReadMetrics private[spark] () extends Serializable {
* last.
*/
private[spark] class TempShuffleReadMetrics {
- private[this] var _remoteBlocksFetched = 0
- private[this] var _localBlocksFetched = 0
+ private[this] var _remoteBlocksFetched = 0L
+ private[this] var _localBlocksFetched = 0L
private[this] var _remoteBytesRead = 0L
private[this] var _localBytesRead = 0L
private[this] var _fetchWaitTime = 0L
private[this] var _recordsRead = 0L
- def incRemoteBlocksFetched(v: Int): Unit = _remoteBlocksFetched += v
- def incLocalBlocksFetched(v: Int): Unit = _localBlocksFetched += v
+ def incRemoteBlocksFetched(v: Long): Unit = _remoteBlocksFetched += v
+ def incLocalBlocksFetched(v: Long): Unit = _localBlocksFetched += v
def incRemoteBytesRead(v: Long): Unit = _remoteBytesRead += v
def incLocalBytesRead(v: Long): Unit = _localBytesRead += v
def incFetchWaitTime(v: Long): Unit = _fetchWaitTime += v
def incRecordsRead(v: Long): Unit = _recordsRead += v
- def remoteBlocksFetched: Int = _remoteBlocksFetched
- def localBlocksFetched: Int = _localBlocksFetched
+ def remoteBlocksFetched: Long = _remoteBlocksFetched
+ def localBlocksFetched: Long = _localBlocksFetched
def remoteBytesRead: Long = _remoteBytesRead
def localBytesRead: Long = _localBytesRead
def fetchWaitTime: Long = _fetchWaitTime
diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala
index 7326fba841..0e70a4f522 100644
--- a/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala
@@ -17,7 +17,7 @@
package org.apache.spark.executor
-import org.apache.spark.InternalAccumulator
+import org.apache.spark.LongAccumulator
import org.apache.spark.annotation.DeveloperApi
@@ -28,29 +28,24 @@ import org.apache.spark.annotation.DeveloperApi
*/
@DeveloperApi
class ShuffleWriteMetrics private[spark] () extends Serializable {
- import InternalAccumulator._
-
- private[executor] val _bytesWritten =
- TaskMetrics.createLongAccum(shuffleWrite.BYTES_WRITTEN)
- private[executor] val _recordsWritten =
- TaskMetrics.createLongAccum(shuffleWrite.RECORDS_WRITTEN)
- private[executor] val _writeTime =
- TaskMetrics.createLongAccum(shuffleWrite.WRITE_TIME)
+ private[executor] val _bytesWritten = new LongAccumulator
+ private[executor] val _recordsWritten = new LongAccumulator
+ private[executor] val _writeTime = new LongAccumulator
/**
* Number of bytes written for the shuffle by this task.
*/
- def bytesWritten: Long = _bytesWritten.localValue
+ def bytesWritten: Long = _bytesWritten.sum
/**
* Total number of records written to the shuffle by this task.
*/
- def recordsWritten: Long = _recordsWritten.localValue
+ def recordsWritten: Long = _recordsWritten.sum
/**
* Time the task spent blocking on writes to disk or buffer cache, in nanoseconds.
*/
- def writeTime: Long = _writeTime.localValue
+ def writeTime: Long = _writeTime.sum
private[spark] def incBytesWritten(v: Long): Unit = _bytesWritten.add(v)
private[spark] def incRecordsWritten(v: Long): Unit = _recordsWritten.add(v)
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index 8513d053f2..0b64917219 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -17,10 +17,9 @@
package org.apache.spark.executor
-import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.{ArrayBuffer, LinkedHashMap}
import org.apache.spark._
-import org.apache.spark.AccumulatorParam.{IntAccumulatorParam, LongAccumulatorParam, UpdatedBlockStatusesAccumulatorParam}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.AccumulableInfo
@@ -42,53 +41,51 @@ import org.apache.spark.storage.{BlockId, BlockStatus}
*/
@DeveloperApi
class TaskMetrics private[spark] () extends Serializable {
- import InternalAccumulator._
-
// Each metric is internally represented as an accumulator
- private val _executorDeserializeTime = TaskMetrics.createLongAccum(EXECUTOR_DESERIALIZE_TIME)
- private val _executorRunTime = TaskMetrics.createLongAccum(EXECUTOR_RUN_TIME)
- private val _resultSize = TaskMetrics.createLongAccum(RESULT_SIZE)
- private val _jvmGCTime = TaskMetrics.createLongAccum(JVM_GC_TIME)
- private val _resultSerializationTime = TaskMetrics.createLongAccum(RESULT_SERIALIZATION_TIME)
- private val _memoryBytesSpilled = TaskMetrics.createLongAccum(MEMORY_BYTES_SPILLED)
- private val _diskBytesSpilled = TaskMetrics.createLongAccum(DISK_BYTES_SPILLED)
- private val _peakExecutionMemory = TaskMetrics.createLongAccum(PEAK_EXECUTION_MEMORY)
- private val _updatedBlockStatuses = TaskMetrics.createBlocksAccum(UPDATED_BLOCK_STATUSES)
+ private val _executorDeserializeTime = new LongAccumulator
+ private val _executorRunTime = new LongAccumulator
+ private val _resultSize = new LongAccumulator
+ private val _jvmGCTime = new LongAccumulator
+ private val _resultSerializationTime = new LongAccumulator
+ private val _memoryBytesSpilled = new LongAccumulator
+ private val _diskBytesSpilled = new LongAccumulator
+ private val _peakExecutionMemory = new LongAccumulator
+ private val _updatedBlockStatuses = new BlockStatusesAccumulator
/**
* Time taken on the executor to deserialize this task.
*/
- def executorDeserializeTime: Long = _executorDeserializeTime.localValue
+ def executorDeserializeTime: Long = _executorDeserializeTime.sum
/**
* Time the executor spends actually running the task (including fetching shuffle data).
*/
- def executorRunTime: Long = _executorRunTime.localValue
+ def executorRunTime: Long = _executorRunTime.sum
/**
* The number of bytes this task transmitted back to the driver as the TaskResult.
*/
- def resultSize: Long = _resultSize.localValue
+ def resultSize: Long = _resultSize.sum
/**
* Amount of time the JVM spent in garbage collection while executing this task.
*/
- def jvmGCTime: Long = _jvmGCTime.localValue
+ def jvmGCTime: Long = _jvmGCTime.sum
/**
* Amount of time spent serializing the task result.
*/
- def resultSerializationTime: Long = _resultSerializationTime.localValue
+ def resultSerializationTime: Long = _resultSerializationTime.sum
/**
* The number of in-memory bytes spilled by this task.
*/
- def memoryBytesSpilled: Long = _memoryBytesSpilled.localValue
+ def memoryBytesSpilled: Long = _memoryBytesSpilled.sum
/**
* The number of on-disk bytes spilled by this task.
*/
- def diskBytesSpilled: Long = _diskBytesSpilled.localValue
+ def diskBytesSpilled: Long = _diskBytesSpilled.sum
/**
* Peak memory used by internal data structures created during shuffles, aggregations and
@@ -96,7 +93,7 @@ class TaskMetrics private[spark] () extends Serializable {
* across all such data structures created in this task. For SQL jobs, this only tracks all
* unsafe operators and ExternalSort.
*/
- def peakExecutionMemory: Long = _peakExecutionMemory.localValue
+ def peakExecutionMemory: Long = _peakExecutionMemory.sum
/**
* Storage statuses of any blocks that have been updated as a result of this task.
@@ -114,7 +111,7 @@ class TaskMetrics private[spark] () extends Serializable {
private[spark] def incMemoryBytesSpilled(v: Long): Unit = _memoryBytesSpilled.add(v)
private[spark] def incDiskBytesSpilled(v: Long): Unit = _diskBytesSpilled.add(v)
private[spark] def incPeakExecutionMemory(v: Long): Unit = _peakExecutionMemory.add(v)
- private[spark] def incUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit =
+ private[spark] def incUpdatedBlockStatuses(v: (BlockId, BlockStatus)): Unit =
_updatedBlockStatuses.add(v)
private[spark] def setUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit =
_updatedBlockStatuses.setValue(v)
@@ -175,124 +172,143 @@ class TaskMetrics private[spark] () extends Serializable {
}
// Only used for test
- private[spark] val testAccum =
- sys.props.get("spark.testing").map(_ => TaskMetrics.createLongAccum(TEST_ACCUM))
-
- @transient private[spark] lazy val internalAccums: Seq[Accumulable[_, _]] = {
- val in = inputMetrics
- val out = outputMetrics
- val sr = shuffleReadMetrics
- val sw = shuffleWriteMetrics
- Seq(_executorDeserializeTime, _executorRunTime, _resultSize, _jvmGCTime,
- _resultSerializationTime, _memoryBytesSpilled, _diskBytesSpilled, _peakExecutionMemory,
- _updatedBlockStatuses, sr._remoteBlocksFetched, sr._localBlocksFetched, sr._remoteBytesRead,
- sr._localBytesRead, sr._fetchWaitTime, sr._recordsRead, sw._bytesWritten, sw._recordsWritten,
- sw._writeTime, in._bytesRead, in._recordsRead, out._bytesWritten, out._recordsWritten) ++
- testAccum
- }
+ private[spark] val testAccum = sys.props.get("spark.testing").map(_ => new LongAccumulator)
+
+
+ import InternalAccumulator._
+ @transient private[spark] lazy val nameToAccums = LinkedHashMap(
+ EXECUTOR_DESERIALIZE_TIME -> _executorDeserializeTime,
+ EXECUTOR_RUN_TIME -> _executorRunTime,
+ RESULT_SIZE -> _resultSize,
+ JVM_GC_TIME -> _jvmGCTime,
+ RESULT_SERIALIZATION_TIME -> _resultSerializationTime,
+ MEMORY_BYTES_SPILLED -> _memoryBytesSpilled,
+ DISK_BYTES_SPILLED -> _diskBytesSpilled,
+ PEAK_EXECUTION_MEMORY -> _peakExecutionMemory,
+ UPDATED_BLOCK_STATUSES -> _updatedBlockStatuses,
+ shuffleRead.REMOTE_BLOCKS_FETCHED -> shuffleReadMetrics._remoteBlocksFetched,
+ shuffleRead.LOCAL_BLOCKS_FETCHED -> shuffleReadMetrics._localBlocksFetched,
+ shuffleRead.REMOTE_BYTES_READ -> shuffleReadMetrics._remoteBytesRead,
+ shuffleRead.LOCAL_BYTES_READ -> shuffleReadMetrics._localBytesRead,
+ shuffleRead.FETCH_WAIT_TIME -> shuffleReadMetrics._fetchWaitTime,
+ shuffleRead.RECORDS_READ -> shuffleReadMetrics._recordsRead,
+ shuffleWrite.BYTES_WRITTEN -> shuffleWriteMetrics._bytesWritten,
+ shuffleWrite.RECORDS_WRITTEN -> shuffleWriteMetrics._recordsWritten,
+ shuffleWrite.WRITE_TIME -> shuffleWriteMetrics._writeTime,
+ input.BYTES_READ -> inputMetrics._bytesRead,
+ input.RECORDS_READ -> inputMetrics._recordsRead,
+ output.BYTES_WRITTEN -> outputMetrics._bytesWritten,
+ output.RECORDS_WRITTEN -> outputMetrics._recordsWritten
+ ) ++ testAccum.map(TEST_ACCUM -> _)
+
+ @transient private[spark] lazy val internalAccums: Seq[NewAccumulator[_, _]] =
+ nameToAccums.values.toIndexedSeq
/* ========================== *
| OTHER THINGS |
* ========================== */
- private[spark] def registerForCleanup(sc: SparkContext): Unit = {
- internalAccums.foreach { accum =>
- sc.cleaner.foreach(_.registerAccumulatorForCleanup(accum))
+ private[spark] def register(sc: SparkContext): Unit = {
+ nameToAccums.foreach {
+ case (name, acc) => acc.register(sc, name = Some(name), countFailedValues = true)
}
}
/**
* External accumulators registered with this task.
*/
- @transient private lazy val externalAccums = new ArrayBuffer[Accumulable[_, _]]
+ @transient private lazy val externalAccums = new ArrayBuffer[NewAccumulator[_, _]]
- private[spark] def registerAccumulator(a: Accumulable[_, _]): Unit = {
+ private[spark] def registerAccumulator(a: NewAccumulator[_, _]): Unit = {
externalAccums += a
}
- /**
- * Return the latest updates of accumulators in this task.
- *
- * The [[AccumulableInfo.update]] field is always defined and the [[AccumulableInfo.value]]
- * field is always empty, since this represents the partial updates recorded in this task,
- * not the aggregated value across multiple tasks.
- */
- def accumulatorUpdates(): Seq[AccumulableInfo] = {
- (internalAccums ++ externalAccums).map { a => a.toInfo(Some(a.localValue), None) }
- }
+ private[spark] def accumulators(): Seq[NewAccumulator[_, _]] = internalAccums ++ externalAccums
}
-/**
- * Internal subclass of [[TaskMetrics]] which is used only for posting events to listeners.
- * Its purpose is to obviate the need for the driver to reconstruct the original accumulators,
- * which might have been garbage-collected. See SPARK-13407 for more details.
- *
- * Instances of this class should be considered read-only and users should not call `inc*()` or
- * `set*()` methods. While we could override the setter methods to throw
- * UnsupportedOperationException, we choose not to do so because the overrides would quickly become
- * out-of-date when new metrics are added.
- */
-private[spark] class ListenerTaskMetrics(accumUpdates: Seq[AccumulableInfo]) extends TaskMetrics {
-
- override def accumulatorUpdates(): Seq[AccumulableInfo] = accumUpdates
-
- override private[spark] def registerAccumulator(a: Accumulable[_, _]): Unit = {
- throw new UnsupportedOperationException("This TaskMetrics is read-only")
- }
-}
private[spark] object TaskMetrics extends Logging {
+ import InternalAccumulator._
/**
* Create an empty task metrics that doesn't register its accumulators.
*/
def empty: TaskMetrics = {
- val metrics = new TaskMetrics
- metrics.internalAccums.foreach(acc => Accumulators.remove(acc.id))
- metrics
+ val tm = new TaskMetrics
+ tm.nameToAccums.foreach { case (name, acc) =>
+ acc.metadata = AccumulatorMetadata(AccumulatorContext.newId(), Some(name), true)
+ }
+ tm
+ }
+
+ def registered: TaskMetrics = {
+ val tm = empty
+ tm.internalAccums.foreach(AccumulatorContext.register)
+ tm
}
/**
- * Create a new accumulator representing an internal task metric.
+ * Construct a [[TaskMetrics]] object from a list of [[AccumulableInfo]], called on driver only.
+ * The returned [[TaskMetrics]] is only used to get some internal metrics, we don't need to take
+ * care of external accumulator info passed in.
*/
- private def newMetric[T](
- initialValue: T,
- name: String,
- param: AccumulatorParam[T]): Accumulator[T] = {
- new Accumulator[T](initialValue, param, Some(name), countFailedValues = true)
+ def fromAccumulatorInfos(infos: Seq[AccumulableInfo]): TaskMetrics = {
+ val tm = new TaskMetrics
+ infos.filter(info => info.name.isDefined && info.update.isDefined).foreach { info =>
+ val name = info.name.get
+ val value = info.update.get
+ if (name == UPDATED_BLOCK_STATUSES) {
+ tm.setUpdatedBlockStatuses(value.asInstanceOf[Seq[(BlockId, BlockStatus)]])
+ } else {
+ tm.nameToAccums.get(name).foreach(
+ _.asInstanceOf[LongAccumulator].setValue(value.asInstanceOf[Long])
+ )
+ }
+ }
+ tm
}
- def createLongAccum(name: String): Accumulator[Long] = {
- newMetric(0L, name, LongAccumulatorParam)
- }
+ /**
+ * Construct a [[TaskMetrics]] object from a list of accumulator updates, called on driver only.
+ */
+ def fromAccumulators(accums: Seq[NewAccumulator[_, _]]): TaskMetrics = {
+ val tm = new TaskMetrics
+ val (internalAccums, externalAccums) =
+ accums.partition(a => a.name.isDefined && tm.nameToAccums.contains(a.name.get))
+
+ internalAccums.foreach { acc =>
+ val tmAcc = tm.nameToAccums(acc.name.get).asInstanceOf[NewAccumulator[Any, Any]]
+ tmAcc.metadata = acc.metadata
+ tmAcc.merge(acc.asInstanceOf[NewAccumulator[Any, Any]])
+ }
- def createIntAccum(name: String): Accumulator[Int] = {
- newMetric(0, name, IntAccumulatorParam)
+ tm.externalAccums ++= externalAccums
+ tm
}
+}
+
+
+private[spark] class BlockStatusesAccumulator
+ extends NewAccumulator[(BlockId, BlockStatus), Seq[(BlockId, BlockStatus)]] {
+ private[this] var _seq = ArrayBuffer.empty[(BlockId, BlockStatus)]
- def createBlocksAccum(name: String): Accumulator[Seq[(BlockId, BlockStatus)]] = {
- newMetric(Nil, name, UpdatedBlockStatusesAccumulatorParam)
+ override def isZero(): Boolean = _seq.isEmpty
+
+ override def copyAndReset(): BlockStatusesAccumulator = new BlockStatusesAccumulator
+
+ override def add(v: (BlockId, BlockStatus)): Unit = _seq += v
+
+ override def merge(other: NewAccumulator[(BlockId, BlockStatus), Seq[(BlockId, BlockStatus)]])
+ : Unit = other match {
+ case o: BlockStatusesAccumulator => _seq ++= o.localValue
+ case _ => throw new UnsupportedOperationException(
+ s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
}
- /**
- * Construct a [[TaskMetrics]] object from a list of accumulator updates, called on driver only.
- *
- * Executors only send accumulator updates back to the driver, not [[TaskMetrics]]. However, we
- * need the latter to post task end events to listeners, so we need to reconstruct the metrics
- * on the driver.
- *
- * This assumes the provided updates contain the initial set of accumulators representing
- * internal task level metrics.
- */
- def fromAccumulatorUpdates(accumUpdates: Seq[AccumulableInfo]): TaskMetrics = {
- val definedAccumUpdates = accumUpdates.filter(_.update.isDefined)
- val metrics = new ListenerTaskMetrics(definedAccumUpdates)
- // We don't register this [[ListenerTaskMetrics]] for cleanup, and this is only used to post
- // event, we should un-register all accumulators immediately.
- metrics.internalAccums.foreach(acc => Accumulators.remove(acc.id))
- definedAccumUpdates.filter(_.internal).foreach { accum =>
- metrics.internalAccums.find(_.name == accum.name).foreach(_.setValueAny(accum.update.get))
- }
- metrics
+ override def localValue: Seq[(BlockId, BlockStatus)] = _seq
+
+ def setValue(newValue: Seq[(BlockId, BlockStatus)]): Unit = {
+ _seq.clear()
+ _seq ++= newValue
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index b7fb608ea5..a96d5f6fbf 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -209,7 +209,7 @@ class DAGScheduler(
task: Task[_],
reason: TaskEndReason,
result: Any,
- accumUpdates: Seq[AccumulableInfo],
+ accumUpdates: Seq[NewAccumulator[_, _]],
taskInfo: TaskInfo): Unit = {
eventProcessLoop.post(
CompletionEvent(task, reason, result, accumUpdates, taskInfo))
@@ -1088,21 +1088,19 @@ class DAGScheduler(
val task = event.task
val stage = stageIdToStage(task.stageId)
try {
- event.accumUpdates.foreach { ainfo =>
- assert(ainfo.update.isDefined, "accumulator from task should have a partial value")
- val id = ainfo.id
- val partialValue = ainfo.update.get
+ event.accumUpdates.foreach { updates =>
+ val id = updates.id
// Find the corresponding accumulator on the driver and update it
- val acc: Accumulable[Any, Any] = Accumulators.get(id) match {
- case Some(accum) => accum.asInstanceOf[Accumulable[Any, Any]]
+ val acc: NewAccumulator[Any, Any] = AccumulatorContext.get(id) match {
+ case Some(accum) => accum.asInstanceOf[NewAccumulator[Any, Any]]
case None =>
throw new SparkException(s"attempted to access non-existent accumulator $id")
}
- acc ++= partialValue
+ acc.merge(updates.asInstanceOf[NewAccumulator[Any, Any]])
// To avoid UI cruft, ignore cases where value wasn't updated
- if (acc.name.isDefined && partialValue != acc.zero) {
+ if (acc.name.isDefined && !updates.isZero()) {
stage.latestInfo.accumulables(id) = acc.toInfo(None, Some(acc.value))
- event.taskInfo.accumulables += acc.toInfo(Some(partialValue), Some(acc.value))
+ event.taskInfo.accumulables += acc.toInfo(Some(updates.value), Some(acc.value))
}
}
} catch {
@@ -1131,7 +1129,7 @@ class DAGScheduler(
val taskMetrics: TaskMetrics =
if (event.accumUpdates.nonEmpty) {
try {
- TaskMetrics.fromAccumulatorUpdates(event.accumUpdates)
+ TaskMetrics.fromAccumulators(event.accumUpdates)
} catch {
case NonFatal(e) =>
logError(s"Error when attempting to reconstruct metrics for task $taskId", e)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
index a3845c6acd..e57a2246d8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
@@ -71,7 +71,7 @@ private[scheduler] case class CompletionEvent(
task: Task[_],
reason: TaskEndReason,
result: Any,
- accumUpdates: Seq[AccumulableInfo],
+ accumUpdates: Seq[NewAccumulator[_, _]],
taskInfo: TaskInfo)
extends DAGSchedulerEvent
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
index 080ea6c33a..7618dfeeed 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
@@ -21,18 +21,15 @@ import java.util.Properties
import javax.annotation.Nullable
import scala.collection.Map
-import scala.collection.mutable
import com.fasterxml.jackson.annotation.JsonTypeInfo
import org.apache.spark.{SparkConf, TaskEndReason}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
-import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.cluster.ExecutorInfo
import org.apache.spark.storage.{BlockManagerId, BlockUpdatedInfo}
import org.apache.spark.ui.SparkUI
-import org.apache.spark.util.{Distribution, Utils}
@DeveloperApi
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "Event")
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index 02185bf631..2f972b064b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -112,7 +112,7 @@ private[scheduler] abstract class Stage(
numPartitionsToCompute: Int,
taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty): Unit = {
val metrics = new TaskMetrics
- metrics.registerForCleanup(rdd.sparkContext)
+ metrics.register(rdd.sparkContext)
_latestInfo = StageInfo.fromStage(
this, nextAttemptId, Some(numPartitionsToCompute), metrics, taskLocalityPreferences)
nextAttemptId += 1
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index eb10f3e69b..e7ca6efd84 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -23,7 +23,7 @@ import java.util.Properties
import scala.collection.mutable.HashMap
-import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl}
+import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.memory.{MemoryMode, TaskMemoryManager}
import org.apache.spark.metrics.MetricsSystem
@@ -52,7 +52,7 @@ private[spark] abstract class Task[T](
val stageAttemptId: Int,
val partitionId: Int,
// The default value is only used in tests.
- val metrics: TaskMetrics = TaskMetrics.empty,
+ val metrics: TaskMetrics = TaskMetrics.registered,
@transient var localProperties: Properties = new Properties) extends Serializable {
/**
@@ -153,11 +153,11 @@ private[spark] abstract class Task[T](
* Collect the latest values of accumulators used in this task. If the task failed,
* filter out the accumulators whose values should not be included on failures.
*/
- def collectAccumulatorUpdates(taskFailed: Boolean = false): Seq[AccumulableInfo] = {
+ def collectAccumulatorUpdates(taskFailed: Boolean = false): Seq[NewAccumulator[_, _]] = {
if (context != null) {
- context.taskMetrics.accumulatorUpdates().filter { a => !taskFailed || a.countFailedValues }
+ context.taskMetrics.accumulators().filter { a => !taskFailed || a.countFailedValues }
} else {
- Seq.empty[AccumulableInfo]
+ Seq.empty
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
index 03135e63d7..b472c5511b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
@@ -22,7 +22,7 @@ import java.nio.ByteBuffer
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.SparkEnv
+import org.apache.spark.{NewAccumulator, SparkEnv}
import org.apache.spark.storage.BlockId
import org.apache.spark.util.Utils
@@ -36,7 +36,7 @@ private[spark] case class IndirectTaskResult[T](blockId: BlockId, size: Int)
/** A TaskResult that contains the task's return value and accumulator updates. */
private[spark] class DirectTaskResult[T](
var valueBytes: ByteBuffer,
- var accumUpdates: Seq[AccumulableInfo])
+ var accumUpdates: Seq[NewAccumulator[_, _]])
extends TaskResult[T] with Externalizable {
private var valueObjectDeserialized = false
@@ -61,9 +61,9 @@ private[spark] class DirectTaskResult[T](
if (numUpdates == 0) {
accumUpdates = null
} else {
- val _accumUpdates = new ArrayBuffer[AccumulableInfo]
+ val _accumUpdates = new ArrayBuffer[NewAccumulator[_, _]]
for (i <- 0 until numUpdates) {
- _accumUpdates += in.readObject.asInstanceOf[AccumulableInfo]
+ _accumUpdates += in.readObject.asInstanceOf[NewAccumulator[_, _]]
}
accumUpdates = _accumUpdates
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
index ae7ef46abb..b438c285fd 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
@@ -93,9 +93,10 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
// we would have to serialize the result again after updating the size.
result.accumUpdates = result.accumUpdates.map { a =>
if (a.name == Some(InternalAccumulator.RESULT_SIZE)) {
- assert(a.update == Some(0L),
- "task result size should not have been set on the executors")
- a.copy(update = Some(size.toLong))
+ val acc = a.asInstanceOf[LongAccumulator]
+ assert(acc.sum == 0L, "task result size should not have been set on the executors")
+ acc.setValue(size.toLong)
+ acc
} else {
a
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
index 647d44a0f0..75a0c56311 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
@@ -17,6 +17,7 @@
package org.apache.spark.scheduler
+import org.apache.spark.NewAccumulator
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
import org.apache.spark.storage.BlockManagerId
@@ -66,7 +67,7 @@ private[spark] trait TaskScheduler {
*/
def executorHeartbeatReceived(
execId: String,
- accumUpdates: Array[(Long, Seq[AccumulableInfo])],
+ accumUpdates: Array[(Long, Seq[NewAccumulator[_, _]])],
blockManagerId: BlockManagerId): Boolean
/**
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index f31ec2af4e..776a3226cc 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -384,13 +384,14 @@ private[spark] class TaskSchedulerImpl(
*/
override def executorHeartbeatReceived(
execId: String,
- accumUpdates: Array[(Long, Seq[AccumulableInfo])],
+ accumUpdates: Array[(Long, Seq[NewAccumulator[_, _]])],
blockManagerId: BlockManagerId): Boolean = {
// (taskId, stageId, stageAttemptId, accumUpdates)
val accumUpdatesWithTaskIds: Array[(Long, Int, Int, Seq[AccumulableInfo])] = synchronized {
accumUpdates.flatMap { case (id, updates) =>
taskIdToTaskSetManager.get(id).map { taskSetMgr =>
- (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, updates)
+ (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId,
+ updates.map(acc => acc.toInfo(Some(acc.value), None)))
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 6e08cdd87a..b79f643a74 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -647,7 +647,7 @@ private[spark] class TaskSetManager(
info.markFailed()
val index = info.index
copiesRunning(index) -= 1
- var accumUpdates: Seq[AccumulableInfo] = Seq.empty[AccumulableInfo]
+ var accumUpdates: Seq[NewAccumulator[_, _]] = Seq.empty
val failureReason = s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid, ${info.host}): " +
reason.asInstanceOf[TaskFailedReason].toErrorString
val failureException: Option[Throwable] = reason match {
@@ -663,7 +663,7 @@ private[spark] class TaskSetManager(
case ef: ExceptionFailure =>
// ExceptionFailure's might have accumulator updates
- accumUpdates = ef.accumUpdates
+ accumUpdates = ef.accums
if (ef.className == classOf[NotSerializableException].getName) {
// If the task result wasn't serializable, there's no point in trying to re-execute it.
logError("Task %s in stage %s (TID %d) had a not serializable result: %s; not retrying"
@@ -788,7 +788,7 @@ private[spark] class TaskSetManager(
// Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
// stage finishes when a total of tasks.size tasks finish.
sched.dagScheduler.taskEnded(
- tasks(index), Resubmitted, null, Seq.empty[AccumulableInfo], info)
+ tasks(index), Resubmitted, null, Seq.empty, info)
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala
index 8daca6c390..c04b483831 100644
--- a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala
@@ -266,7 +266,13 @@ private[spark] object SerializationDebugger extends Logging {
(o, desc)
} else {
// write place
- findObjectAndDescriptor(desc.invokeWriteReplace(o))
+ val replaced = desc.invokeWriteReplace(o)
+ // `writeReplace` may return the same object.
+ if (replaced eq o) {
+ (o, desc)
+ } else {
+ findObjectAndDescriptor(replaced)
+ }
}
}
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala
index ff28796a60..32e332a9ad 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala
@@ -186,8 +186,8 @@ class OutputMetrics private[spark](
val recordsWritten: Long)
class ShuffleReadMetrics private[spark](
- val remoteBlocksFetched: Int,
- val localBlocksFetched: Int,
+ val remoteBlocksFetched: Long,
+ val localBlocksFetched: Long,
val fetchWaitTime: Long,
val remoteBytesRead: Long,
val localBytesRead: Long,
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 1c4921666f..f2d06c7ea8 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -801,7 +801,7 @@ private[spark] class BlockManager(
reportBlockStatus(blockId, info, putBlockStatus)
}
Option(TaskContext.get()).foreach { c =>
- c.taskMetrics().incUpdatedBlockStatuses(Seq((blockId, putBlockStatus)))
+ c.taskMetrics().incUpdatedBlockStatuses(blockId -> putBlockStatus)
}
}
logDebug("Put block %s locally took %s".format(blockId, Utils.getUsedTimeMs(startTimeMs)))
@@ -958,7 +958,7 @@ private[spark] class BlockManager(
reportBlockStatus(blockId, info, putBlockStatus)
}
Option(TaskContext.get()).foreach { c =>
- c.taskMetrics().incUpdatedBlockStatuses(Seq((blockId, putBlockStatus)))
+ c.taskMetrics().incUpdatedBlockStatuses(blockId -> putBlockStatus)
}
logDebug("Put block %s locally took %s".format(blockId, Utils.getUsedTimeMs(startTimeMs)))
if (level.replication > 1) {
@@ -1257,7 +1257,7 @@ private[spark] class BlockManager(
}
if (blockIsUpdated) {
Option(TaskContext.get()).foreach { c =>
- c.taskMetrics().incUpdatedBlockStatuses(Seq((blockId, status)))
+ c.taskMetrics().incUpdatedBlockStatuses(blockId -> status)
}
}
status.storageLevel
@@ -1311,7 +1311,7 @@ private[spark] class BlockManager(
reportBlockStatus(blockId, info, removeBlockStatus)
}
Option(TaskContext.get()).foreach { c =>
- c.taskMetrics().incUpdatedBlockStatuses(Seq((blockId, removeBlockStatus)))
+ c.taskMetrics().incUpdatedBlockStatuses(blockId -> removeBlockStatus)
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
index 9ab7d96e29..945830c8bf 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
@@ -375,26 +375,21 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
execSummary.taskTime += info.duration
stageData.numActiveTasks -= 1
- val (errorMessage, accums): (Option[String], Seq[AccumulableInfo]) =
+ val errorMessage: Option[String] =
taskEnd.reason match {
case org.apache.spark.Success =>
stageData.completedIndices.add(info.index)
stageData.numCompleteTasks += 1
- (None, taskEnd.taskMetrics.accumulatorUpdates())
+ None
case e: ExceptionFailure => // Handle ExceptionFailure because we might have accumUpdates
stageData.numFailedTasks += 1
- (Some(e.toErrorString), e.accumUpdates)
+ Some(e.toErrorString)
case e: TaskFailedReason => // All other failure cases
stageData.numFailedTasks += 1
- (Some(e.toErrorString), Seq.empty[AccumulableInfo])
+ Some(e.toErrorString)
}
- val taskMetrics =
- if (accums.nonEmpty) {
- Some(TaskMetrics.fromAccumulatorUpdates(accums))
- } else {
- None
- }
+ val taskMetrics = Option(taskEnd.taskMetrics)
taskMetrics.foreach { m =>
val oldMetrics = stageData.taskData.get(info.taskId).flatMap(_.metrics)
updateAggregateMetrics(stageData, info.executorId, m, oldMetrics)
@@ -503,7 +498,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
new StageUIData
})
val taskData = stageData.taskData.get(taskId)
- val metrics = TaskMetrics.fromAccumulatorUpdates(accumUpdates)
+ val metrics = TaskMetrics.fromAccumulatorInfos(accumUpdates)
taskData.foreach { t =>
if (!t.taskInfo.finished) {
updateAggregateMetrics(stageData, executorMetricsUpdate.execId, metrics, t.metrics)
diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
index a613fbc5cc..aeab71d9df 100644
--- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
@@ -840,7 +840,9 @@ private[spark] object JsonProtocol {
// Fallback on getting accumulator updates from TaskMetrics, which was logged in Spark 1.x
val accumUpdates = Utils.jsonOption(json \ "Accumulator Updates")
.map(_.extract[List[JValue]].map(accumulableInfoFromJson))
- .getOrElse(taskMetricsFromJson(json \ "Metrics").accumulatorUpdates())
+ .getOrElse(taskMetricsFromJson(json \ "Metrics").accumulators().map(acc => {
+ acc.toInfo(Some(acc.localValue), None)
+ }))
ExceptionFailure(className, description, stackTrace, fullStackTrace, None, accumUpdates)
case `taskResultLost` => TaskResultLost
case `taskKilled` => TaskKilled
diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
index 6063476936..5f97e58845 100644
--- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
@@ -28,17 +28,17 @@ import scala.util.control.NonFatal
import org.scalatest.Matchers
import org.scalatest.exceptions.TestFailedException
-import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.AccumulatorParam.{ListAccumulatorParam, StringAccumulatorParam}
import org.apache.spark.scheduler._
import org.apache.spark.serializer.JavaSerializer
class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContext {
- import AccumulatorParam._
+ import AccumulatorSuite.createLongAccum
override def afterEach(): Unit = {
try {
- Accumulators.clear()
+ AccumulatorContext.clear()
} finally {
super.afterEach()
}
@@ -59,9 +59,30 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex
}
}
+ test("accumulator serialization") {
+ val ser = new JavaSerializer(new SparkConf).newInstance()
+ val acc = createLongAccum("x")
+ acc.add(5)
+ assert(acc.value == 5)
+ assert(acc.isAtDriverSide)
+
+ // serialize and de-serialize it, to simulate sending accumulator to executor.
+ val acc2 = ser.deserialize[LongAccumulator](ser.serialize(acc))
+ // value is reset on the executors
+ assert(acc2.localValue == 0)
+ assert(!acc2.isAtDriverSide)
+
+ acc2.add(10)
+ // serialize and de-serialize it again, to simulate sending accumulator back to driver.
+ val acc3 = ser.deserialize[LongAccumulator](ser.serialize(acc2))
+ // value is not reset on the driver
+ assert(acc3.value == 10)
+ assert(acc3.isAtDriverSide)
+ }
+
test ("basic accumulation") {
sc = new SparkContext("local", "test")
- val acc : Accumulator[Int] = sc.accumulator(0)
+ val acc: Accumulator[Int] = sc.accumulator(0)
val d = sc.parallelize(1 to 20)
d.foreach{x => acc += x}
@@ -75,7 +96,7 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex
test("value not assignable from tasks") {
sc = new SparkContext("local", "test")
- val acc : Accumulator[Int] = sc.accumulator(0)
+ val acc: Accumulator[Int] = sc.accumulator(0)
val d = sc.parallelize(1 to 20)
an [Exception] should be thrownBy {d.foreach{x => acc.value = x}}
@@ -169,14 +190,13 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex
System.gc()
assert(ref.get.isEmpty)
- Accumulators.remove(accId)
- assert(!Accumulators.originals.get(accId).isDefined)
+ AccumulatorContext.remove(accId)
+ assert(!AccumulatorContext.originals.containsKey(accId))
}
test("get accum") {
- sc = new SparkContext("local", "test")
// Don't register with SparkContext for cleanup
- var acc = new Accumulable[Int, Int](0, IntAccumulatorParam, None, true)
+ var acc = createLongAccum("a")
val accId = acc.id
val ref = WeakReference(acc)
assert(ref.get.isDefined)
@@ -188,44 +208,16 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex
// Getting a garbage collected accum should throw error
intercept[IllegalAccessError] {
- Accumulators.get(accId)
+ AccumulatorContext.get(accId)
}
// Getting a normal accumulator. Note: this has to be separate because referencing an
// accumulator above in an `assert` would keep it from being garbage collected.
- val acc2 = new Accumulable[Long, Long](0L, LongAccumulatorParam, None, true)
- assert(Accumulators.get(acc2.id) === Some(acc2))
+ val acc2 = createLongAccum("b")
+ assert(AccumulatorContext.get(acc2.id) === Some(acc2))
// Getting an accumulator that does not exist should return None
- assert(Accumulators.get(100000).isEmpty)
- }
-
- test("copy") {
- val acc1 = new Accumulable[Long, Long](456L, LongAccumulatorParam, Some("x"), false)
- val acc2 = acc1.copy()
- assert(acc1.id === acc2.id)
- assert(acc1.value === acc2.value)
- assert(acc1.name === acc2.name)
- assert(acc1.countFailedValues === acc2.countFailedValues)
- assert(acc1 !== acc2)
- // Modifying one does not affect the other
- acc1.add(44L)
- assert(acc1.value === 500L)
- assert(acc2.value === 456L)
- acc2.add(144L)
- assert(acc1.value === 500L)
- assert(acc2.value === 600L)
- }
-
- test("register multiple accums with same ID") {
- val acc1 = new Accumulable[Int, Int](0, IntAccumulatorParam, None, true)
- // `copy` will create a new Accumulable and register it.
- val acc2 = acc1.copy()
- assert(acc1 !== acc2)
- assert(acc1.id === acc2.id)
- // The second one does not override the first one
- assert(Accumulators.originals.size === 1)
- assert(Accumulators.get(acc1.id) === Some(acc1))
+ assert(AccumulatorContext.get(100000).isEmpty)
}
test("string accumulator param") {
@@ -257,38 +249,33 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex
acc.setValue(Seq(9, 10))
assert(acc.value === Seq(9, 10))
}
-
- test("value is reset on the executors") {
- val acc1 = new Accumulator(0, IntAccumulatorParam, Some("thing"))
- val acc2 = new Accumulator(0L, LongAccumulatorParam, Some("thing2"))
- val externalAccums = Seq(acc1, acc2)
- val taskMetrics = new TaskMetrics
- // Set some values; these should not be observed later on the "executors"
- acc1.setValue(10)
- acc2.setValue(20L)
- taskMetrics.testAccum.get.setValue(30L)
- // Simulate the task being serialized and sent to the executors.
- val dummyTask = new DummyTask(taskMetrics, externalAccums)
- val serInstance = new JavaSerializer(new SparkConf).newInstance()
- val taskSer = Task.serializeWithDependencies(
- dummyTask, mutable.HashMap(), mutable.HashMap(), serInstance)
- // Now we're on the executors.
- // Deserialize the task and assert that its accumulators are zero'ed out.
- val (_, _, _, taskBytes) = Task.deserializeWithDependencies(taskSer)
- val taskDeser = serInstance.deserialize[DummyTask](
- taskBytes, Thread.currentThread.getContextClassLoader)
- // Assert that executors see only zeros
- taskDeser.externalAccums.foreach { a => assert(a.localValue == a.zero) }
- taskDeser.metrics.internalAccums.foreach { a => assert(a.localValue == a.zero) }
- }
-
}
private[spark] object AccumulatorSuite {
-
import InternalAccumulator._
/**
+ * Create a long accumulator and register it to [[AccumulatorContext]].
+ */
+ def createLongAccum(
+ name: String,
+ countFailedValues: Boolean = false,
+ initValue: Long = 0,
+ id: Long = AccumulatorContext.newId()): LongAccumulator = {
+ val acc = new LongAccumulator
+ acc.setValue(initValue)
+ acc.metadata = AccumulatorMetadata(id, Some(name), countFailedValues)
+ AccumulatorContext.register(acc)
+ acc
+ }
+
+ /**
+ * Make an [[AccumulableInfo]] out of an [[Accumulable]] with the intent to use the
+ * info as an accumulator update.
+ */
+ def makeInfo(a: NewAccumulator[_, _]): AccumulableInfo = a.toInfo(Some(a.localValue), None)
+
+ /**
* Run one or more Spark jobs and verify that in at least one job the peak execution memory
* accumulator is updated afterwards.
*/
@@ -340,7 +327,6 @@ private class SaveInfoListener extends SparkListener {
if (jobCompletionCallback != null) {
jobCompletionSem.acquire()
if (exception != null) {
- exception = null
throw exception
}
}
@@ -377,13 +363,3 @@ private class SaveInfoListener extends SparkListener {
(taskEnd.stageId, taskEnd.stageAttemptId), new ArrayBuffer[TaskInfo]) += taskEnd.taskInfo
}
}
-
-
-/**
- * A dummy [[Task]] that contains internal and external [[Accumulator]]s.
- */
-private[spark] class DummyTask(
- metrics: TaskMetrics,
- val externalAccums: Seq[Accumulator[_]]) extends Task[Int](0, 0, 0, metrics) {
- override def runTask(c: TaskContext): Int = 1
-}
diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
index 4d2b3e7f3b..1adc90ab1e 100644
--- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
+++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
@@ -211,10 +211,10 @@ class HeartbeatReceiverSuite
private def triggerHeartbeat(
executorId: String,
executorShouldReregister: Boolean): Unit = {
- val metrics = new TaskMetrics
+ val metrics = TaskMetrics.empty
val blockManagerId = BlockManagerId(executorId, "localhost", 12345)
val response = heartbeatReceiverRef.askWithRetry[HeartbeatResponse](
- Heartbeat(executorId, Array(1L -> metrics.accumulatorUpdates()), blockManagerId))
+ Heartbeat(executorId, Array(1L -> metrics.accumulators()), blockManagerId))
if (executorShouldReregister) {
assert(response.reregisterBlockManager)
} else {
@@ -222,7 +222,7 @@ class HeartbeatReceiverSuite
// Additionally verify that the scheduler callback is called with the correct parameters
verify(scheduler).executorHeartbeatReceived(
Matchers.eq(executorId),
- Matchers.eq(Array(1L -> metrics.accumulatorUpdates())),
+ Matchers.eq(Array(1L -> metrics.accumulators())),
Matchers.eq(blockManagerId))
}
}
diff --git a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala
index b074b95424..e4474bb813 100644
--- a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark
+import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.executor.TaskMetrics
@@ -29,7 +30,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
override def afterEach(): Unit = {
try {
- Accumulators.clear()
+ AccumulatorContext.clear()
} finally {
super.afterEach()
}
@@ -37,9 +38,8 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
test("internal accumulators in TaskContext") {
val taskContext = TaskContext.empty()
- val accumUpdates = taskContext.taskMetrics.accumulatorUpdates()
+ val accumUpdates = taskContext.taskMetrics.accumulators()
assert(accumUpdates.size > 0)
- assert(accumUpdates.forall(_.internal))
val testAccum = taskContext.taskMetrics.testAccum.get
assert(accumUpdates.exists(_.id == testAccum.id))
}
@@ -51,7 +51,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
sc.addSparkListener(listener)
// Have each task add 1 to the internal accumulator
val rdd = sc.parallelize(1 to 100, numPartitions).mapPartitions { iter =>
- TaskContext.get().taskMetrics().testAccum.get += 1
+ TaskContext.get().taskMetrics().testAccum.get.add(1)
iter
}
// Register asserts in job completion callback to avoid flakiness
@@ -87,17 +87,17 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
val rdd = sc.parallelize(1 to 100, numPartitions)
.map { i => (i, i) }
.mapPartitions { iter =>
- TaskContext.get().taskMetrics().testAccum.get += 1
+ TaskContext.get().taskMetrics().testAccum.get.add(1)
iter
}
.reduceByKey { case (x, y) => x + y }
.mapPartitions { iter =>
- TaskContext.get().taskMetrics().testAccum.get += 10
+ TaskContext.get().taskMetrics().testAccum.get.add(10)
iter
}
.repartition(numPartitions * 2)
.mapPartitions { iter =>
- TaskContext.get().taskMetrics().testAccum.get += 100
+ TaskContext.get().taskMetrics().testAccum.get.add(100)
iter
}
// Register asserts in job completion callback to avoid flakiness
@@ -127,7 +127,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
// This should retry both stages in the scheduler. Note that we only want to fail the
// first stage attempt because we want the stage to eventually succeed.
val x = sc.parallelize(1 to 100, numPartitions)
- .mapPartitions { iter => TaskContext.get().taskMetrics().testAccum.get += 1; iter }
+ .mapPartitions { iter => TaskContext.get().taskMetrics().testAccum.get.add(1); iter }
.groupBy(identity)
val sid = x.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleHandle.shuffleId
val rdd = x.mapPartitionsWithIndex { case (i, iter) =>
@@ -183,18 +183,18 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
private val myCleaner = new SaveAccumContextCleaner(this)
override def cleaner: Option[ContextCleaner] = Some(myCleaner)
}
- assert(Accumulators.originals.isEmpty)
+ assert(AccumulatorContext.originals.isEmpty)
sc.parallelize(1 to 100).map { i => (i, i) }.reduceByKey { _ + _ }.count()
val numInternalAccums = TaskMetrics.empty.internalAccums.length
// We ran 2 stages, so we should have 2 sets of internal accumulators, 1 for each stage
- assert(Accumulators.originals.size === numInternalAccums * 2)
+ assert(AccumulatorContext.originals.size === numInternalAccums * 2)
val accumsRegistered = sc.cleaner match {
case Some(cleaner: SaveAccumContextCleaner) => cleaner.accumsRegisteredForCleanup
case _ => Seq.empty[Long]
}
// Make sure the same set of accumulators is registered for cleanup
assert(accumsRegistered.size === numInternalAccums * 2)
- assert(accumsRegistered.toSet === Accumulators.originals.keys.toSet)
+ assert(accumsRegistered.toSet === AccumulatorContext.originals.keySet().asScala)
}
/**
@@ -212,7 +212,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
private class SaveAccumContextCleaner(sc: SparkContext) extends ContextCleaner(sc) {
private val accumsRegistered = new ArrayBuffer[Long]
- override def registerAccumulatorForCleanup(a: Accumulable[_, _]): Unit = {
+ override def registerAccumulatorForCleanup(a: NewAccumulator[_, _]): Unit = {
accumsRegistered += a.id
super.registerAccumulatorForCleanup(a)
}
diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala
index 3228752b96..4aae2c9b4a 100644
--- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala
@@ -34,7 +34,7 @@ private[spark] abstract class SparkFunSuite
protected override def afterAll(): Unit = {
try {
// Avoid leaking map entries in tests that use accumulators without SparkContext
- Accumulators.clear()
+ AccumulatorContext.clear()
} finally {
super.afterAll()
}
diff --git a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala
index ee70419727..94f6e1a3a7 100644
--- a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala
@@ -20,14 +20,11 @@ package org.apache.spark.executor
import org.scalatest.Assertions
import org.apache.spark._
-import org.apache.spark.scheduler.AccumulableInfo
-import org.apache.spark.storage.{BlockId, BlockStatus, StorageLevel, TestBlockId}
+import org.apache.spark.storage.{BlockStatus, StorageLevel, TestBlockId}
class TaskMetricsSuite extends SparkFunSuite {
- import AccumulatorParam._
import StorageLevel._
- import TaskMetricsSuite._
test("mutating values") {
val tm = new TaskMetrics
@@ -59,8 +56,8 @@ class TaskMetricsSuite extends SparkFunSuite {
tm.incPeakExecutionMemory(8L)
val block1 = (TestBlockId("a"), BlockStatus(MEMORY_ONLY, 1L, 2L))
val block2 = (TestBlockId("b"), BlockStatus(MEMORY_ONLY, 3L, 4L))
- tm.incUpdatedBlockStatuses(Seq(block1))
- tm.incUpdatedBlockStatuses(Seq(block2))
+ tm.incUpdatedBlockStatuses(block1)
+ tm.incUpdatedBlockStatuses(block2)
// assert new values exist
assert(tm.executorDeserializeTime == 1L)
assert(tm.executorRunTime == 2L)
@@ -194,18 +191,19 @@ class TaskMetricsSuite extends SparkFunSuite {
}
test("additional accumulables") {
- val tm = new TaskMetrics
- val acc1 = new Accumulator(0, IntAccumulatorParam, Some("a"))
- val acc2 = new Accumulator(0, IntAccumulatorParam, Some("b"))
- val acc3 = new Accumulator(0, IntAccumulatorParam, Some("c"))
- val acc4 = new Accumulator(0, IntAccumulatorParam, Some("d"), countFailedValues = true)
+ val tm = TaskMetrics.empty
+ val acc1 = AccumulatorSuite.createLongAccum("a")
+ val acc2 = AccumulatorSuite.createLongAccum("b")
+ val acc3 = AccumulatorSuite.createLongAccum("c")
+ val acc4 = AccumulatorSuite.createLongAccum("d", true)
tm.registerAccumulator(acc1)
tm.registerAccumulator(acc2)
tm.registerAccumulator(acc3)
tm.registerAccumulator(acc4)
- acc1 += 1
- acc2 += 2
- val newUpdates = tm.accumulatorUpdates().map { a => (a.id, a) }.toMap
+ acc1.add(1)
+ acc2.add(2)
+ val newUpdates = tm.accumulators()
+ .map(a => (a.id, a.asInstanceOf[NewAccumulator[Any, Any]])).toMap
assert(newUpdates.contains(acc1.id))
assert(newUpdates.contains(acc2.id))
assert(newUpdates.contains(acc3.id))
@@ -214,46 +212,14 @@ class TaskMetricsSuite extends SparkFunSuite {
assert(newUpdates(acc2.id).name === Some("b"))
assert(newUpdates(acc3.id).name === Some("c"))
assert(newUpdates(acc4.id).name === Some("d"))
- assert(newUpdates(acc1.id).update === Some(1))
- assert(newUpdates(acc2.id).update === Some(2))
- assert(newUpdates(acc3.id).update === Some(0))
- assert(newUpdates(acc4.id).update === Some(0))
+ assert(newUpdates(acc1.id).value === 1)
+ assert(newUpdates(acc2.id).value === 2)
+ assert(newUpdates(acc3.id).value === 0)
+ assert(newUpdates(acc4.id).value === 0)
assert(!newUpdates(acc3.id).countFailedValues)
assert(newUpdates(acc4.id).countFailedValues)
- assert(newUpdates.values.map(_.update).forall(_.isDefined))
- assert(newUpdates.values.map(_.value).forall(_.isEmpty))
assert(newUpdates.size === tm.internalAccums.size + 4)
}
-
- test("from accumulator updates") {
- val accumUpdates1 = TaskMetrics.empty.internalAccums.map { a =>
- AccumulableInfo(a.id, a.name, Some(3L), None, true, a.countFailedValues)
- }
- val metrics1 = TaskMetrics.fromAccumulatorUpdates(accumUpdates1)
- assertUpdatesEquals(metrics1.accumulatorUpdates(), accumUpdates1)
- // Test this with additional accumulators to ensure that we do not crash when handling
- // updates from unregistered accumulators. In practice, all accumulators created
- // on the driver, internal or not, should be registered with `Accumulators` at some point.
- val param = IntAccumulatorParam
- val registeredAccums = Seq(
- new Accumulator(0, param, Some("a"), countFailedValues = true),
- new Accumulator(0, param, Some("b"), countFailedValues = false))
- val unregisteredAccums = Seq(
- new Accumulator(0, param, Some("c"), countFailedValues = true),
- new Accumulator(0, param, Some("d"), countFailedValues = false))
- registeredAccums.foreach(Accumulators.register)
- registeredAccums.foreach(a => assert(Accumulators.originals.contains(a.id)))
- unregisteredAccums.foreach(a => Accumulators.remove(a.id))
- unregisteredAccums.foreach(a => assert(!Accumulators.originals.contains(a.id)))
- // set some values in these accums
- registeredAccums.zipWithIndex.foreach { case (a, i) => a.setValue(i) }
- unregisteredAccums.zipWithIndex.foreach { case (a, i) => a.setValue(i) }
- val registeredAccumInfos = registeredAccums.map(makeInfo)
- val unregisteredAccumInfos = unregisteredAccums.map(makeInfo)
- val accumUpdates2 = accumUpdates1 ++ registeredAccumInfos ++ unregisteredAccumInfos
- // Simply checking that this does not crash:
- TaskMetrics.fromAccumulatorUpdates(accumUpdates2)
- }
}
@@ -264,21 +230,14 @@ private[spark] object TaskMetricsSuite extends Assertions {
* Note: this does NOT check accumulator ID equality.
*/
def assertUpdatesEquals(
- updates1: Seq[AccumulableInfo],
- updates2: Seq[AccumulableInfo]): Unit = {
+ updates1: Seq[NewAccumulator[_, _]],
+ updates2: Seq[NewAccumulator[_, _]]): Unit = {
assert(updates1.size === updates2.size)
- updates1.zip(updates2).foreach { case (info1, info2) =>
+ updates1.zip(updates2).foreach { case (acc1, acc2) =>
// do not assert ID equals here
- assert(info1.name === info2.name)
- assert(info1.update === info2.update)
- assert(info1.value === info2.value)
- assert(info1.countFailedValues === info2.countFailedValues)
+ assert(acc1.name === acc2.name)
+ assert(acc1.countFailedValues === acc2.countFailedValues)
+ assert(acc1.value == acc2.value)
}
}
-
- /**
- * Make an [[AccumulableInfo]] out of an [[Accumulable]] with the intent to use the
- * info as an accumulator update.
- */
- def makeInfo(a: Accumulable[_, _]): AccumulableInfo = a.toInfo(Some(a.value), None)
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index b76c0a4bd1..9912d1f3bc 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -112,7 +112,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
override def stop() = {}
override def executorHeartbeatReceived(
execId: String,
- accumUpdates: Array[(Long, Seq[AccumulableInfo])],
+ accumUpdates: Array[(Long, Seq[NewAccumulator[_, _]])],
blockManagerId: BlockManagerId): Boolean = true
override def submitTasks(taskSet: TaskSet) = {
// normally done by TaskSetManager
@@ -277,8 +277,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
taskSet.tasks(i),
result._1,
result._2,
- Seq(new AccumulableInfo(
- accumId, Some(""), Some(1), None, internal = false, countFailedValues = false))))
+ Seq(AccumulatorSuite.createLongAccum("", initValue = 1, id = accumId))))
}
}
}
@@ -484,7 +483,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
override def defaultParallelism(): Int = 2
override def executorHeartbeatReceived(
execId: String,
- accumUpdates: Array[(Long, Seq[AccumulableInfo])],
+ accumUpdates: Array[(Long, Seq[NewAccumulator[_, _]])],
blockManagerId: BlockManagerId): Boolean = true
override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {}
override def applicationAttemptId(): Option[String] = None
@@ -997,10 +996,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
// complete two tasks
runEvent(makeCompletionEvent(
taskSets(0).tasks(0), Success, 42,
- Seq.empty[AccumulableInfo], createFakeTaskInfoWithId(0)))
+ Seq.empty, createFakeTaskInfoWithId(0)))
runEvent(makeCompletionEvent(
taskSets(0).tasks(1), Success, 42,
- Seq.empty[AccumulableInfo], createFakeTaskInfoWithId(1)))
+ Seq.empty, createFakeTaskInfoWithId(1)))
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
// verify stage exists
assert(scheduler.stageIdToStage.contains(0))
@@ -1009,10 +1008,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
// finish other 2 tasks
runEvent(makeCompletionEvent(
taskSets(0).tasks(2), Success, 42,
- Seq.empty[AccumulableInfo], createFakeTaskInfoWithId(2)))
+ Seq.empty, createFakeTaskInfoWithId(2)))
runEvent(makeCompletionEvent(
taskSets(0).tasks(3), Success, 42,
- Seq.empty[AccumulableInfo], createFakeTaskInfoWithId(3)))
+ Seq.empty, createFakeTaskInfoWithId(3)))
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(sparkListener.endedTasks.size == 4)
@@ -1023,14 +1022,14 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
// with a speculative task and make sure the event is sent out
runEvent(makeCompletionEvent(
taskSets(0).tasks(3), Success, 42,
- Seq.empty[AccumulableInfo], createFakeTaskInfoWithId(5)))
+ Seq.empty, createFakeTaskInfoWithId(5)))
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(sparkListener.endedTasks.size == 5)
// make sure non successful tasks also send out event
runEvent(makeCompletionEvent(
taskSets(0).tasks(3), UnknownReason, 42,
- Seq.empty[AccumulableInfo], createFakeTaskInfoWithId(6)))
+ Seq.empty, createFakeTaskInfoWithId(6)))
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(sparkListener.endedTasks.size == 6)
}
@@ -1613,37 +1612,43 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
test("accumulator not calculated for resubmitted result stage") {
// just for register
- val accum = new Accumulator[Int](0, AccumulatorParam.IntAccumulatorParam)
+ val accum = AccumulatorSuite.createLongAccum("a")
val finalRdd = new MyRDD(sc, 1, Nil)
submit(finalRdd, Array(0))
completeWithAccumulator(accum.id, taskSets(0), Seq((Success, 42)))
completeWithAccumulator(accum.id, taskSets(0), Seq((Success, 42)))
assert(results === Map(0 -> 42))
- val accVal = Accumulators.originals(accum.id).get.get.value
-
- assert(accVal === 1)
-
+ assert(accum.value === 1)
assertDataStructuresEmpty()
}
test("accumulators are updated on exception failures") {
- val acc1 = sc.accumulator(0L, "ingenieur")
- val acc2 = sc.accumulator(0L, "boulanger")
- val acc3 = sc.accumulator(0L, "agriculteur")
- assert(Accumulators.get(acc1.id).isDefined)
- assert(Accumulators.get(acc2.id).isDefined)
- assert(Accumulators.get(acc3.id).isDefined)
- val accInfo1 = acc1.toInfo(Some(15L), None)
- val accInfo2 = acc2.toInfo(Some(13L), None)
- val accInfo3 = acc3.toInfo(Some(18L), None)
- val accumUpdates = Seq(accInfo1, accInfo2, accInfo3)
- val exceptionFailure = new ExceptionFailure(new SparkException("fondue?"), accumUpdates)
+ val acc1 = AccumulatorSuite.createLongAccum("ingenieur")
+ val acc2 = AccumulatorSuite.createLongAccum("boulanger")
+ val acc3 = AccumulatorSuite.createLongAccum("agriculteur")
+ assert(AccumulatorContext.get(acc1.id).isDefined)
+ assert(AccumulatorContext.get(acc2.id).isDefined)
+ assert(AccumulatorContext.get(acc3.id).isDefined)
+ val accUpdate1 = new LongAccumulator
+ accUpdate1.metadata = acc1.metadata
+ accUpdate1.setValue(15)
+ val accUpdate2 = new LongAccumulator
+ accUpdate2.metadata = acc2.metadata
+ accUpdate2.setValue(13)
+ val accUpdate3 = new LongAccumulator
+ accUpdate3.metadata = acc3.metadata
+ accUpdate3.setValue(18)
+ val accumUpdates = Seq(accUpdate1, accUpdate2, accUpdate3)
+ val accumInfo = accumUpdates.map(AccumulatorSuite.makeInfo)
+ val exceptionFailure = new ExceptionFailure(
+ new SparkException("fondue?"),
+ accumInfo).copy(accums = accumUpdates)
submit(new MyRDD(sc, 1, Nil), Array(0))
runEvent(makeCompletionEvent(taskSets.head.tasks.head, exceptionFailure, "result"))
- assert(Accumulators.get(acc1.id).get.value === 15L)
- assert(Accumulators.get(acc2.id).get.value === 13L)
- assert(Accumulators.get(acc3.id).get.value === 18L)
+ assert(AccumulatorContext.get(acc1.id).get.value === 15L)
+ assert(AccumulatorContext.get(acc2.id).get.value === 13L)
+ assert(AccumulatorContext.get(acc3.id).get.value === 18L)
}
test("reduce tasks should be placed locally with map output") {
@@ -2007,12 +2012,12 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
task: Task[_],
reason: TaskEndReason,
result: Any,
- extraAccumUpdates: Seq[AccumulableInfo] = Seq.empty[AccumulableInfo],
+ extraAccumUpdates: Seq[NewAccumulator[_, _]] = Seq.empty,
taskInfo: TaskInfo = createFakeTaskInfo()): CompletionEvent = {
val accumUpdates = reason match {
- case Success => task.metrics.accumulatorUpdates()
- case ef: ExceptionFailure => ef.accumUpdates
- case _ => Seq.empty[AccumulableInfo]
+ case Success => task.metrics.accumulators()
+ case ef: ExceptionFailure => ef.accums
+ case _ => Seq.empty
}
CompletionEvent(task, reason, result, accumUpdates ++ extraAccumUpdates, taskInfo)
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala
index 9971d48a52..16027d944f 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.scheduler
-import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite}
+import org.apache.spark.{LocalSparkContext, NewAccumulator, SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
import org.apache.spark.storage.BlockManagerId
-class ExternalClusterManagerSuite extends SparkFunSuite with LocalSparkContext
-{
+class ExternalClusterManagerSuite extends SparkFunSuite with LocalSparkContext {
test("launch of backend and scheduler") {
val conf = new SparkConf().setMaster("myclusterManager").
setAppName("testcm").set("spark.driver.allowMultipleContexts", "true")
@@ -68,6 +67,6 @@ private class DummyTaskScheduler extends TaskScheduler {
override def applicationAttemptId(): Option[String] = None
def executorHeartbeatReceived(
execId: String,
- accumUpdates: Array[(Long, Seq[AccumulableInfo])],
+ accumUpdates: Array[(Long, Seq[NewAccumulator[_, _]])],
blockManagerId: BlockManagerId): Boolean = true
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index d55f6f60ec..9aca4dbc23 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -162,18 +162,17 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
}.count()
// The one that counts failed values should be 4x the one that didn't,
// since we ran each task 4 times
- assert(Accumulators.get(acc1.id).get.value === 40L)
- assert(Accumulators.get(acc2.id).get.value === 10L)
+ assert(AccumulatorContext.get(acc1.id).get.value === 40L)
+ assert(AccumulatorContext.get(acc2.id).get.value === 10L)
}
test("failed tasks collect only accumulators whose values count during failures") {
sc = new SparkContext("local", "test")
- val param = AccumulatorParam.LongAccumulatorParam
- val acc1 = new Accumulator(0L, param, Some("x"), countFailedValues = true)
- val acc2 = new Accumulator(0L, param, Some("y"), countFailedValues = false)
+ val acc1 = AccumulatorSuite.createLongAccum("x", true)
+ val acc2 = AccumulatorSuite.createLongAccum("y", false)
// Create a dummy task. We won't end up running this; we just want to collect
// accumulator updates from it.
- val taskMetrics = new TaskMetrics
+ val taskMetrics = TaskMetrics.empty
val task = new Task[Int](0, 0, 0) {
context = new TaskContextImpl(0, 0, 0L, 0,
new TaskMemoryManager(SparkEnv.get.memoryManager, 0L),
@@ -186,12 +185,11 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
}
// First, simulate task success. This should give us all the accumulators.
val accumUpdates1 = task.collectAccumulatorUpdates(taskFailed = false)
- val accumUpdates2 = (taskMetrics.internalAccums ++ Seq(acc1, acc2))
- .map(TaskMetricsSuite.makeInfo)
+ val accumUpdates2 = taskMetrics.internalAccums ++ Seq(acc1, acc2)
TaskMetricsSuite.assertUpdatesEquals(accumUpdates1, accumUpdates2)
// Now, simulate task failures. This should give us only the accums that count failed values.
val accumUpdates3 = task.collectAccumulatorUpdates(taskFailed = true)
- val accumUpdates4 = (taskMetrics.internalAccums ++ Seq(acc1)).map(TaskMetricsSuite.makeInfo)
+ val accumUpdates4 = taskMetrics.internalAccums ++ Seq(acc1)
TaskMetricsSuite.assertUpdatesEquals(accumUpdates3, accumUpdates4)
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
index b5385c11a9..9e472f900b 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
@@ -241,8 +241,8 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local
assert(resultGetter.taskResults.size === 1)
val resBefore = resultGetter.taskResults.head
val resAfter = captor.getValue
- val resSizeBefore = resBefore.accumUpdates.find(_.name == Some(RESULT_SIZE)).flatMap(_.update)
- val resSizeAfter = resAfter.accumUpdates.find(_.name == Some(RESULT_SIZE)).flatMap(_.update)
+ val resSizeBefore = resBefore.accumUpdates.find(_.name == Some(RESULT_SIZE)).map(_.value)
+ val resSizeAfter = resAfter.accumUpdates.find(_.name == Some(RESULT_SIZE)).map(_.value)
assert(resSizeBefore.exists(_ == 0L))
assert(resSizeAfter.exists(_.toString.toLong > 0L))
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index ecf4b76da5..339fc4254d 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -37,7 +37,7 @@ class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler)
task: Task[_],
reason: TaskEndReason,
result: Any,
- accumUpdates: Seq[AccumulableInfo],
+ accumUpdates: Seq[NewAccumulator[_, _]],
taskInfo: TaskInfo) {
taskScheduler.endedTasks(taskInfo.index) = reason
}
@@ -166,8 +166,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
val taskSet = FakeTask.createTaskSet(1)
val clock = new ManualClock
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock)
- val accumUpdates =
- taskSet.tasks.head.metrics.internalAccums.map { a => a.toInfo(Some(0L), None) }
+ val accumUpdates = taskSet.tasks.head.metrics.internalAccums
// Offer a host with NO_PREF as the constraint,
// we should get a nopref task immediately since that's what we only have
@@ -185,8 +184,8 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
val sched = new FakeTaskScheduler(sc, ("exec1", "host1"))
val taskSet = FakeTask.createTaskSet(3)
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES)
- val accumUpdatesByTask: Array[Seq[AccumulableInfo]] = taskSet.tasks.map { task =>
- task.metrics.internalAccums.map { a => a.toInfo(Some(0L), None) }
+ val accumUpdatesByTask: Array[Seq[NewAccumulator[_, _]]] = taskSet.tasks.map { task =>
+ task.metrics.internalAccums
}
// First three offers should all find tasks
@@ -792,7 +791,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
private def createTaskResult(
id: Int,
- accumUpdates: Seq[AccumulableInfo] = Seq.empty[AccumulableInfo]): DirectTaskResult[Int] = {
+ accumUpdates: Seq[NewAccumulator[_, _]] = Seq.empty): DirectTaskResult[Int] = {
val valueSer = SparkEnv.get.serializer.newInstance()
new DirectTaskResult[Int](valueSer.serialize(id), accumUpdates)
}
diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
index 221124829f..ce7d51d1c3 100644
--- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
@@ -183,7 +183,7 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with
test("test executor id to summary") {
val conf = new SparkConf()
val listener = new JobProgressListener(conf)
- val taskMetrics = new TaskMetrics()
+ val taskMetrics = TaskMetrics.empty
val shuffleReadMetrics = taskMetrics.createTempShuffleReadMetrics()
assert(listener.stageIdToData.size === 0)
@@ -230,7 +230,7 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with
test("test task success vs failure counting for different task end reasons") {
val conf = new SparkConf()
val listener = new JobProgressListener(conf)
- val metrics = new TaskMetrics()
+ val metrics = TaskMetrics.empty
val taskInfo = new TaskInfo(1234L, 0, 3, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false)
taskInfo.finishTime = 1
val task = new ShuffleMapTask(0)
@@ -269,7 +269,7 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with
val execId = "exe-1"
def makeTaskMetrics(base: Int): TaskMetrics = {
- val taskMetrics = new TaskMetrics
+ val taskMetrics = TaskMetrics.empty
val shuffleReadMetrics = taskMetrics.createTempShuffleReadMetrics()
val shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics
val inputMetrics = taskMetrics.inputMetrics
@@ -300,9 +300,9 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with
listener.onTaskStart(SparkListenerTaskStart(1, 0, makeTaskInfo(1237L)))
listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate(execId, Array(
- (1234L, 0, 0, makeTaskMetrics(0).accumulatorUpdates()),
- (1235L, 0, 0, makeTaskMetrics(100).accumulatorUpdates()),
- (1236L, 1, 0, makeTaskMetrics(200).accumulatorUpdates()))))
+ (1234L, 0, 0, makeTaskMetrics(0).accumulators().map(AccumulatorSuite.makeInfo)),
+ (1235L, 0, 0, makeTaskMetrics(100).accumulators().map(AccumulatorSuite.makeInfo)),
+ (1236L, 1, 0, makeTaskMetrics(200).accumulators().map(AccumulatorSuite.makeInfo)))))
var stage0Data = listener.stageIdToData.get((0, 0)).get
var stage1Data = listener.stageIdToData.get((1, 0)).get
diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
index d3b6cdfe86..6fda7378e6 100644
--- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
@@ -85,7 +85,8 @@ class JsonProtocolSuite extends SparkFunSuite {
// Use custom accum ID for determinism
val accumUpdates =
makeTaskMetrics(300L, 400L, 500L, 600L, 700, 800, hasHadoopInput = true, hasOutput = true)
- .accumulatorUpdates().zipWithIndex.map { case (a, i) => a.copy(id = i) }
+ .accumulators().map(AccumulatorSuite.makeInfo)
+ .zipWithIndex.map { case (a, i) => a.copy(id = i) }
SparkListenerExecutorMetricsUpdate("exec3", Seq((1L, 2, 3, accumUpdates)))
}
@@ -385,7 +386,7 @@ class JsonProtocolSuite extends SparkFunSuite {
// "Task Metrics" field, if it exists.
val tm = makeTaskMetrics(1L, 2L, 3L, 4L, 5, 6, hasHadoopInput = true, hasOutput = true)
val tmJson = JsonProtocol.taskMetricsToJson(tm)
- val accumUpdates = tm.accumulatorUpdates()
+ val accumUpdates = tm.accumulators().map(AccumulatorSuite.makeInfo)
val exception = new SparkException("sentimental")
val exceptionFailure = new ExceptionFailure(exception, accumUpdates)
val exceptionFailureJson = JsonProtocol.taskEndReasonToJson(exceptionFailure)
@@ -813,7 +814,7 @@ private[spark] object JsonProtocolSuite extends Assertions {
hasHadoopInput: Boolean,
hasOutput: Boolean,
hasRecords: Boolean = true) = {
- val t = new TaskMetrics
+ val t = TaskMetrics.empty
t.setExecutorDeserializeTime(a)
t.setExecutorRunTime(b)
t.setResultSize(c)
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 0f8648f890..6fc49a08fe 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -688,6 +688,18 @@ object MimaExcludes {
) ++ Seq(
// [SPARK-4452][Core]Shuffle data structures can starve others on the same thread for memory
ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.util.collection.Spillable")
+ ) ++ Seq(
+ // SPARK-14654: New accumulator API
+ ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ExceptionFailure$"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ExceptionFailure.apply"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ExceptionFailure.metrics"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ExceptionFailure.copy"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ExceptionFailure.this"),
+ ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.executor.ShuffleReadMetrics.remoteBlocksFetched"),
+ ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.executor.ShuffleReadMetrics.totalBlocksFetched"),
+ ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.executor.ShuffleReadMetrics.localBlocksFetched"),
+ ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.ShuffleReadMetrics.remoteBlocksFetched"),
+ ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.ShuffleReadMetrics.localBlocksFetched")
)
case v if v.startsWith("1.6") =>
Seq(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index 520ceaaaea..d6516f26a7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -106,7 +106,7 @@ private[sql] case class RDDScanExec(
override val nodeName: String) extends LeafExecNode {
private[sql] override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
@@ -147,7 +147,7 @@ private[sql] case class RowDataSourceScanExec(
extends DataSourceScanExec with CodegenSupport {
private[sql] override lazy val metrics =
- Map("numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
val outputUnsafeRows = relation match {
case r: HadoopFsRelation if r.fileFormat.isInstanceOf[ParquetSource] =>
@@ -216,7 +216,7 @@ private[sql] case class BatchedDataSourceScanExec(
extends DataSourceScanExec with CodegenSupport {
private[sql] override lazy val metrics =
- Map("numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"),
+ Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
"scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time"))
protected override def doExecute(): RDD[InternalRow] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
index 7c4756663a..c201822d44 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
@@ -40,7 +40,7 @@ case class ExpandExec(
extends UnaryExecNode with CodegenSupport {
private[sql] override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
// The GroupExpressions can output data with arbitrary partitioning, so set it
// as UNKNOWN partitioning
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
index 10cfec3330..934bc38dc4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
@@ -56,7 +56,7 @@ case class GenerateExec(
extends UnaryExecNode {
private[sql] override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
override def producedAttributes: AttributeSet = AttributeSet(output)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala
index 4ab447a47b..c5e78b0333 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala
@@ -31,7 +31,7 @@ private[sql] case class LocalTableScanExec(
rows: Seq[InternalRow]) extends LeafExecNode {
private[sql] override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
private val unsafeRows: Array[InternalRow] = {
val proj = UnsafeProjection.create(output, output)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 861ff3cd15..0bbe970420 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetric}
+import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.DataType
import org.apache.spark.util.ThreadUtils
@@ -77,7 +77,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
/**
* Return all metrics containing metrics of this SparkPlan.
*/
- private[sql] def metrics: Map[String, SQLMetric[_, _]] = Map.empty
+ private[sql] def metrics: Map[String, SQLMetric] = Map.empty
/**
* Reset all the metrics.
@@ -89,8 +89,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
/**
* Return a LongSQLMetric according to the name.
*/
- private[sql] def longMetric(name: String): LongSQLMetric =
- metrics(name).asInstanceOf[LongSQLMetric]
+ private[sql] def longMetric(name: String): SQLMetric = metrics(name)
// TODO: Move to `DistributedPlan`
/** Specifies how data is partitioned across different nodes in the cluster. */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
index cb4b1cfeb9..f84070a0c4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
@@ -55,8 +55,7 @@ private[sql] object SparkPlanInfo {
case _ => plan.children ++ plan.subqueries
}
val metrics = plan.metrics.toSeq.map { case (key, metric) =>
- new SQLMetricInfo(metric.name.getOrElse(key), metric.id,
- Utils.getFormattedClassName(metric.param))
+ new SQLMetricInfo(metric.name.getOrElse(key), metric.id, metric.metricType)
}
new SparkPlanInfo(plan.nodeName, plan.simpleString, children.map(fromSparkPlan),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
index 362d0d7a72..484923428f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
@@ -26,7 +26,7 @@ import com.google.common.io.ByteStreams
import org.apache.spark.serializer.{DeserializationStream, SerializationStream, Serializer, SerializerInstance}
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
-import org.apache.spark.sql.execution.metric.LongSQLMetric
+import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.unsafe.Platform
/**
@@ -42,7 +42,7 @@ import org.apache.spark.unsafe.Platform
*/
private[sql] class UnsafeRowSerializer(
numFields: Int,
- dataSize: LongSQLMetric = null) extends Serializer with Serializable {
+ dataSize: SQLMetric = null) extends Serializer with Serializable {
override def newInstance(): SerializerInstance =
new UnsafeRowSerializerInstance(numFields, dataSize)
override private[spark] def supportsRelocationOfSerializedObjects: Boolean = true
@@ -50,7 +50,7 @@ private[sql] class UnsafeRowSerializer(
private class UnsafeRowSerializerInstance(
numFields: Int,
- dataSize: LongSQLMetric) extends SerializerInstance {
+ dataSize: SQLMetric) extends SerializerInstance {
/**
* Serializes a stream of UnsafeRows. Within the stream, each record consists of a record
* length (stored as a 4-byte integer, written high byte first), followed by the record's bytes.
@@ -60,13 +60,10 @@ private class UnsafeRowSerializerInstance(
private[this] val dOut: DataOutputStream =
new DataOutputStream(new BufferedOutputStream(out))
- // LongSQLMetricParam.add() is faster than LongSQLMetric.+=
- val localDataSize = if (dataSize != null) dataSize.localValue else null
-
override def writeValue[T: ClassTag](value: T): SerializationStream = {
val row = value.asInstanceOf[UnsafeRow]
- if (localDataSize != null) {
- localDataSize.add(row.getSizeInBytes)
+ if (dataSize != null) {
+ dataSize.add(row.getSizeInBytes)
}
dOut.writeInt(row.getSizeInBytes)
row.writeToStream(dOut, writeBuffer)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index 6a03bd08c5..15b4abe806 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.toCommentSafeString
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
-import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetrics}
+import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -52,11 +52,7 @@ trait CodegenSupport extends SparkPlan {
* @return name of the variable representing the metric
*/
def metricTerm(ctx: CodegenContext, name: String): String = {
- val metric = ctx.addReferenceObj(name, longMetric(name))
- val value = ctx.freshName("metricValue")
- val cls = classOf[LongSQLMetricValue].getName
- ctx.addMutableState(cls, value, s"$value = ($cls) $metric.localValue();")
- value
+ ctx.addReferenceObj(name, longMetric(name))
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregateExec.scala
index 3169e0a2fd..2e74d59c5f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregateExec.scala
@@ -46,7 +46,7 @@ case class SortBasedAggregateExec(
AttributeSet(aggregateBufferAttributes)
override private[sql] lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
index c35d781d3e..f392b135ce 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction}
-import org.apache.spark.sql.execution.metric.LongSQLMetric
+import org.apache.spark.sql.execution.metric.SQLMetric
/**
* An iterator used to evaluate [[AggregateFunction]]. It assumes the input rows have been
@@ -35,7 +35,7 @@ class SortBasedAggregationIterator(
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection,
- numOutputRows: LongSQLMetric)
+ numOutputRows: SQLMetric)
extends AggregationIterator(
groupingExpressions,
valueAttributes,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 16362f756f..d0ba37ee13 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetrics}
+import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.types.{DecimalType, StringType, StructType}
import org.apache.spark.unsafe.KVIterator
@@ -51,7 +51,7 @@ case class TungstenAggregate(
aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
override private[sql] lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"),
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
"peakMemory" -> SQLMetrics.createSizeMetric(sparkContext, "peak memory"),
"spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"),
"aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "aggregate time"))
@@ -309,8 +309,8 @@ case class TungstenAggregate(
def finishAggregate(
hashMap: UnsafeFixedWidthAggregationMap,
sorter: UnsafeKVExternalSorter,
- peakMemory: LongSQLMetricValue,
- spillSize: LongSQLMetricValue): KVIterator[UnsafeRow, UnsafeRow] = {
+ peakMemory: SQLMetric,
+ spillSize: SQLMetric): KVIterator[UnsafeRow, UnsafeRow] = {
// update peak execution memory
val mapMemory = hashMap.getPeakMemoryUsedBytes
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index 9db5087fe0..243aa15deb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, UnsafeKVExternalSorter}
-import org.apache.spark.sql.execution.metric.LongSQLMetric
+import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.KVIterator
@@ -86,9 +86,9 @@ class TungstenAggregationIterator(
originalInputAttributes: Seq[Attribute],
inputIter: Iterator[InternalRow],
testFallbackStartsAt: Option[(Int, Int)],
- numOutputRows: LongSQLMetric,
- peakMemory: LongSQLMetric,
- spillSize: LongSQLMetric)
+ numOutputRows: SQLMetric,
+ peakMemory: SQLMetric,
+ spillSize: SQLMetric)
extends AggregationIterator(
groupingExpressions,
originalInputAttributes,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index 83f527f555..77be613b83 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -103,7 +103,7 @@ case class FilterExec(condition: Expression, child: SparkPlan)
}
private[sql] override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
override def inputRDDs(): Seq[RDD[InternalRow]] = {
child.asInstanceOf[CodegenSupport].inputRDDs()
@@ -229,7 +229,7 @@ case class SampleExec(
override def output: Seq[Attribute] = child.output
private[sql] override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
protected override def doExecute(): RDD[InternalRow] = {
if (withReplacement) {
@@ -322,7 +322,7 @@ case class RangeExec(
extends LeafExecNode with CodegenSupport {
private[sql] override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
// output attributes should not affect the results
override lazy val cleanArgs: Seq[Any] = Seq(start, step, numSlices, numElements)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
index cb957b9666..577c34ba61 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.columnar
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.{Accumulable, Accumulator, Accumulators}
+import org.apache.spark.{Accumulable, Accumulator, AccumulatorContext}
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
@@ -204,7 +204,7 @@ private[sql] case class InMemoryRelation(
Seq(_cachedColumnBuffers, statisticsToBePropagated, batchStats)
private[sql] def uncache(blocking: Boolean): Unit = {
- Accumulators.remove(batchStats.id)
+ AccumulatorContext.remove(batchStats.id)
cachedColumnBuffers.unpersist(blocking)
_cachedColumnBuffers = null
}
@@ -217,7 +217,7 @@ private[sql] case class InMemoryTableScanExec(
extends LeafExecNode {
private[sql] override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
override def output: Seq[Attribute] = attributes
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
index 573ca195ac..b6ecd3cb06 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
@@ -38,10 +38,10 @@ case class BroadcastExchangeExec(
child: SparkPlan) extends Exchange {
override private[sql] lazy val metrics = Map(
- "dataSize" -> SQLMetrics.createLongMetric(sparkContext, "data size (bytes)"),
- "collectTime" -> SQLMetrics.createLongMetric(sparkContext, "time to collect (ms)"),
- "buildTime" -> SQLMetrics.createLongMetric(sparkContext, "time to build (ms)"),
- "broadcastTime" -> SQLMetrics.createLongMetric(sparkContext, "time to broadcast (ms)"))
+ "dataSize" -> SQLMetrics.createMetric(sparkContext, "data size (bytes)"),
+ "collectTime" -> SQLMetrics.createMetric(sparkContext, "time to collect (ms)"),
+ "buildTime" -> SQLMetrics.createMetric(sparkContext, "time to build (ms)"),
+ "broadcastTime" -> SQLMetrics.createMetric(sparkContext, "time to broadcast (ms)"))
override def outputPartitioning: Partitioning = BroadcastPartitioning(mode)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
index b0a6b8f28a..587c603192 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
@@ -46,7 +46,7 @@ case class BroadcastHashJoinExec(
extends BinaryExecNode with HashJoin with CodegenSupport {
override private[sql] lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
index 51afa0017d..a659bf26e3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
@@ -35,7 +35,7 @@ case class BroadcastNestedLoopJoinExec(
condition: Option[Expression]) extends BinaryExecNode {
override private[sql] lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
/** BuildRight means the right relation <=> the broadcast relation. */
private val (streamed, broadcast) = buildSide match {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
index 67f59197ad..8d7ecc442a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
@@ -86,7 +86,7 @@ case class CartesianProductExec(
override def output: Seq[Attribute] = left.output ++ right.output
override private[sql] lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index d6feedc272..9c173d7bf1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution.{RowIterator, SparkPlan}
-import org.apache.spark.sql.execution.metric.LongSQLMetric
+import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.{IntegralType, LongType}
trait HashJoin {
@@ -201,7 +201,7 @@ trait HashJoin {
protected def join(
streamedIter: Iterator[InternalRow],
hashed: HashedRelation,
- numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
+ numOutputRows: SQLMetric): Iterator[InternalRow] = {
val joinedIter = joinType match {
case Inner =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
index a242a078f6..3ef2fec352 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
@@ -40,7 +40,7 @@ case class ShuffledHashJoinExec(
extends BinaryExecNode with HashJoin {
override private[sql] lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"),
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
"buildDataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size of build side"),
"buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map"))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index a4c5491aff..775f8ac508 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCo
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, RowIterator, SparkPlan}
-import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics}
+import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.util.collection.BitSet
/**
@@ -41,7 +41,7 @@ case class SortMergeJoinExec(
right: SparkPlan) extends BinaryExecNode with CodegenSupport {
override private[sql] lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
override def output: Seq[Attribute] = {
joinType match {
@@ -734,7 +734,7 @@ private class LeftOuterIterator(
rightNullRow: InternalRow,
boundCondition: InternalRow => Boolean,
resultProj: InternalRow => InternalRow,
- numOutputRows: LongSQLMetric)
+ numOutputRows: SQLMetric)
extends OneSideOuterIterator(
smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows) {
@@ -750,7 +750,7 @@ private class RightOuterIterator(
leftNullRow: InternalRow,
boundCondition: InternalRow => Boolean,
resultProj: InternalRow => InternalRow,
- numOutputRows: LongSQLMetric)
+ numOutputRows: SQLMetric)
extends OneSideOuterIterator(smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows) {
protected override def setStreamSideOutput(row: InternalRow): Unit = joinedRow.withRight(row)
@@ -778,7 +778,7 @@ private abstract class OneSideOuterIterator(
bufferedSideNullRow: InternalRow,
boundCondition: InternalRow => Boolean,
resultProj: InternalRow => InternalRow,
- numOutputRows: LongSQLMetric) extends RowIterator {
+ numOutputRows: SQLMetric) extends RowIterator {
// A row to store the joined result, reused many times
protected[this] val joinedRow: JoinedRow = new JoinedRow()
@@ -1016,7 +1016,7 @@ private class SortMergeFullOuterJoinScanner(
private class FullOuterIterator(
smjScanner: SortMergeFullOuterJoinScanner,
resultProj: InternalRow => InternalRow,
- numRows: LongSQLMetric) extends RowIterator {
+ numRows: SQLMetric) extends RowIterator {
private[this] val joinedRow: JoinedRow = smjScanner.getJoinedRow()
override def advanceNext(): Boolean = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala
index 2708219ad3..adb81519db 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala
@@ -27,4 +27,4 @@ import org.apache.spark.annotation.DeveloperApi
class SQLMetricInfo(
val name: String,
val accumulatorId: Long,
- val metricParam: String)
+ val metricType: String)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
index 5755c00c1f..7bf9225272 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
@@ -19,200 +19,106 @@ package org.apache.spark.sql.execution.metric
import java.text.NumberFormat
-import org.apache.spark.{Accumulable, AccumulableParam, Accumulators, SparkContext}
+import org.apache.spark.{NewAccumulator, SparkContext}
import org.apache.spark.scheduler.AccumulableInfo
import org.apache.spark.util.Utils
-/**
- * Create a layer for specialized metric. We cannot add `@specialized` to
- * `Accumulable/AccumulableParam` because it will break Java source compatibility.
- *
- * An implementation of SQLMetric should override `+=` and `add` to avoid boxing.
- */
-private[sql] abstract class SQLMetric[R <: SQLMetricValue[T], T](
- name: String,
- val param: SQLMetricParam[R, T]) extends Accumulable[R, T](param.zero, param, Some(name)) {
- // Provide special identifier as metadata so we can tell that this is a `SQLMetric` later
- override def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = {
- new AccumulableInfo(id, Some(name), update, value, true, countFailedValues,
- Some(SQLMetrics.ACCUM_IDENTIFIER))
- }
-
- def reset(): Unit = {
- this.value = param.zero
- }
-}
-
-/**
- * Create a layer for specialized metric. We cannot add `@specialized` to
- * `Accumulable/AccumulableParam` because it will break Java source compatibility.
- */
-private[sql] trait SQLMetricParam[R <: SQLMetricValue[T], T] extends AccumulableParam[R, T] {
-
- /**
- * A function that defines how we aggregate the final accumulator results among all tasks,
- * and represent it in string for a SQL physical operator.
- */
- val stringValue: Seq[T] => String
-
- def zero: R
-}
+class SQLMetric(val metricType: String, initValue: Long = 0L) extends NewAccumulator[Long, Long] {
+ // This is a workaround for SPARK-11013.
+ // We may use -1 as initial value of the accumulator, if the accumulator is valid, we will
+ // update it at the end of task and the value will be at least 0. Then we can filter out the -1
+ // values before calculate max, min, etc.
+ private[this] var _value = initValue
-/**
- * Create a layer for specialized metric. We cannot add `@specialized` to
- * `Accumulable/AccumulableParam` because it will break Java source compatibility.
- */
-private[sql] trait SQLMetricValue[T] extends Serializable {
+ override def copyAndReset(): SQLMetric = new SQLMetric(metricType, initValue)
- def value: T
-
- override def toString: String = value.toString
-}
-
-/**
- * A wrapper of Long to avoid boxing and unboxing when using Accumulator
- */
-private[sql] class LongSQLMetricValue(private var _value : Long) extends SQLMetricValue[Long] {
-
- def add(incr: Long): LongSQLMetricValue = {
- _value += incr
- this
+ override def merge(other: NewAccumulator[Long, Long]): Unit = other match {
+ case o: SQLMetric => _value += o.localValue
+ case _ => throw new UnsupportedOperationException(
+ s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
}
- // Although there is a boxing here, it's fine because it's only called in SQLListener
- override def value: Long = _value
-
- // Needed for SQLListenerSuite
- override def equals(other: Any): Boolean = other match {
- case o: LongSQLMetricValue => value == o.value
- case _ => false
- }
+ override def isZero(): Boolean = _value == initValue
- override def hashCode(): Int = _value.hashCode()
-}
+ override def add(v: Long): Unit = _value += v
-/**
- * A specialized long Accumulable to avoid boxing and unboxing when using Accumulator's
- * `+=` and `add`.
- */
-private[sql] class LongSQLMetric private[metric](name: String, param: LongSQLMetricParam)
- extends SQLMetric[LongSQLMetricValue, Long](name, param) {
+ def +=(v: Long): Unit = _value += v
- override def +=(term: Long): Unit = {
- localValue.add(term)
- }
+ override def localValue: Long = _value
- override def add(term: Long): Unit = {
- localValue.add(term)
+ // Provide special identifier as metadata so we can tell that this is a `SQLMetric` later
+ private[spark] override def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = {
+ new AccumulableInfo(id, name, update, value, true, true, Some(SQLMetrics.ACCUM_IDENTIFIER))
}
-}
-
-private class LongSQLMetricParam(val stringValue: Seq[Long] => String, initialValue: Long)
- extends SQLMetricParam[LongSQLMetricValue, Long] {
-
- override def addAccumulator(r: LongSQLMetricValue, t: Long): LongSQLMetricValue = r.add(t)
- override def addInPlace(r1: LongSQLMetricValue, r2: LongSQLMetricValue): LongSQLMetricValue =
- r1.add(r2.value)
-
- override def zero(initialValue: LongSQLMetricValue): LongSQLMetricValue = zero
-
- override def zero: LongSQLMetricValue = new LongSQLMetricValue(initialValue)
+ def reset(): Unit = _value = initValue
}
-private object LongSQLMetricParam
- extends LongSQLMetricParam(x => NumberFormat.getInstance().format(x.sum), 0L)
-
-private object StatisticsBytesSQLMetricParam extends LongSQLMetricParam(
- (values: Seq[Long]) => {
- // This is a workaround for SPARK-11013.
- // We use -1 as initial value of the accumulator, if the accumulator is valid, we will update
- // it at the end of task and the value will be at least 0.
- val validValues = values.filter(_ >= 0)
- val Seq(sum, min, med, max) = {
- val metric = if (validValues.length == 0) {
- Seq.fill(4)(0L)
- } else {
- val sorted = validValues.sorted
- Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1))
- }
- metric.map(Utils.bytesToString)
- }
- s"\n$sum ($min, $med, $max)"
- }, -1L)
-
-private object StatisticsTimingSQLMetricParam extends LongSQLMetricParam(
- (values: Seq[Long]) => {
- // This is a workaround for SPARK-11013.
- // We use -1 as initial value of the accumulator, if the accumulator is valid, we will update
- // it at the end of task and the value will be at least 0.
- val validValues = values.filter(_ >= 0)
- val Seq(sum, min, med, max) = {
- val metric = if (validValues.length == 0) {
- Seq.fill(4)(0L)
- } else {
- val sorted = validValues.sorted
- Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1))
- }
- metric.map(Utils.msDurationToString)
- }
- s"\n$sum ($min, $med, $max)"
- }, -1L)
private[sql] object SQLMetrics {
-
// Identifier for distinguishing SQL metrics from other accumulators
private[sql] val ACCUM_IDENTIFIER = "sql"
- private def createLongMetric(
- sc: SparkContext,
- name: String,
- param: LongSQLMetricParam): LongSQLMetric = {
- val acc = new LongSQLMetric(name, param)
- // This is an internal accumulator so we need to register it explicitly.
- Accumulators.register(acc)
- sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc))
- acc
- }
+ private[sql] val SUM_METRIC = "sum"
+ private[sql] val SIZE_METRIC = "size"
+ private[sql] val TIMING_METRIC = "timing"
- def createLongMetric(sc: SparkContext, name: String): LongSQLMetric = {
- createLongMetric(sc, name, LongSQLMetricParam)
+ def createMetric(sc: SparkContext, name: String): SQLMetric = {
+ val acc = new SQLMetric(SUM_METRIC)
+ acc.register(sc, name = Some(name), countFailedValues = true)
+ acc
}
/**
* Create a metric to report the size information (including total, min, med, max) like data size,
* spill size, etc.
*/
- def createSizeMetric(sc: SparkContext, name: String): LongSQLMetric = {
+ def createSizeMetric(sc: SparkContext, name: String): SQLMetric = {
// The final result of this metric in physical operator UI may looks like:
// data size total (min, med, max):
// 100GB (100MB, 1GB, 10GB)
- createLongMetric(sc, s"$name total (min, med, max)", StatisticsBytesSQLMetricParam)
+ val acc = new SQLMetric(SIZE_METRIC, -1)
+ acc.register(sc, name = Some(s"$name total (min, med, max)"), countFailedValues = true)
+ acc
}
- def createTimingMetric(sc: SparkContext, name: String): LongSQLMetric = {
+ def createTimingMetric(sc: SparkContext, name: String): SQLMetric = {
// The final result of this metric in physical operator UI may looks like:
// duration(min, med, max):
// 5s (800ms, 1s, 2s)
- createLongMetric(sc, s"$name total (min, med, max)", StatisticsTimingSQLMetricParam)
- }
-
- def getMetricParam(metricParamName: String): SQLMetricParam[SQLMetricValue[Any], Any] = {
- val longSQLMetricParam = Utils.getFormattedClassName(LongSQLMetricParam)
- val bytesSQLMetricParam = Utils.getFormattedClassName(StatisticsBytesSQLMetricParam)
- val timingsSQLMetricParam = Utils.getFormattedClassName(StatisticsTimingSQLMetricParam)
- val metricParam = metricParamName match {
- case `longSQLMetricParam` => LongSQLMetricParam
- case `bytesSQLMetricParam` => StatisticsBytesSQLMetricParam
- case `timingsSQLMetricParam` => StatisticsTimingSQLMetricParam
- }
- metricParam.asInstanceOf[SQLMetricParam[SQLMetricValue[Any], Any]]
+ val acc = new SQLMetric(TIMING_METRIC, -1)
+ acc.register(sc, name = Some(s"$name total (min, med, max)"), countFailedValues = true)
+ acc
}
/**
- * A metric that its value will be ignored. Use this one when we need a metric parameter but don't
- * care about the value.
+ * A function that defines how we aggregate the final accumulator results among all tasks,
+ * and represent it in string for a SQL physical operator.
*/
- val nullLongMetric = new LongSQLMetric("null", LongSQLMetricParam)
+ def stringValue(metricsType: String, values: Seq[Long]): String = {
+ if (metricsType == SUM_METRIC) {
+ NumberFormat.getInstance().format(values.sum)
+ } else {
+ val strFormat: Long => String = if (metricsType == SIZE_METRIC) {
+ Utils.bytesToString
+ } else if (metricsType == TIMING_METRIC) {
+ Utils.msDurationToString
+ } else {
+ throw new IllegalStateException("unexpected metrics type: " + metricsType)
+ }
+
+ val validValues = values.filter(_ >= 0)
+ val Seq(sum, min, med, max) = {
+ val metric = if (validValues.length == 0) {
+ Seq.fill(4)(0L)
+ } else {
+ val sorted = validValues.sorted
+ Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1))
+ }
+ metric.map(strFormat)
+ }
+ s"\n$sum ($min, $med, $max)"
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
index 5ae9e916ad..9118593c0e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
@@ -164,7 +164,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi
taskEnd.taskInfo.taskId,
taskEnd.stageId,
taskEnd.stageAttemptId,
- taskEnd.taskMetrics.accumulatorUpdates(),
+ taskEnd.taskMetrics.accumulators().map(a => a.toInfo(Some(a.localValue), None)),
finishTask = true)
}
}
@@ -296,7 +296,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi
}
}.filter { case (id, _) => executionUIData.accumulatorMetrics.contains(id) }
mergeAccumulatorUpdates(accumulatorUpdates, accumulatorId =>
- executionUIData.accumulatorMetrics(accumulatorId).metricParam)
+ executionUIData.accumulatorMetrics(accumulatorId).metricType)
case None =>
// This execution has been dropped
Map.empty
@@ -305,11 +305,11 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi
private def mergeAccumulatorUpdates(
accumulatorUpdates: Seq[(Long, Any)],
- paramFunc: Long => SQLMetricParam[SQLMetricValue[Any], Any]): Map[Long, String] = {
+ metricTypeFunc: Long => String): Map[Long, String] = {
accumulatorUpdates.groupBy(_._1).map { case (accumulatorId, values) =>
- val param = paramFunc(accumulatorId)
- (accumulatorId,
- param.stringValue(values.map(_._2.asInstanceOf[SQLMetricValue[Any]].value)))
+ val metricType = metricTypeFunc(accumulatorId)
+ accumulatorId ->
+ SQLMetrics.stringValue(metricType, values.map(_._2.asInstanceOf[Long]))
}
}
@@ -337,7 +337,7 @@ private[spark] class SQLHistoryListener(conf: SparkConf, sparkUI: SparkUI)
// Filter out accumulators that are not SQL metrics
// For now we assume all SQL metrics are Long's that have been JSON serialized as String's
if (a.metadata == Some(SQLMetrics.ACCUM_IDENTIFIER)) {
- val newValue = new LongSQLMetricValue(a.update.map(_.toString.toLong).getOrElse(0L))
+ val newValue = a.update.map(_.toString.toLong).getOrElse(0L)
Some(a.copy(update = Some(newValue)))
} else {
None
@@ -403,7 +403,7 @@ private[ui] class SQLExecutionUIData(
private[ui] case class SQLPlanMetric(
name: String,
accumulatorId: Long,
- metricParam: SQLMetricParam[SQLMetricValue[Any], Any])
+ metricType: String)
/**
* Store all accumulatorUpdates for all tasks in a Spark stage.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
index 1959f1e368..8f5681bfc7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
@@ -80,8 +80,7 @@ private[sql] object SparkPlanGraph {
planInfo.nodeName match {
case "WholeStageCodegen" =>
val metrics = planInfo.metrics.map { metric =>
- SQLPlanMetric(metric.name, metric.accumulatorId,
- SQLMetrics.getMetricParam(metric.metricParam))
+ SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType)
}
val cluster = new SparkPlanGraphCluster(
@@ -106,8 +105,7 @@ private[sql] object SparkPlanGraph {
edges += SparkPlanGraphEdge(node.id, parent.id)
case name =>
val metrics = planInfo.metrics.map { metric =>
- SQLPlanMetric(metric.name, metric.accumulatorId,
- SQLMetrics.getMetricParam(metric.metricParam))
+ SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType)
}
val node = new SparkPlanGraphNode(
nodeIdGenerator.getAndIncrement(), planInfo.nodeName,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index 4aea21e52a..0e6356b578 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -22,7 +22,7 @@ import scala.language.postfixOps
import org.scalatest.concurrent.Eventually._
-import org.apache.spark.Accumulators
+import org.apache.spark.AccumulatorContext
import org.apache.spark.sql.execution.RDDScanExec
import org.apache.spark.sql.execution.columnar._
import org.apache.spark.sql.execution.exchange.ShuffleExchange
@@ -333,11 +333,11 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
sql("SELECT * FROM t1").count()
sql("SELECT * FROM t2").count()
- Accumulators.synchronized {
- val accsSize = Accumulators.originals.size
+ AccumulatorContext.synchronized {
+ val accsSize = AccumulatorContext.originals.size
sqlContext.uncacheTable("t1")
sqlContext.uncacheTable("t2")
- assert((accsSize - 2) == Accumulators.originals.size)
+ assert((accsSize - 2) == AccumulatorContext.originals.size)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index 1859c6e7ad..8de4d8bbd4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -37,8 +37,8 @@ import org.apache.spark.util.{JsonProtocol, Utils}
class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
import testImplicits._
- test("LongSQLMetric should not box Long") {
- val l = SQLMetrics.createLongMetric(sparkContext, "long")
+ test("SQLMetric should not box Long") {
+ val l = SQLMetrics.createMetric(sparkContext, "long")
val f = () => {
l += 1L
l.add(1L)
@@ -300,12 +300,12 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
}
test("metrics can be loaded by history server") {
- val metric = new LongSQLMetric("zanzibar", LongSQLMetricParam)
+ val metric = SQLMetrics.createMetric(sparkContext, "zanzibar")
metric += 10L
val metricInfo = metric.toInfo(Some(metric.localValue), None)
metricInfo.update match {
- case Some(v: LongSQLMetricValue) => assert(v.value === 10L)
- case Some(v) => fail(s"metric value was not a LongSQLMetricValue: ${v.getClass.getName}")
+ case Some(v: Long) => assert(v === 10L)
+ case Some(v) => fail(s"metric value was not a Long: ${v.getClass.getName}")
case _ => fail("metric update is missing")
}
assert(metricInfo.metadata === Some(SQLMetrics.ACCUM_IDENTIFIER))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
index 09bd7f6e8f..8572ed16aa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
@@ -21,18 +21,19 @@ import java.util.Properties
import org.mockito.Mockito.{mock, when}
-import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite}
+import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.scheduler._
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.util.quietly
import org.apache.spark.sql.execution.{SparkPlanInfo, SQLExecution}
-import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetrics}
+import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.ui.SparkUI
class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
import testImplicits._
+ import org.apache.spark.AccumulatorSuite.makeInfo
private def createTestDataFrame: DataFrame = {
Seq(
@@ -72,9 +73,11 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
private def createTaskMetrics(accumulatorUpdates: Map[Long, Long]): TaskMetrics = {
val metrics = mock(classOf[TaskMetrics])
- when(metrics.accumulatorUpdates()).thenReturn(accumulatorUpdates.map { case (id, update) =>
- new AccumulableInfo(id, Some(""), Some(new LongSQLMetricValue(update)),
- value = None, internal = true, countFailedValues = true)
+ when(metrics.accumulators()).thenReturn(accumulatorUpdates.map { case (id, update) =>
+ val acc = new LongAccumulator
+ acc.metadata = AccumulatorMetadata(id, Some(""), true)
+ acc.setValue(update)
+ acc
}.toSeq)
metrics
}
@@ -130,16 +133,17 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq(
// (task id, stage id, stage attempt, accum updates)
- (0L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()),
- (1L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates())
+ (0L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)),
+ (1L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo))
)))
checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2))
listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq(
// (task id, stage id, stage attempt, accum updates)
- (0L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()),
- (1L, 0, 0, createTaskMetrics(accumulatorUpdates.mapValues(_ * 2)).accumulatorUpdates())
+ (0L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)),
+ (1L, 0, 0,
+ createTaskMetrics(accumulatorUpdates.mapValues(_ * 2)).accumulators().map(makeInfo))
)))
checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 3))
@@ -149,8 +153,8 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq(
// (task id, stage id, stage attempt, accum updates)
- (0L, 0, 1, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()),
- (1L, 0, 1, createTaskMetrics(accumulatorUpdates).accumulatorUpdates())
+ (0L, 0, 1, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)),
+ (1L, 0, 1, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo))
)))
checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2))
@@ -189,8 +193,8 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq(
// (task id, stage id, stage attempt, accum updates)
- (0L, 1, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()),
- (1L, 1, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates())
+ (0L, 1, 0, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)),
+ (1L, 1, 0, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo))
)))
checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 7))
@@ -358,7 +362,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
val stageSubmitted = SparkListenerStageSubmitted(stageInfo)
// This task has both accumulators that are SQL metrics and accumulators that are not.
// The listener should only track the ones that are actually SQL metrics.
- val sqlMetric = SQLMetrics.createLongMetric(sparkContext, "beach umbrella")
+ val sqlMetric = SQLMetrics.createMetric(sparkContext, "beach umbrella")
val nonSqlMetric = sparkContext.accumulator[Int](0, "baseball")
val sqlMetricInfo = sqlMetric.toInfo(Some(sqlMetric.localValue), None)
val nonSqlMetricInfo = nonSqlMetric.toInfo(Some(nonSqlMetric.localValue), None)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
index eb25ea0629..8a0578c1ff 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
@@ -96,7 +96,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
case w: WholeStageCodegenExec => w.child.longMetric("numOutputRows")
case other => other.longMetric("numOutputRows")
}
- metrics += metric.value.value
+ metrics += metric.value
}
}
sqlContext.listenerManager.register(listener)
@@ -126,9 +126,9 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
- metrics += qe.executedPlan.longMetric("dataSize").value.value
+ metrics += qe.executedPlan.longMetric("dataSize").value
val bottomAgg = qe.executedPlan.children(0).children(0)
- metrics += bottomAgg.longMetric("dataSize").value.value
+ metrics += bottomAgg.longMetric("dataSize").value
}
}
sqlContext.listenerManager.register(listener)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
index 007c3384e5..b52b96a804 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
@@ -55,7 +55,7 @@ case class HiveTableScanExec(
"Partition pruning predicates only supported for partitioned tables.")
private[sql] override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
override def producedAttributes: AttributeSet = outputSet ++
AttributeSet(partitionPruningPred.flatMap(_.references))