aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/Accumulable.scala64
-rw-r--r--core/src/main/scala/org/apache/spark/Accumulator.scala67
-rw-r--r--core/src/main/scala/org/apache/spark/ContextCleaner.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/NewAccumulator.scala391
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala107
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContext.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContextImpl.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/TaskEndReason.scala20
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/executor/InputMetrics.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala64
-rw-r--r--core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala19
-rw-r--r--core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala228
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala20
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Stage.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Task.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/status/api/v1/api.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala17
-rw-r--r--core/src/main/scala/org/apache/spark/util/JsonProtocol.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/AccumulatorSuite.scala132
-rw-r--r--core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala6
-rw-r--r--core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala24
-rw-r--r--core/src/test/scala/org/apache/spark/SparkFunSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala85
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala71
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala7
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala16
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala11
-rw-r--r--core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala12
-rw-r--r--core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala7
-rw-r--r--project/MimaExcludes.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregateExec.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala218
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala16
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala32
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala6
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala2
73 files changed, 1071 insertions, 842 deletions
diff --git a/core/src/main/scala/org/apache/spark/Accumulable.scala b/core/src/main/scala/org/apache/spark/Accumulable.scala
index e8f053c150..c76720c4bb 100644
--- a/core/src/main/scala/org/apache/spark/Accumulable.scala
+++ b/core/src/main/scala/org/apache/spark/Accumulable.scala
@@ -63,7 +63,7 @@ class Accumulable[R, T] private (
param: AccumulableParam[R, T],
name: Option[String],
countFailedValues: Boolean) = {
- this(Accumulators.newId(), initialValue, param, name, countFailedValues)
+ this(AccumulatorContext.newId(), initialValue, param, name, countFailedValues)
}
private[spark] def this(initialValue: R, param: AccumulableParam[R, T], name: Option[String]) = {
@@ -72,34 +72,23 @@ class Accumulable[R, T] private (
def this(initialValue: R, param: AccumulableParam[R, T]) = this(initialValue, param, None)
- @volatile @transient private var value_ : R = initialValue // Current value on driver
- val zero = param.zero(initialValue) // Zero value to be passed to executors
- private var deserialized = false
-
- Accumulators.register(this)
-
- /**
- * Return a copy of this [[Accumulable]].
- *
- * The copy will have the same ID as the original and will not be registered with
- * [[Accumulators]] again. This method exists so that the caller can avoid passing the
- * same mutable instance around.
- */
- private[spark] def copy(): Accumulable[R, T] = {
- new Accumulable[R, T](id, initialValue, param, name, countFailedValues)
- }
+ val zero = param.zero(initialValue)
+ private[spark] val newAcc = new LegacyAccumulatorWrapper(initialValue, param)
+ newAcc.metadata = AccumulatorMetadata(id, name, countFailedValues)
+ // Register the new accumulator in ctor, to follow the previous behaviour.
+ AccumulatorContext.register(newAcc)
/**
* Add more data to this accumulator / accumulable
* @param term the data to add
*/
- def += (term: T) { value_ = param.addAccumulator(value_, term) }
+ def += (term: T) { newAcc.add(term) }
/**
* Add more data to this accumulator / accumulable
* @param term the data to add
*/
- def add(term: T) { value_ = param.addAccumulator(value_, term) }
+ def add(term: T) { newAcc.add(term) }
/**
* Merge two accumulable objects together
@@ -107,7 +96,7 @@ class Accumulable[R, T] private (
* Normally, a user will not want to use this version, but will instead call `+=`.
* @param term the other `R` that will get merged with this
*/
- def ++= (term: R) { value_ = param.addInPlace(value_, term)}
+ def ++= (term: R) { newAcc._value = param.addInPlace(newAcc._value, term) }
/**
* Merge two accumulable objects together
@@ -115,18 +104,12 @@ class Accumulable[R, T] private (
* Normally, a user will not want to use this version, but will instead call `add`.
* @param term the other `R` that will get merged with this
*/
- def merge(term: R) { value_ = param.addInPlace(value_, term)}
+ def merge(term: R) { newAcc._value = param.addInPlace(newAcc._value, term) }
/**
* Access the accumulator's current value; only allowed on driver.
*/
- def value: R = {
- if (!deserialized) {
- value_
- } else {
- throw new UnsupportedOperationException("Can't read accumulator value in task")
- }
- }
+ def value: R = newAcc.value
/**
* Get the current value of this accumulator from within a task.
@@ -137,14 +120,14 @@ class Accumulable[R, T] private (
* The typical use of this method is to directly mutate the local value, eg., to add
* an element to a Set.
*/
- def localValue: R = value_
+ def localValue: R = newAcc.localValue
/**
* Set the accumulator's value; only allowed on driver.
*/
def value_= (newValue: R) {
- if (!deserialized) {
- value_ = newValue
+ if (newAcc.isAtDriverSide) {
+ newAcc._value = newValue
} else {
throw new UnsupportedOperationException("Can't assign accumulator value in task")
}
@@ -153,7 +136,7 @@ class Accumulable[R, T] private (
/**
* Set the accumulator's value. For internal use only.
*/
- def setValue(newValue: R): Unit = { value_ = newValue }
+ def setValue(newValue: R): Unit = { newAcc._value = newValue }
/**
* Set the accumulator's value. For internal use only.
@@ -168,22 +151,7 @@ class Accumulable[R, T] private (
new AccumulableInfo(id, name, update, value, isInternal, countFailedValues)
}
- // Called by Java when deserializing an object
- private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
- in.defaultReadObject()
- value_ = zero
- deserialized = true
-
- // Automatically register the accumulator when it is deserialized with the task closure.
- // This is for external accumulators and internal ones that do not represent task level
- // metrics, e.g. internal SQL metrics, which are per-operator.
- val taskContext = TaskContext.get()
- if (taskContext != null) {
- taskContext.registerAccumulator(this)
- }
- }
-
- override def toString: String = if (value_ == null) "null" else value_.toString
+ override def toString: String = if (newAcc._value == null) "null" else newAcc._value.toString
}
diff --git a/core/src/main/scala/org/apache/spark/Accumulator.scala b/core/src/main/scala/org/apache/spark/Accumulator.scala
index 0c17f014c9..9b007b9776 100644
--- a/core/src/main/scala/org/apache/spark/Accumulator.scala
+++ b/core/src/main/scala/org/apache/spark/Accumulator.scala
@@ -68,73 +68,6 @@ class Accumulator[T] private[spark] (
extends Accumulable[T, T](initialValue, param, name, countFailedValues)
-// TODO: The multi-thread support in accumulators is kind of lame; check
-// if there's a more intuitive way of doing it right
-private[spark] object Accumulators extends Logging {
- /**
- * This global map holds the original accumulator objects that are created on the driver.
- * It keeps weak references to these objects so that accumulators can be garbage-collected
- * once the RDDs and user-code that reference them are cleaned up.
- * TODO: Don't use a global map; these should be tied to a SparkContext (SPARK-13051).
- */
- @GuardedBy("Accumulators")
- val originals = mutable.Map[Long, WeakReference[Accumulable[_, _]]]()
-
- private val nextId = new AtomicLong(0L)
-
- /**
- * Return a globally unique ID for a new [[Accumulable]].
- * Note: Once you copy the [[Accumulable]] the ID is no longer unique.
- */
- def newId(): Long = nextId.getAndIncrement
-
- /**
- * Register an [[Accumulable]] created on the driver such that it can be used on the executors.
- *
- * All accumulators registered here can later be used as a container for accumulating partial
- * values across multiple tasks. This is what [[org.apache.spark.scheduler.DAGScheduler]] does.
- * Note: if an accumulator is registered here, it should also be registered with the active
- * context cleaner for cleanup so as to avoid memory leaks.
- *
- * If an [[Accumulable]] with the same ID was already registered, this does nothing instead
- * of overwriting it. This happens when we copy accumulators, e.g. when we reconstruct
- * [[org.apache.spark.executor.TaskMetrics]] from accumulator updates.
- */
- def register(a: Accumulable[_, _]): Unit = synchronized {
- if (!originals.contains(a.id)) {
- originals(a.id) = new WeakReference[Accumulable[_, _]](a)
- }
- }
-
- /**
- * Unregister the [[Accumulable]] with the given ID, if any.
- */
- def remove(accId: Long): Unit = synchronized {
- originals.remove(accId)
- }
-
- /**
- * Return the [[Accumulable]] registered with the given ID, if any.
- */
- def get(id: Long): Option[Accumulable[_, _]] = synchronized {
- originals.get(id).map { weakRef =>
- // Since we are storing weak references, we must check whether the underlying data is valid.
- weakRef.get.getOrElse {
- throw new IllegalAccessError(s"Attempted to access garbage collected accumulator $id")
- }
- }
- }
-
- /**
- * Clear all registered [[Accumulable]]s. For testing only.
- */
- def clear(): Unit = synchronized {
- originals.clear()
- }
-
-}
-
-
/**
* A simpler version of [[org.apache.spark.AccumulableParam]] where the only data type you can add
* in is the same type as the accumulated value. An implicit AccumulatorParam object needs to be
diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala
index 76692ccec8..63a00a84af 100644
--- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala
@@ -144,7 +144,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
registerForCleanup(rdd, CleanRDD(rdd.id))
}
- def registerAccumulatorForCleanup(a: Accumulable[_, _]): Unit = {
+ def registerAccumulatorForCleanup(a: NewAccumulator[_, _]): Unit = {
registerForCleanup(a, CleanAccum(a.id))
}
@@ -241,7 +241,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
def doCleanupAccum(accId: Long, blocking: Boolean): Unit = {
try {
logDebug("Cleaning accumulator " + accId)
- Accumulators.remove(accId)
+ AccumulatorContext.remove(accId)
listeners.asScala.foreach(_.accumCleaned(accId))
logInfo("Cleaned accumulator " + accId)
} catch {
diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
index 2bdbd3fae9..9eac05fdf9 100644
--- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
+++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
@@ -35,7 +35,7 @@ import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils}
*/
private[spark] case class Heartbeat(
executorId: String,
- accumUpdates: Array[(Long, Seq[AccumulableInfo])], // taskId -> accum updates
+ accumUpdates: Array[(Long, Seq[NewAccumulator[_, _]])], // taskId -> accumulator updates
blockManagerId: BlockManagerId)
/**
diff --git a/core/src/main/scala/org/apache/spark/NewAccumulator.scala b/core/src/main/scala/org/apache/spark/NewAccumulator.scala
new file mode 100644
index 0000000000..edb9b741a8
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/NewAccumulator.scala
@@ -0,0 +1,391 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.{lang => jl}
+import java.io.ObjectInputStream
+import java.util.concurrent.atomic.AtomicLong
+import javax.annotation.concurrent.GuardedBy
+
+import org.apache.spark.scheduler.AccumulableInfo
+import org.apache.spark.util.Utils
+
+
+private[spark] case class AccumulatorMetadata(
+ id: Long,
+ name: Option[String],
+ countFailedValues: Boolean) extends Serializable
+
+
+/**
+ * The base class for accumulators, that can accumulate inputs of type `IN`, and produce output of
+ * type `OUT`.
+ */
+abstract class NewAccumulator[IN, OUT] extends Serializable {
+ private[spark] var metadata: AccumulatorMetadata = _
+ private[this] var atDriverSide = true
+
+ private[spark] def register(
+ sc: SparkContext,
+ name: Option[String] = None,
+ countFailedValues: Boolean = false): Unit = {
+ if (this.metadata != null) {
+ throw new IllegalStateException("Cannot register an Accumulator twice.")
+ }
+ this.metadata = AccumulatorMetadata(AccumulatorContext.newId(), name, countFailedValues)
+ AccumulatorContext.register(this)
+ sc.cleaner.foreach(_.registerAccumulatorForCleanup(this))
+ }
+
+ /**
+ * Returns true if this accumulator has been registered. Note that all accumulators must be
+ * registered before ues, or it will throw exception.
+ */
+ final def isRegistered: Boolean =
+ metadata != null && AccumulatorContext.originals.containsKey(metadata.id)
+
+ private def assertMetadataNotNull(): Unit = {
+ if (metadata == null) {
+ throw new IllegalAccessError("The metadata of this accumulator has not been assigned yet.")
+ }
+ }
+
+ /**
+ * Returns the id of this accumulator, can only be called after registration.
+ */
+ final def id: Long = {
+ assertMetadataNotNull()
+ metadata.id
+ }
+
+ /**
+ * Returns the name of this accumulator, can only be called after registration.
+ */
+ final def name: Option[String] = {
+ assertMetadataNotNull()
+ metadata.name
+ }
+
+ /**
+ * Whether to accumulate values from failed tasks. This is set to true for system and time
+ * metrics like serialization time or bytes spilled, and false for things with absolute values
+ * like number of input rows. This should be used for internal metrics only.
+ */
+ private[spark] final def countFailedValues: Boolean = {
+ assertMetadataNotNull()
+ metadata.countFailedValues
+ }
+
+ /**
+ * Creates an [[AccumulableInfo]] representation of this [[NewAccumulator]] with the provided
+ * values.
+ */
+ private[spark] def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = {
+ val isInternal = name.exists(_.startsWith(InternalAccumulator.METRICS_PREFIX))
+ new AccumulableInfo(id, name, update, value, isInternal, countFailedValues)
+ }
+
+ final private[spark] def isAtDriverSide: Boolean = atDriverSide
+
+ /**
+ * Tells if this accumulator is zero value or not. e.g. for a counter accumulator, 0 is zero
+ * value; for a list accumulator, Nil is zero value.
+ */
+ def isZero(): Boolean
+
+ /**
+ * Creates a new copy of this accumulator, which is zero value. i.e. call `isZero` on the copy
+ * must return true.
+ */
+ def copyAndReset(): NewAccumulator[IN, OUT]
+
+ /**
+ * Takes the inputs and accumulates. e.g. it can be a simple `+=` for counter accumulator.
+ */
+ def add(v: IN): Unit
+
+ /**
+ * Merges another same-type accumulator into this one and update its state, i.e. this should be
+ * merge-in-place.
+ */
+ def merge(other: NewAccumulator[IN, OUT]): Unit
+
+ /**
+ * Access this accumulator's current value; only allowed on driver.
+ */
+ final def value: OUT = {
+ if (atDriverSide) {
+ localValue
+ } else {
+ throw new UnsupportedOperationException("Can't read accumulator value in task")
+ }
+ }
+
+ /**
+ * Defines the current value of this accumulator.
+ *
+ * This is NOT the global value of the accumulator. To get the global value after a
+ * completed operation on the dataset, call `value`.
+ */
+ def localValue: OUT
+
+ // Called by Java when serializing an object
+ final protected def writeReplace(): Any = {
+ if (atDriverSide) {
+ if (!isRegistered) {
+ throw new UnsupportedOperationException(
+ "Accumulator must be registered before send to executor")
+ }
+ val copy = copyAndReset()
+ assert(copy.isZero(), "copyAndReset must return a zero value copy")
+ copy.metadata = metadata
+ copy
+ } else {
+ this
+ }
+ }
+
+ // Called by Java when deserializing an object
+ private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
+ in.defaultReadObject()
+ if (atDriverSide) {
+ atDriverSide = false
+
+ // Automatically register the accumulator when it is deserialized with the task closure.
+ // This is for external accumulators and internal ones that do not represent task level
+ // metrics, e.g. internal SQL metrics, which are per-operator.
+ val taskContext = TaskContext.get()
+ if (taskContext != null) {
+ taskContext.registerAccumulator(this)
+ }
+ } else {
+ atDriverSide = true
+ }
+ }
+
+ override def toString: String = {
+ if (metadata == null) {
+ "Un-registered Accumulator: " + getClass.getSimpleName
+ } else {
+ getClass.getSimpleName + s"(id: $id, name: $name, value: $localValue)"
+ }
+ }
+}
+
+
+private[spark] object AccumulatorContext {
+
+ /**
+ * This global map holds the original accumulator objects that are created on the driver.
+ * It keeps weak references to these objects so that accumulators can be garbage-collected
+ * once the RDDs and user-code that reference them are cleaned up.
+ * TODO: Don't use a global map; these should be tied to a SparkContext (SPARK-13051).
+ */
+ @GuardedBy("AccumulatorContext")
+ val originals = new java.util.HashMap[Long, jl.ref.WeakReference[NewAccumulator[_, _]]]
+
+ private[this] val nextId = new AtomicLong(0L)
+
+ /**
+ * Return a globally unique ID for a new [[Accumulator]].
+ * Note: Once you copy the [[Accumulator]] the ID is no longer unique.
+ */
+ def newId(): Long = nextId.getAndIncrement
+
+ /**
+ * Register an [[Accumulator]] created on the driver such that it can be used on the executors.
+ *
+ * All accumulators registered here can later be used as a container for accumulating partial
+ * values across multiple tasks. This is what [[org.apache.spark.scheduler.DAGScheduler]] does.
+ * Note: if an accumulator is registered here, it should also be registered with the active
+ * context cleaner for cleanup so as to avoid memory leaks.
+ *
+ * If an [[Accumulator]] with the same ID was already registered, this does nothing instead
+ * of overwriting it. We will never register same accumulator twice, this is just a sanity check.
+ */
+ def register(a: NewAccumulator[_, _]): Unit = synchronized {
+ if (!originals.containsKey(a.id)) {
+ originals.put(a.id, new jl.ref.WeakReference[NewAccumulator[_, _]](a))
+ }
+ }
+
+ /**
+ * Unregister the [[Accumulator]] with the given ID, if any.
+ */
+ def remove(id: Long): Unit = synchronized {
+ originals.remove(id)
+ }
+
+ /**
+ * Return the [[Accumulator]] registered with the given ID, if any.
+ */
+ def get(id: Long): Option[NewAccumulator[_, _]] = synchronized {
+ Option(originals.get(id)).map { ref =>
+ // Since we are storing weak references, we must check whether the underlying data is valid.
+ val acc = ref.get
+ if (acc eq null) {
+ throw new IllegalAccessError(s"Attempted to access garbage collected accumulator $id")
+ }
+ acc
+ }
+ }
+
+ /**
+ * Clear all registered [[Accumulator]]s. For testing only.
+ */
+ def clear(): Unit = synchronized {
+ originals.clear()
+ }
+}
+
+
+class LongAccumulator extends NewAccumulator[jl.Long, jl.Long] {
+ private[this] var _sum = 0L
+
+ override def isZero(): Boolean = _sum == 0
+
+ override def copyAndReset(): LongAccumulator = new LongAccumulator
+
+ override def add(v: jl.Long): Unit = _sum += v
+
+ def add(v: Long): Unit = _sum += v
+
+ def sum: Long = _sum
+
+ override def merge(other: NewAccumulator[jl.Long, jl.Long]): Unit = other match {
+ case o: LongAccumulator => _sum += o.sum
+ case _ => throw new UnsupportedOperationException(
+ s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
+ }
+
+ private[spark] def setValue(newValue: Long): Unit = _sum = newValue
+
+ override def localValue: jl.Long = _sum
+}
+
+
+class DoubleAccumulator extends NewAccumulator[jl.Double, jl.Double] {
+ private[this] var _sum = 0.0
+
+ override def isZero(): Boolean = _sum == 0.0
+
+ override def copyAndReset(): DoubleAccumulator = new DoubleAccumulator
+
+ override def add(v: jl.Double): Unit = _sum += v
+
+ def add(v: Double): Unit = _sum += v
+
+ def sum: Double = _sum
+
+ override def merge(other: NewAccumulator[jl.Double, jl.Double]): Unit = other match {
+ case o: DoubleAccumulator => _sum += o.sum
+ case _ => throw new UnsupportedOperationException(
+ s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
+ }
+
+ private[spark] def setValue(newValue: Double): Unit = _sum = newValue
+
+ override def localValue: jl.Double = _sum
+}
+
+
+class AverageAccumulator extends NewAccumulator[jl.Double, jl.Double] {
+ private[this] var _sum = 0.0
+ private[this] var _count = 0L
+
+ override def isZero(): Boolean = _sum == 0.0 && _count == 0
+
+ override def copyAndReset(): AverageAccumulator = new AverageAccumulator
+
+ override def add(v: jl.Double): Unit = {
+ _sum += v
+ _count += 1
+ }
+
+ def add(d: Double): Unit = {
+ _sum += d
+ _count += 1
+ }
+
+ override def merge(other: NewAccumulator[jl.Double, jl.Double]): Unit = other match {
+ case o: AverageAccumulator =>
+ _sum += o.sum
+ _count += o.count
+ case _ => throw new UnsupportedOperationException(
+ s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
+ }
+
+ override def localValue: jl.Double = if (_count == 0) {
+ Double.NaN
+ } else {
+ _sum / _count
+ }
+
+ def sum: Double = _sum
+
+ def count: Long = _count
+}
+
+
+class ListAccumulator[T] extends NewAccumulator[T, java.util.List[T]] {
+ private[this] val _list: java.util.List[T] = new java.util.ArrayList[T]
+
+ override def isZero(): Boolean = _list.isEmpty
+
+ override def copyAndReset(): ListAccumulator[T] = new ListAccumulator
+
+ override def add(v: T): Unit = _list.add(v)
+
+ override def merge(other: NewAccumulator[T, java.util.List[T]]): Unit = other match {
+ case o: ListAccumulator[T] => _list.addAll(o.localValue)
+ case _ => throw new UnsupportedOperationException(
+ s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
+ }
+
+ override def localValue: java.util.List[T] = java.util.Collections.unmodifiableList(_list)
+
+ private[spark] def setValue(newValue: java.util.List[T]): Unit = {
+ _list.clear()
+ _list.addAll(newValue)
+ }
+}
+
+
+class LegacyAccumulatorWrapper[R, T](
+ initialValue: R,
+ param: org.apache.spark.AccumulableParam[R, T]) extends NewAccumulator[T, R] {
+ private[spark] var _value = initialValue // Current value on driver
+
+ override def isZero(): Boolean = _value == param.zero(initialValue)
+
+ override def copyAndReset(): LegacyAccumulatorWrapper[R, T] = {
+ val acc = new LegacyAccumulatorWrapper(initialValue, param)
+ acc._value = param.zero(initialValue)
+ acc
+ }
+
+ override def add(v: T): Unit = _value = param.addAccumulator(_value, v)
+
+ override def merge(other: NewAccumulator[T, R]): Unit = other match {
+ case o: LegacyAccumulatorWrapper[R, T] => _value = param.addInPlace(_value, o.localValue)
+ case _ => throw new UnsupportedOperationException(
+ s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
+ }
+
+ override def localValue: R = _value
+}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index f322a770bf..865989aee0 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -1217,10 +1217,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* Create an [[org.apache.spark.Accumulator]] variable of a given type, which tasks can "add"
* values to using the `+=` method. Only the driver can access the accumulator's `value`.
*/
- def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]): Accumulator[T] =
- {
+ def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]): Accumulator[T] = {
val acc = new Accumulator(initialValue, param)
- cleaner.foreach(_.registerAccumulatorForCleanup(acc))
+ cleaner.foreach(_.registerAccumulatorForCleanup(acc.newAcc))
acc
}
@@ -1232,7 +1231,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
def accumulator[T](initialValue: T, name: String)(implicit param: AccumulatorParam[T])
: Accumulator[T] = {
val acc = new Accumulator(initialValue, param, Some(name))
- cleaner.foreach(_.registerAccumulatorForCleanup(acc))
+ cleaner.foreach(_.registerAccumulatorForCleanup(acc.newAcc))
acc
}
@@ -1245,7 +1244,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
def accumulable[R, T](initialValue: R)(implicit param: AccumulableParam[R, T])
: Accumulable[R, T] = {
val acc = new Accumulable(initialValue, param)
- cleaner.foreach(_.registerAccumulatorForCleanup(acc))
+ cleaner.foreach(_.registerAccumulatorForCleanup(acc.newAcc))
acc
}
@@ -1259,7 +1258,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
def accumulable[R, T](initialValue: R, name: String)(implicit param: AccumulableParam[R, T])
: Accumulable[R, T] = {
val acc = new Accumulable(initialValue, param, Some(name))
- cleaner.foreach(_.registerAccumulatorForCleanup(acc))
+ cleaner.foreach(_.registerAccumulatorForCleanup(acc.newAcc))
acc
}
@@ -1273,7 +1272,101 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
(initialValue: R): Accumulable[R, T] = {
val param = new GrowableAccumulableParam[R, T]
val acc = new Accumulable(initialValue, param)
- cleaner.foreach(_.registerAccumulatorForCleanup(acc))
+ cleaner.foreach(_.registerAccumulatorForCleanup(acc.newAcc))
+ acc
+ }
+
+ /**
+ * Register the given accumulator. Note that accumulators must be registered before use, or it
+ * will throw exception.
+ */
+ def register(acc: NewAccumulator[_, _]): Unit = {
+ acc.register(this)
+ }
+
+ /**
+ * Register the given accumulator with given name. Note that accumulators must be registered
+ * before use, or it will throw exception.
+ */
+ def register(acc: NewAccumulator[_, _], name: String): Unit = {
+ acc.register(this, name = Some(name))
+ }
+
+ /**
+ * Create and register a long accumulator, which starts with 0 and accumulates inputs by `+=`.
+ */
+ def longAccumulator: LongAccumulator = {
+ val acc = new LongAccumulator
+ register(acc)
+ acc
+ }
+
+ /**
+ * Create and register a long accumulator, which starts with 0 and accumulates inputs by `+=`.
+ */
+ def longAccumulator(name: String): LongAccumulator = {
+ val acc = new LongAccumulator
+ register(acc, name)
+ acc
+ }
+
+ /**
+ * Create and register a double accumulator, which starts with 0 and accumulates inputs by `+=`.
+ */
+ def doubleAccumulator: DoubleAccumulator = {
+ val acc = new DoubleAccumulator
+ register(acc)
+ acc
+ }
+
+ /**
+ * Create and register a double accumulator, which starts with 0 and accumulates inputs by `+=`.
+ */
+ def doubleAccumulator(name: String): DoubleAccumulator = {
+ val acc = new DoubleAccumulator
+ register(acc, name)
+ acc
+ }
+
+ /**
+ * Create and register an average accumulator, which accumulates double inputs by recording the
+ * total sum and total count, and produce the output by sum / total. Note that Double.NaN will be
+ * returned if no input is added.
+ */
+ def averageAccumulator: AverageAccumulator = {
+ val acc = new AverageAccumulator
+ register(acc)
+ acc
+ }
+
+ /**
+ * Create and register an average accumulator, which accumulates double inputs by recording the
+ * total sum and total count, and produce the output by sum / total. Note that Double.NaN will be
+ * returned if no input is added.
+ */
+ def averageAccumulator(name: String): AverageAccumulator = {
+ val acc = new AverageAccumulator
+ register(acc, name)
+ acc
+ }
+
+ /**
+ * Create and register a list accumulator, which starts with empty list and accumulates inputs
+ * by adding them into the inner list.
+ */
+ def listAccumulator[T]: ListAccumulator[T] = {
+ val acc = new ListAccumulator[T]
+ register(acc)
+ acc
+ }
+
+ /**
+ * Create and register a list accumulator, which starts with empty list and accumulates inputs
+ * by adding them into the inner list.
+ */
+ def listAccumulator[T](name: String): ListAccumulator[T] = {
+ val acc = new ListAccumulator[T]
+ register(acc, name)
acc
}
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index e7940bd9ed..9e53257462 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -188,6 +188,6 @@ abstract class TaskContext extends Serializable {
* Register an accumulator that belongs to this task. Accumulators must call this method when
* deserializing in executors.
*/
- private[spark] def registerAccumulator(a: Accumulable[_, _]): Unit
+ private[spark] def registerAccumulator(a: NewAccumulator[_, _]): Unit
}
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index 43e555670d..bc3807f5db 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -122,7 +122,7 @@ private[spark] class TaskContextImpl(
override def getMetricsSources(sourceName: String): Seq[Source] =
metricsSystem.getSourcesByName(sourceName)
- private[spark] override def registerAccumulator(a: Accumulable[_, _]): Unit = {
+ private[spark] override def registerAccumulator(a: NewAccumulator[_, _]): Unit = {
taskMetrics.registerAccumulator(a)
}
diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
index 7487cfe9c5..82ba2d0c27 100644
--- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala
+++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
@@ -19,10 +19,7 @@ package org.apache.spark
import java.io.{ObjectInputStream, ObjectOutputStream}
-import scala.util.Try
-
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.executor.TaskMetrics
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.AccumulableInfo
import org.apache.spark.storage.BlockManagerId
@@ -120,18 +117,10 @@ case class ExceptionFailure(
stackTrace: Array[StackTraceElement],
fullStackTrace: String,
private val exceptionWrapper: Option[ThrowableSerializationWrapper],
- accumUpdates: Seq[AccumulableInfo] = Seq.empty[AccumulableInfo])
+ accumUpdates: Seq[AccumulableInfo] = Seq.empty,
+ private[spark] var accums: Seq[NewAccumulator[_, _]] = Nil)
extends TaskFailedReason {
- @deprecated("use accumUpdates instead", "2.0.0")
- val metrics: Option[TaskMetrics] = {
- if (accumUpdates.nonEmpty) {
- Try(TaskMetrics.fromAccumulatorUpdates(accumUpdates)).toOption
- } else {
- None
- }
- }
-
/**
* `preserveCause` is used to keep the exception itself so it is available to the
* driver. This may be set to `false` in the event that the exception is not in fact
@@ -149,6 +138,11 @@ case class ExceptionFailure(
this(e, accumUpdates, preserveCause = true)
}
+ private[spark] def withAccums(accums: Seq[NewAccumulator[_, _]]): ExceptionFailure = {
+ this.accums = accums
+ this
+ }
+
def exception: Option[Throwable] = exceptionWrapper.flatMap(w => Option(w.exception))
override def toErrorString: String =
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 650f05c309..4d61f7e232 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -353,22 +353,24 @@ private[spark] class Executor(
logError(s"Exception in $taskName (TID $taskId)", t)
// Collect latest accumulator values to report back to the driver
- val accumulatorUpdates: Seq[AccumulableInfo] =
+ val accums: Seq[NewAccumulator[_, _]] =
if (task != null) {
task.metrics.setExecutorRunTime(System.currentTimeMillis() - taskStart)
task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime)
task.collectAccumulatorUpdates(taskFailed = true)
} else {
- Seq.empty[AccumulableInfo]
+ Seq.empty
}
+ val accUpdates = accums.map(acc => acc.toInfo(Some(acc.localValue), None))
+
val serializedTaskEndReason = {
try {
- ser.serialize(new ExceptionFailure(t, accumulatorUpdates))
+ ser.serialize(new ExceptionFailure(t, accUpdates).withAccums(accums))
} catch {
case _: NotSerializableException =>
// t is not serializable so just send the stacktrace
- ser.serialize(new ExceptionFailure(t, accumulatorUpdates, preserveCause = false))
+ ser.serialize(new ExceptionFailure(t, accUpdates, false).withAccums(accums))
}
}
execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason)
@@ -476,14 +478,14 @@ private[spark] class Executor(
/** Reports heartbeat and metrics for active tasks to the driver. */
private def reportHeartBeat(): Unit = {
// list of (task id, accumUpdates) to send back to the driver
- val accumUpdates = new ArrayBuffer[(Long, Seq[AccumulableInfo])]()
+ val accumUpdates = new ArrayBuffer[(Long, Seq[NewAccumulator[_, _]])]()
val curGCTime = computeTotalGcTime()
for (taskRunner <- runningTasks.values().asScala) {
if (taskRunner.task != null) {
taskRunner.task.metrics.mergeShuffleReadMetrics()
taskRunner.task.metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime)
- accumUpdates += ((taskRunner.taskId, taskRunner.task.metrics.accumulatorUpdates()))
+ accumUpdates += ((taskRunner.taskId, taskRunner.task.metrics.accumulators()))
}
}
diff --git a/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala b/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala
index 535352e7dd..6f7160ac0d 100644
--- a/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala
@@ -17,7 +17,7 @@
package org.apache.spark.executor
-import org.apache.spark.InternalAccumulator
+import org.apache.spark.LongAccumulator
import org.apache.spark.annotation.DeveloperApi
@@ -40,20 +40,18 @@ object DataReadMethod extends Enumeration with Serializable {
*/
@DeveloperApi
class InputMetrics private[spark] () extends Serializable {
- import InternalAccumulator._
-
- private[executor] val _bytesRead = TaskMetrics.createLongAccum(input.BYTES_READ)
- private[executor] val _recordsRead = TaskMetrics.createLongAccum(input.RECORDS_READ)
+ private[executor] val _bytesRead = new LongAccumulator
+ private[executor] val _recordsRead = new LongAccumulator
/**
* Total number of bytes read.
*/
- def bytesRead: Long = _bytesRead.localValue
+ def bytesRead: Long = _bytesRead.sum
/**
* Total number of records read.
*/
- def recordsRead: Long = _recordsRead.localValue
+ def recordsRead: Long = _recordsRead.sum
private[spark] def incBytesRead(v: Long): Unit = _bytesRead.add(v)
private[spark] def incRecordsRead(v: Long): Unit = _recordsRead.add(v)
diff --git a/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala b/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala
index 586c98b156..db3924cb69 100644
--- a/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala
@@ -17,7 +17,7 @@
package org.apache.spark.executor
-import org.apache.spark.InternalAccumulator
+import org.apache.spark.LongAccumulator
import org.apache.spark.annotation.DeveloperApi
@@ -39,20 +39,18 @@ object DataWriteMethod extends Enumeration with Serializable {
*/
@DeveloperApi
class OutputMetrics private[spark] () extends Serializable {
- import InternalAccumulator._
-
- private[executor] val _bytesWritten = TaskMetrics.createLongAccum(output.BYTES_WRITTEN)
- private[executor] val _recordsWritten = TaskMetrics.createLongAccum(output.RECORDS_WRITTEN)
+ private[executor] val _bytesWritten = new LongAccumulator
+ private[executor] val _recordsWritten = new LongAccumulator
/**
* Total number of bytes written.
*/
- def bytesWritten: Long = _bytesWritten.localValue
+ def bytesWritten: Long = _bytesWritten.sum
/**
* Total number of records written.
*/
- def recordsWritten: Long = _recordsWritten.localValue
+ def recordsWritten: Long = _recordsWritten.sum
private[spark] def setBytesWritten(v: Long): Unit = _bytesWritten.setValue(v)
private[spark] def setRecordsWritten(v: Long): Unit = _recordsWritten.setValue(v)
diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala
index f012a74db6..fa962108c3 100644
--- a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala
@@ -17,7 +17,7 @@
package org.apache.spark.executor
-import org.apache.spark.InternalAccumulator
+import org.apache.spark.LongAccumulator
import org.apache.spark.annotation.DeveloperApi
@@ -28,52 +28,44 @@ import org.apache.spark.annotation.DeveloperApi
*/
@DeveloperApi
class ShuffleReadMetrics private[spark] () extends Serializable {
- import InternalAccumulator._
-
- private[executor] val _remoteBlocksFetched =
- TaskMetrics.createIntAccum(shuffleRead.REMOTE_BLOCKS_FETCHED)
- private[executor] val _localBlocksFetched =
- TaskMetrics.createIntAccum(shuffleRead.LOCAL_BLOCKS_FETCHED)
- private[executor] val _remoteBytesRead =
- TaskMetrics.createLongAccum(shuffleRead.REMOTE_BYTES_READ)
- private[executor] val _localBytesRead =
- TaskMetrics.createLongAccum(shuffleRead.LOCAL_BYTES_READ)
- private[executor] val _fetchWaitTime =
- TaskMetrics.createLongAccum(shuffleRead.FETCH_WAIT_TIME)
- private[executor] val _recordsRead =
- TaskMetrics.createLongAccum(shuffleRead.RECORDS_READ)
+ private[executor] val _remoteBlocksFetched = new LongAccumulator
+ private[executor] val _localBlocksFetched = new LongAccumulator
+ private[executor] val _remoteBytesRead = new LongAccumulator
+ private[executor] val _localBytesRead = new LongAccumulator
+ private[executor] val _fetchWaitTime = new LongAccumulator
+ private[executor] val _recordsRead = new LongAccumulator
/**
* Number of remote blocks fetched in this shuffle by this task.
*/
- def remoteBlocksFetched: Int = _remoteBlocksFetched.localValue
+ def remoteBlocksFetched: Long = _remoteBlocksFetched.sum
/**
* Number of local blocks fetched in this shuffle by this task.
*/
- def localBlocksFetched: Int = _localBlocksFetched.localValue
+ def localBlocksFetched: Long = _localBlocksFetched.sum
/**
* Total number of remote bytes read from the shuffle by this task.
*/
- def remoteBytesRead: Long = _remoteBytesRead.localValue
+ def remoteBytesRead: Long = _remoteBytesRead.sum
/**
* Shuffle data that was read from the local disk (as opposed to from a remote executor).
*/
- def localBytesRead: Long = _localBytesRead.localValue
+ def localBytesRead: Long = _localBytesRead.sum
/**
* Time the task spent waiting for remote shuffle blocks. This only includes the time
* blocking on shuffle input data. For instance if block B is being fetched while the task is
* still not finished processing block A, it is not considered to be blocking on block B.
*/
- def fetchWaitTime: Long = _fetchWaitTime.localValue
+ def fetchWaitTime: Long = _fetchWaitTime.sum
/**
* Total number of records read from the shuffle by this task.
*/
- def recordsRead: Long = _recordsRead.localValue
+ def recordsRead: Long = _recordsRead.sum
/**
* Total bytes fetched in the shuffle by this task (both remote and local).
@@ -83,10 +75,10 @@ class ShuffleReadMetrics private[spark] () extends Serializable {
/**
* Number of blocks fetched in this shuffle by this task (remote or local).
*/
- def totalBlocksFetched: Int = remoteBlocksFetched + localBlocksFetched
+ def totalBlocksFetched: Long = remoteBlocksFetched + localBlocksFetched
- private[spark] def incRemoteBlocksFetched(v: Int): Unit = _remoteBlocksFetched.add(v)
- private[spark] def incLocalBlocksFetched(v: Int): Unit = _localBlocksFetched.add(v)
+ private[spark] def incRemoteBlocksFetched(v: Long): Unit = _remoteBlocksFetched.add(v)
+ private[spark] def incLocalBlocksFetched(v: Long): Unit = _localBlocksFetched.add(v)
private[spark] def incRemoteBytesRead(v: Long): Unit = _remoteBytesRead.add(v)
private[spark] def incLocalBytesRead(v: Long): Unit = _localBytesRead.add(v)
private[spark] def incFetchWaitTime(v: Long): Unit = _fetchWaitTime.add(v)
@@ -104,12 +96,12 @@ class ShuffleReadMetrics private[spark] () extends Serializable {
* [[TempShuffleReadMetrics]] into `this`.
*/
private[spark] def setMergeValues(metrics: Seq[TempShuffleReadMetrics]): Unit = {
- _remoteBlocksFetched.setValue(_remoteBlocksFetched.zero)
- _localBlocksFetched.setValue(_localBlocksFetched.zero)
- _remoteBytesRead.setValue(_remoteBytesRead.zero)
- _localBytesRead.setValue(_localBytesRead.zero)
- _fetchWaitTime.setValue(_fetchWaitTime.zero)
- _recordsRead.setValue(_recordsRead.zero)
+ _remoteBlocksFetched.setValue(0)
+ _localBlocksFetched.setValue(0)
+ _remoteBytesRead.setValue(0)
+ _localBytesRead.setValue(0)
+ _fetchWaitTime.setValue(0)
+ _recordsRead.setValue(0)
metrics.foreach { metric =>
_remoteBlocksFetched.add(metric.remoteBlocksFetched)
_localBlocksFetched.add(metric.localBlocksFetched)
@@ -127,22 +119,22 @@ class ShuffleReadMetrics private[spark] () extends Serializable {
* last.
*/
private[spark] class TempShuffleReadMetrics {
- private[this] var _remoteBlocksFetched = 0
- private[this] var _localBlocksFetched = 0
+ private[this] var _remoteBlocksFetched = 0L
+ private[this] var _localBlocksFetched = 0L
private[this] var _remoteBytesRead = 0L
private[this] var _localBytesRead = 0L
private[this] var _fetchWaitTime = 0L
private[this] var _recordsRead = 0L
- def incRemoteBlocksFetched(v: Int): Unit = _remoteBlocksFetched += v
- def incLocalBlocksFetched(v: Int): Unit = _localBlocksFetched += v
+ def incRemoteBlocksFetched(v: Long): Unit = _remoteBlocksFetched += v
+ def incLocalBlocksFetched(v: Long): Unit = _localBlocksFetched += v
def incRemoteBytesRead(v: Long): Unit = _remoteBytesRead += v
def incLocalBytesRead(v: Long): Unit = _localBytesRead += v
def incFetchWaitTime(v: Long): Unit = _fetchWaitTime += v
def incRecordsRead(v: Long): Unit = _recordsRead += v
- def remoteBlocksFetched: Int = _remoteBlocksFetched
- def localBlocksFetched: Int = _localBlocksFetched
+ def remoteBlocksFetched: Long = _remoteBlocksFetched
+ def localBlocksFetched: Long = _localBlocksFetched
def remoteBytesRead: Long = _remoteBytesRead
def localBytesRead: Long = _localBytesRead
def fetchWaitTime: Long = _fetchWaitTime
diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala
index 7326fba841..0e70a4f522 100644
--- a/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala
@@ -17,7 +17,7 @@
package org.apache.spark.executor
-import org.apache.spark.InternalAccumulator
+import org.apache.spark.LongAccumulator
import org.apache.spark.annotation.DeveloperApi
@@ -28,29 +28,24 @@ import org.apache.spark.annotation.DeveloperApi
*/
@DeveloperApi
class ShuffleWriteMetrics private[spark] () extends Serializable {
- import InternalAccumulator._
-
- private[executor] val _bytesWritten =
- TaskMetrics.createLongAccum(shuffleWrite.BYTES_WRITTEN)
- private[executor] val _recordsWritten =
- TaskMetrics.createLongAccum(shuffleWrite.RECORDS_WRITTEN)
- private[executor] val _writeTime =
- TaskMetrics.createLongAccum(shuffleWrite.WRITE_TIME)
+ private[executor] val _bytesWritten = new LongAccumulator
+ private[executor] val _recordsWritten = new LongAccumulator
+ private[executor] val _writeTime = new LongAccumulator
/**
* Number of bytes written for the shuffle by this task.
*/
- def bytesWritten: Long = _bytesWritten.localValue
+ def bytesWritten: Long = _bytesWritten.sum
/**
* Total number of records written to the shuffle by this task.
*/
- def recordsWritten: Long = _recordsWritten.localValue
+ def recordsWritten: Long = _recordsWritten.sum
/**
* Time the task spent blocking on writes to disk or buffer cache, in nanoseconds.
*/
- def writeTime: Long = _writeTime.localValue
+ def writeTime: Long = _writeTime.sum
private[spark] def incBytesWritten(v: Long): Unit = _bytesWritten.add(v)
private[spark] def incRecordsWritten(v: Long): Unit = _recordsWritten.add(v)
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index 8513d053f2..0b64917219 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -17,10 +17,9 @@
package org.apache.spark.executor
-import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.{ArrayBuffer, LinkedHashMap}
import org.apache.spark._
-import org.apache.spark.AccumulatorParam.{IntAccumulatorParam, LongAccumulatorParam, UpdatedBlockStatusesAccumulatorParam}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.AccumulableInfo
@@ -42,53 +41,51 @@ import org.apache.spark.storage.{BlockId, BlockStatus}
*/
@DeveloperApi
class TaskMetrics private[spark] () extends Serializable {
- import InternalAccumulator._
-
// Each metric is internally represented as an accumulator
- private val _executorDeserializeTime = TaskMetrics.createLongAccum(EXECUTOR_DESERIALIZE_TIME)
- private val _executorRunTime = TaskMetrics.createLongAccum(EXECUTOR_RUN_TIME)
- private val _resultSize = TaskMetrics.createLongAccum(RESULT_SIZE)
- private val _jvmGCTime = TaskMetrics.createLongAccum(JVM_GC_TIME)
- private val _resultSerializationTime = TaskMetrics.createLongAccum(RESULT_SERIALIZATION_TIME)
- private val _memoryBytesSpilled = TaskMetrics.createLongAccum(MEMORY_BYTES_SPILLED)
- private val _diskBytesSpilled = TaskMetrics.createLongAccum(DISK_BYTES_SPILLED)
- private val _peakExecutionMemory = TaskMetrics.createLongAccum(PEAK_EXECUTION_MEMORY)
- private val _updatedBlockStatuses = TaskMetrics.createBlocksAccum(UPDATED_BLOCK_STATUSES)
+ private val _executorDeserializeTime = new LongAccumulator
+ private val _executorRunTime = new LongAccumulator
+ private val _resultSize = new LongAccumulator
+ private val _jvmGCTime = new LongAccumulator
+ private val _resultSerializationTime = new LongAccumulator
+ private val _memoryBytesSpilled = new LongAccumulator
+ private val _diskBytesSpilled = new LongAccumulator
+ private val _peakExecutionMemory = new LongAccumulator
+ private val _updatedBlockStatuses = new BlockStatusesAccumulator
/**
* Time taken on the executor to deserialize this task.
*/
- def executorDeserializeTime: Long = _executorDeserializeTime.localValue
+ def executorDeserializeTime: Long = _executorDeserializeTime.sum
/**
* Time the executor spends actually running the task (including fetching shuffle data).
*/
- def executorRunTime: Long = _executorRunTime.localValue
+ def executorRunTime: Long = _executorRunTime.sum
/**
* The number of bytes this task transmitted back to the driver as the TaskResult.
*/
- def resultSize: Long = _resultSize.localValue
+ def resultSize: Long = _resultSize.sum
/**
* Amount of time the JVM spent in garbage collection while executing this task.
*/
- def jvmGCTime: Long = _jvmGCTime.localValue
+ def jvmGCTime: Long = _jvmGCTime.sum
/**
* Amount of time spent serializing the task result.
*/
- def resultSerializationTime: Long = _resultSerializationTime.localValue
+ def resultSerializationTime: Long = _resultSerializationTime.sum
/**
* The number of in-memory bytes spilled by this task.
*/
- def memoryBytesSpilled: Long = _memoryBytesSpilled.localValue
+ def memoryBytesSpilled: Long = _memoryBytesSpilled.sum
/**
* The number of on-disk bytes spilled by this task.
*/
- def diskBytesSpilled: Long = _diskBytesSpilled.localValue
+ def diskBytesSpilled: Long = _diskBytesSpilled.sum
/**
* Peak memory used by internal data structures created during shuffles, aggregations and
@@ -96,7 +93,7 @@ class TaskMetrics private[spark] () extends Serializable {
* across all such data structures created in this task. For SQL jobs, this only tracks all
* unsafe operators and ExternalSort.
*/
- def peakExecutionMemory: Long = _peakExecutionMemory.localValue
+ def peakExecutionMemory: Long = _peakExecutionMemory.sum
/**
* Storage statuses of any blocks that have been updated as a result of this task.
@@ -114,7 +111,7 @@ class TaskMetrics private[spark] () extends Serializable {
private[spark] def incMemoryBytesSpilled(v: Long): Unit = _memoryBytesSpilled.add(v)
private[spark] def incDiskBytesSpilled(v: Long): Unit = _diskBytesSpilled.add(v)
private[spark] def incPeakExecutionMemory(v: Long): Unit = _peakExecutionMemory.add(v)
- private[spark] def incUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit =
+ private[spark] def incUpdatedBlockStatuses(v: (BlockId, BlockStatus)): Unit =
_updatedBlockStatuses.add(v)
private[spark] def setUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit =
_updatedBlockStatuses.setValue(v)
@@ -175,124 +172,143 @@ class TaskMetrics private[spark] () extends Serializable {
}
// Only used for test
- private[spark] val testAccum =
- sys.props.get("spark.testing").map(_ => TaskMetrics.createLongAccum(TEST_ACCUM))
-
- @transient private[spark] lazy val internalAccums: Seq[Accumulable[_, _]] = {
- val in = inputMetrics
- val out = outputMetrics
- val sr = shuffleReadMetrics
- val sw = shuffleWriteMetrics
- Seq(_executorDeserializeTime, _executorRunTime, _resultSize, _jvmGCTime,
- _resultSerializationTime, _memoryBytesSpilled, _diskBytesSpilled, _peakExecutionMemory,
- _updatedBlockStatuses, sr._remoteBlocksFetched, sr._localBlocksFetched, sr._remoteBytesRead,
- sr._localBytesRead, sr._fetchWaitTime, sr._recordsRead, sw._bytesWritten, sw._recordsWritten,
- sw._writeTime, in._bytesRead, in._recordsRead, out._bytesWritten, out._recordsWritten) ++
- testAccum
- }
+ private[spark] val testAccum = sys.props.get("spark.testing").map(_ => new LongAccumulator)
+
+
+ import InternalAccumulator._
+ @transient private[spark] lazy val nameToAccums = LinkedHashMap(
+ EXECUTOR_DESERIALIZE_TIME -> _executorDeserializeTime,
+ EXECUTOR_RUN_TIME -> _executorRunTime,
+ RESULT_SIZE -> _resultSize,
+ JVM_GC_TIME -> _jvmGCTime,
+ RESULT_SERIALIZATION_TIME -> _resultSerializationTime,
+ MEMORY_BYTES_SPILLED -> _memoryBytesSpilled,
+ DISK_BYTES_SPILLED -> _diskBytesSpilled,
+ PEAK_EXECUTION_MEMORY -> _peakExecutionMemory,
+ UPDATED_BLOCK_STATUSES -> _updatedBlockStatuses,
+ shuffleRead.REMOTE_BLOCKS_FETCHED -> shuffleReadMetrics._remoteBlocksFetched,
+ shuffleRead.LOCAL_BLOCKS_FETCHED -> shuffleReadMetrics._localBlocksFetched,
+ shuffleRead.REMOTE_BYTES_READ -> shuffleReadMetrics._remoteBytesRead,
+ shuffleRead.LOCAL_BYTES_READ -> shuffleReadMetrics._localBytesRead,
+ shuffleRead.FETCH_WAIT_TIME -> shuffleReadMetrics._fetchWaitTime,
+ shuffleRead.RECORDS_READ -> shuffleReadMetrics._recordsRead,
+ shuffleWrite.BYTES_WRITTEN -> shuffleWriteMetrics._bytesWritten,
+ shuffleWrite.RECORDS_WRITTEN -> shuffleWriteMetrics._recordsWritten,
+ shuffleWrite.WRITE_TIME -> shuffleWriteMetrics._writeTime,
+ input.BYTES_READ -> inputMetrics._bytesRead,
+ input.RECORDS_READ -> inputMetrics._recordsRead,
+ output.BYTES_WRITTEN -> outputMetrics._bytesWritten,
+ output.RECORDS_WRITTEN -> outputMetrics._recordsWritten
+ ) ++ testAccum.map(TEST_ACCUM -> _)
+
+ @transient private[spark] lazy val internalAccums: Seq[NewAccumulator[_, _]] =
+ nameToAccums.values.toIndexedSeq
/* ========================== *
| OTHER THINGS |
* ========================== */
- private[spark] def registerForCleanup(sc: SparkContext): Unit = {
- internalAccums.foreach { accum =>
- sc.cleaner.foreach(_.registerAccumulatorForCleanup(accum))
+ private[spark] def register(sc: SparkContext): Unit = {
+ nameToAccums.foreach {
+ case (name, acc) => acc.register(sc, name = Some(name), countFailedValues = true)
}
}
/**
* External accumulators registered with this task.
*/
- @transient private lazy val externalAccums = new ArrayBuffer[Accumulable[_, _]]
+ @transient private lazy val externalAccums = new ArrayBuffer[NewAccumulator[_, _]]
- private[spark] def registerAccumulator(a: Accumulable[_, _]): Unit = {
+ private[spark] def registerAccumulator(a: NewAccumulator[_, _]): Unit = {
externalAccums += a
}
- /**
- * Return the latest updates of accumulators in this task.
- *
- * The [[AccumulableInfo.update]] field is always defined and the [[AccumulableInfo.value]]
- * field is always empty, since this represents the partial updates recorded in this task,
- * not the aggregated value across multiple tasks.
- */
- def accumulatorUpdates(): Seq[AccumulableInfo] = {
- (internalAccums ++ externalAccums).map { a => a.toInfo(Some(a.localValue), None) }
- }
+ private[spark] def accumulators(): Seq[NewAccumulator[_, _]] = internalAccums ++ externalAccums
}
-/**
- * Internal subclass of [[TaskMetrics]] which is used only for posting events to listeners.
- * Its purpose is to obviate the need for the driver to reconstruct the original accumulators,
- * which might have been garbage-collected. See SPARK-13407 for more details.
- *
- * Instances of this class should be considered read-only and users should not call `inc*()` or
- * `set*()` methods. While we could override the setter methods to throw
- * UnsupportedOperationException, we choose not to do so because the overrides would quickly become
- * out-of-date when new metrics are added.
- */
-private[spark] class ListenerTaskMetrics(accumUpdates: Seq[AccumulableInfo]) extends TaskMetrics {
-
- override def accumulatorUpdates(): Seq[AccumulableInfo] = accumUpdates
-
- override private[spark] def registerAccumulator(a: Accumulable[_, _]): Unit = {
- throw new UnsupportedOperationException("This TaskMetrics is read-only")
- }
-}
private[spark] object TaskMetrics extends Logging {
+ import InternalAccumulator._
/**
* Create an empty task metrics that doesn't register its accumulators.
*/
def empty: TaskMetrics = {
- val metrics = new TaskMetrics
- metrics.internalAccums.foreach(acc => Accumulators.remove(acc.id))
- metrics
+ val tm = new TaskMetrics
+ tm.nameToAccums.foreach { case (name, acc) =>
+ acc.metadata = AccumulatorMetadata(AccumulatorContext.newId(), Some(name), true)
+ }
+ tm
+ }
+
+ def registered: TaskMetrics = {
+ val tm = empty
+ tm.internalAccums.foreach(AccumulatorContext.register)
+ tm
}
/**
- * Create a new accumulator representing an internal task metric.
+ * Construct a [[TaskMetrics]] object from a list of [[AccumulableInfo]], called on driver only.
+ * The returned [[TaskMetrics]] is only used to get some internal metrics, we don't need to take
+ * care of external accumulator info passed in.
*/
- private def newMetric[T](
- initialValue: T,
- name: String,
- param: AccumulatorParam[T]): Accumulator[T] = {
- new Accumulator[T](initialValue, param, Some(name), countFailedValues = true)
+ def fromAccumulatorInfos(infos: Seq[AccumulableInfo]): TaskMetrics = {
+ val tm = new TaskMetrics
+ infos.filter(info => info.name.isDefined && info.update.isDefined).foreach { info =>
+ val name = info.name.get
+ val value = info.update.get
+ if (name == UPDATED_BLOCK_STATUSES) {
+ tm.setUpdatedBlockStatuses(value.asInstanceOf[Seq[(BlockId, BlockStatus)]])
+ } else {
+ tm.nameToAccums.get(name).foreach(
+ _.asInstanceOf[LongAccumulator].setValue(value.asInstanceOf[Long])
+ )
+ }
+ }
+ tm
}
- def createLongAccum(name: String): Accumulator[Long] = {
- newMetric(0L, name, LongAccumulatorParam)
- }
+ /**
+ * Construct a [[TaskMetrics]] object from a list of accumulator updates, called on driver only.
+ */
+ def fromAccumulators(accums: Seq[NewAccumulator[_, _]]): TaskMetrics = {
+ val tm = new TaskMetrics
+ val (internalAccums, externalAccums) =
+ accums.partition(a => a.name.isDefined && tm.nameToAccums.contains(a.name.get))
+
+ internalAccums.foreach { acc =>
+ val tmAcc = tm.nameToAccums(acc.name.get).asInstanceOf[NewAccumulator[Any, Any]]
+ tmAcc.metadata = acc.metadata
+ tmAcc.merge(acc.asInstanceOf[NewAccumulator[Any, Any]])
+ }
- def createIntAccum(name: String): Accumulator[Int] = {
- newMetric(0, name, IntAccumulatorParam)
+ tm.externalAccums ++= externalAccums
+ tm
}
+}
+
+
+private[spark] class BlockStatusesAccumulator
+ extends NewAccumulator[(BlockId, BlockStatus), Seq[(BlockId, BlockStatus)]] {
+ private[this] var _seq = ArrayBuffer.empty[(BlockId, BlockStatus)]
- def createBlocksAccum(name: String): Accumulator[Seq[(BlockId, BlockStatus)]] = {
- newMetric(Nil, name, UpdatedBlockStatusesAccumulatorParam)
+ override def isZero(): Boolean = _seq.isEmpty
+
+ override def copyAndReset(): BlockStatusesAccumulator = new BlockStatusesAccumulator
+
+ override def add(v: (BlockId, BlockStatus)): Unit = _seq += v
+
+ override def merge(other: NewAccumulator[(BlockId, BlockStatus), Seq[(BlockId, BlockStatus)]])
+ : Unit = other match {
+ case o: BlockStatusesAccumulator => _seq ++= o.localValue
+ case _ => throw new UnsupportedOperationException(
+ s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
}
- /**
- * Construct a [[TaskMetrics]] object from a list of accumulator updates, called on driver only.
- *
- * Executors only send accumulator updates back to the driver, not [[TaskMetrics]]. However, we
- * need the latter to post task end events to listeners, so we need to reconstruct the metrics
- * on the driver.
- *
- * This assumes the provided updates contain the initial set of accumulators representing
- * internal task level metrics.
- */
- def fromAccumulatorUpdates(accumUpdates: Seq[AccumulableInfo]): TaskMetrics = {
- val definedAccumUpdates = accumUpdates.filter(_.update.isDefined)
- val metrics = new ListenerTaskMetrics(definedAccumUpdates)
- // We don't register this [[ListenerTaskMetrics]] for cleanup, and this is only used to post
- // event, we should un-register all accumulators immediately.
- metrics.internalAccums.foreach(acc => Accumulators.remove(acc.id))
- definedAccumUpdates.filter(_.internal).foreach { accum =>
- metrics.internalAccums.find(_.name == accum.name).foreach(_.setValueAny(accum.update.get))
- }
- metrics
+ override def localValue: Seq[(BlockId, BlockStatus)] = _seq
+
+ def setValue(newValue: Seq[(BlockId, BlockStatus)]): Unit = {
+ _seq.clear()
+ _seq ++= newValue
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index b7fb608ea5..a96d5f6fbf 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -209,7 +209,7 @@ class DAGScheduler(
task: Task[_],
reason: TaskEndReason,
result: Any,
- accumUpdates: Seq[AccumulableInfo],
+ accumUpdates: Seq[NewAccumulator[_, _]],
taskInfo: TaskInfo): Unit = {
eventProcessLoop.post(
CompletionEvent(task, reason, result, accumUpdates, taskInfo))
@@ -1088,21 +1088,19 @@ class DAGScheduler(
val task = event.task
val stage = stageIdToStage(task.stageId)
try {
- event.accumUpdates.foreach { ainfo =>
- assert(ainfo.update.isDefined, "accumulator from task should have a partial value")
- val id = ainfo.id
- val partialValue = ainfo.update.get
+ event.accumUpdates.foreach { updates =>
+ val id = updates.id
// Find the corresponding accumulator on the driver and update it
- val acc: Accumulable[Any, Any] = Accumulators.get(id) match {
- case Some(accum) => accum.asInstanceOf[Accumulable[Any, Any]]
+ val acc: NewAccumulator[Any, Any] = AccumulatorContext.get(id) match {
+ case Some(accum) => accum.asInstanceOf[NewAccumulator[Any, Any]]
case None =>
throw new SparkException(s"attempted to access non-existent accumulator $id")
}
- acc ++= partialValue
+ acc.merge(updates.asInstanceOf[NewAccumulator[Any, Any]])
// To avoid UI cruft, ignore cases where value wasn't updated
- if (acc.name.isDefined && partialValue != acc.zero) {
+ if (acc.name.isDefined && !updates.isZero()) {
stage.latestInfo.accumulables(id) = acc.toInfo(None, Some(acc.value))
- event.taskInfo.accumulables += acc.toInfo(Some(partialValue), Some(acc.value))
+ event.taskInfo.accumulables += acc.toInfo(Some(updates.value), Some(acc.value))
}
}
} catch {
@@ -1131,7 +1129,7 @@ class DAGScheduler(
val taskMetrics: TaskMetrics =
if (event.accumUpdates.nonEmpty) {
try {
- TaskMetrics.fromAccumulatorUpdates(event.accumUpdates)
+ TaskMetrics.fromAccumulators(event.accumUpdates)
} catch {
case NonFatal(e) =>
logError(s"Error when attempting to reconstruct metrics for task $taskId", e)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
index a3845c6acd..e57a2246d8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
@@ -71,7 +71,7 @@ private[scheduler] case class CompletionEvent(
task: Task[_],
reason: TaskEndReason,
result: Any,
- accumUpdates: Seq[AccumulableInfo],
+ accumUpdates: Seq[NewAccumulator[_, _]],
taskInfo: TaskInfo)
extends DAGSchedulerEvent
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
index 080ea6c33a..7618dfeeed 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
@@ -21,18 +21,15 @@ import java.util.Properties
import javax.annotation.Nullable
import scala.collection.Map
-import scala.collection.mutable
import com.fasterxml.jackson.annotation.JsonTypeInfo
import org.apache.spark.{SparkConf, TaskEndReason}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
-import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.cluster.ExecutorInfo
import org.apache.spark.storage.{BlockManagerId, BlockUpdatedInfo}
import org.apache.spark.ui.SparkUI
-import org.apache.spark.util.{Distribution, Utils}
@DeveloperApi
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "Event")
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index 02185bf631..2f972b064b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -112,7 +112,7 @@ private[scheduler] abstract class Stage(
numPartitionsToCompute: Int,
taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty): Unit = {
val metrics = new TaskMetrics
- metrics.registerForCleanup(rdd.sparkContext)
+ metrics.register(rdd.sparkContext)
_latestInfo = StageInfo.fromStage(
this, nextAttemptId, Some(numPartitionsToCompute), metrics, taskLocalityPreferences)
nextAttemptId += 1
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index eb10f3e69b..e7ca6efd84 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -23,7 +23,7 @@ import java.util.Properties
import scala.collection.mutable.HashMap
-import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl}
+import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.memory.{MemoryMode, TaskMemoryManager}
import org.apache.spark.metrics.MetricsSystem
@@ -52,7 +52,7 @@ private[spark] abstract class Task[T](
val stageAttemptId: Int,
val partitionId: Int,
// The default value is only used in tests.
- val metrics: TaskMetrics = TaskMetrics.empty,
+ val metrics: TaskMetrics = TaskMetrics.registered,
@transient var localProperties: Properties = new Properties) extends Serializable {
/**
@@ -153,11 +153,11 @@ private[spark] abstract class Task[T](
* Collect the latest values of accumulators used in this task. If the task failed,
* filter out the accumulators whose values should not be included on failures.
*/
- def collectAccumulatorUpdates(taskFailed: Boolean = false): Seq[AccumulableInfo] = {
+ def collectAccumulatorUpdates(taskFailed: Boolean = false): Seq[NewAccumulator[_, _]] = {
if (context != null) {
- context.taskMetrics.accumulatorUpdates().filter { a => !taskFailed || a.countFailedValues }
+ context.taskMetrics.accumulators().filter { a => !taskFailed || a.countFailedValues }
} else {
- Seq.empty[AccumulableInfo]
+ Seq.empty
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
index 03135e63d7..b472c5511b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
@@ -22,7 +22,7 @@ import java.nio.ByteBuffer
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.SparkEnv
+import org.apache.spark.{NewAccumulator, SparkEnv}
import org.apache.spark.storage.BlockId
import org.apache.spark.util.Utils
@@ -36,7 +36,7 @@ private[spark] case class IndirectTaskResult[T](blockId: BlockId, size: Int)
/** A TaskResult that contains the task's return value and accumulator updates. */
private[spark] class DirectTaskResult[T](
var valueBytes: ByteBuffer,
- var accumUpdates: Seq[AccumulableInfo])
+ var accumUpdates: Seq[NewAccumulator[_, _]])
extends TaskResult[T] with Externalizable {
private var valueObjectDeserialized = false
@@ -61,9 +61,9 @@ private[spark] class DirectTaskResult[T](
if (numUpdates == 0) {
accumUpdates = null
} else {
- val _accumUpdates = new ArrayBuffer[AccumulableInfo]
+ val _accumUpdates = new ArrayBuffer[NewAccumulator[_, _]]
for (i <- 0 until numUpdates) {
- _accumUpdates += in.readObject.asInstanceOf[AccumulableInfo]
+ _accumUpdates += in.readObject.asInstanceOf[NewAccumulator[_, _]]
}
accumUpdates = _accumUpdates
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
index ae7ef46abb..b438c285fd 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
@@ -93,9 +93,10 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
// we would have to serialize the result again after updating the size.
result.accumUpdates = result.accumUpdates.map { a =>
if (a.name == Some(InternalAccumulator.RESULT_SIZE)) {
- assert(a.update == Some(0L),
- "task result size should not have been set on the executors")
- a.copy(update = Some(size.toLong))
+ val acc = a.asInstanceOf[LongAccumulator]
+ assert(acc.sum == 0L, "task result size should not have been set on the executors")
+ acc.setValue(size.toLong)
+ acc
} else {
a
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
index 647d44a0f0..75a0c56311 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
@@ -17,6 +17,7 @@
package org.apache.spark.scheduler
+import org.apache.spark.NewAccumulator
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
import org.apache.spark.storage.BlockManagerId
@@ -66,7 +67,7 @@ private[spark] trait TaskScheduler {
*/
def executorHeartbeatReceived(
execId: String,
- accumUpdates: Array[(Long, Seq[AccumulableInfo])],
+ accumUpdates: Array[(Long, Seq[NewAccumulator[_, _]])],
blockManagerId: BlockManagerId): Boolean
/**
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index f31ec2af4e..776a3226cc 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -384,13 +384,14 @@ private[spark] class TaskSchedulerImpl(
*/
override def executorHeartbeatReceived(
execId: String,
- accumUpdates: Array[(Long, Seq[AccumulableInfo])],
+ accumUpdates: Array[(Long, Seq[NewAccumulator[_, _]])],
blockManagerId: BlockManagerId): Boolean = {
// (taskId, stageId, stageAttemptId, accumUpdates)
val accumUpdatesWithTaskIds: Array[(Long, Int, Int, Seq[AccumulableInfo])] = synchronized {
accumUpdates.flatMap { case (id, updates) =>
taskIdToTaskSetManager.get(id).map { taskSetMgr =>
- (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, updates)
+ (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId,
+ updates.map(acc => acc.toInfo(Some(acc.value), None)))
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 6e08cdd87a..b79f643a74 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -647,7 +647,7 @@ private[spark] class TaskSetManager(
info.markFailed()
val index = info.index
copiesRunning(index) -= 1
- var accumUpdates: Seq[AccumulableInfo] = Seq.empty[AccumulableInfo]
+ var accumUpdates: Seq[NewAccumulator[_, _]] = Seq.empty
val failureReason = s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid, ${info.host}): " +
reason.asInstanceOf[TaskFailedReason].toErrorString
val failureException: Option[Throwable] = reason match {
@@ -663,7 +663,7 @@ private[spark] class TaskSetManager(
case ef: ExceptionFailure =>
// ExceptionFailure's might have accumulator updates
- accumUpdates = ef.accumUpdates
+ accumUpdates = ef.accums
if (ef.className == classOf[NotSerializableException].getName) {
// If the task result wasn't serializable, there's no point in trying to re-execute it.
logError("Task %s in stage %s (TID %d) had a not serializable result: %s; not retrying"
@@ -788,7 +788,7 @@ private[spark] class TaskSetManager(
// Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
// stage finishes when a total of tasks.size tasks finish.
sched.dagScheduler.taskEnded(
- tasks(index), Resubmitted, null, Seq.empty[AccumulableInfo], info)
+ tasks(index), Resubmitted, null, Seq.empty, info)
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala
index 8daca6c390..c04b483831 100644
--- a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala
@@ -266,7 +266,13 @@ private[spark] object SerializationDebugger extends Logging {
(o, desc)
} else {
// write place
- findObjectAndDescriptor(desc.invokeWriteReplace(o))
+ val replaced = desc.invokeWriteReplace(o)
+ // `writeReplace` may return the same object.
+ if (replaced eq o) {
+ (o, desc)
+ } else {
+ findObjectAndDescriptor(replaced)
+ }
}
}
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala
index ff28796a60..32e332a9ad 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala
@@ -186,8 +186,8 @@ class OutputMetrics private[spark](
val recordsWritten: Long)
class ShuffleReadMetrics private[spark](
- val remoteBlocksFetched: Int,
- val localBlocksFetched: Int,
+ val remoteBlocksFetched: Long,
+ val localBlocksFetched: Long,
val fetchWaitTime: Long,
val remoteBytesRead: Long,
val localBytesRead: Long,
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 1c4921666f..f2d06c7ea8 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -801,7 +801,7 @@ private[spark] class BlockManager(
reportBlockStatus(blockId, info, putBlockStatus)
}
Option(TaskContext.get()).foreach { c =>
- c.taskMetrics().incUpdatedBlockStatuses(Seq((blockId, putBlockStatus)))
+ c.taskMetrics().incUpdatedBlockStatuses(blockId -> putBlockStatus)
}
}
logDebug("Put block %s locally took %s".format(blockId, Utils.getUsedTimeMs(startTimeMs)))
@@ -958,7 +958,7 @@ private[spark] class BlockManager(
reportBlockStatus(blockId, info, putBlockStatus)
}
Option(TaskContext.get()).foreach { c =>
- c.taskMetrics().incUpdatedBlockStatuses(Seq((blockId, putBlockStatus)))
+ c.taskMetrics().incUpdatedBlockStatuses(blockId -> putBlockStatus)
}
logDebug("Put block %s locally took %s".format(blockId, Utils.getUsedTimeMs(startTimeMs)))
if (level.replication > 1) {
@@ -1257,7 +1257,7 @@ private[spark] class BlockManager(
}
if (blockIsUpdated) {
Option(TaskContext.get()).foreach { c =>
- c.taskMetrics().incUpdatedBlockStatuses(Seq((blockId, status)))
+ c.taskMetrics().incUpdatedBlockStatuses(blockId -> status)
}
}
status.storageLevel
@@ -1311,7 +1311,7 @@ private[spark] class BlockManager(
reportBlockStatus(blockId, info, removeBlockStatus)
}
Option(TaskContext.get()).foreach { c =>
- c.taskMetrics().incUpdatedBlockStatuses(Seq((blockId, removeBlockStatus)))
+ c.taskMetrics().incUpdatedBlockStatuses(blockId -> removeBlockStatus)
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
index 9ab7d96e29..945830c8bf 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
@@ -375,26 +375,21 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
execSummary.taskTime += info.duration
stageData.numActiveTasks -= 1
- val (errorMessage, accums): (Option[String], Seq[AccumulableInfo]) =
+ val errorMessage: Option[String] =
taskEnd.reason match {
case org.apache.spark.Success =>
stageData.completedIndices.add(info.index)
stageData.numCompleteTasks += 1
- (None, taskEnd.taskMetrics.accumulatorUpdates())
+ None
case e: ExceptionFailure => // Handle ExceptionFailure because we might have accumUpdates
stageData.numFailedTasks += 1
- (Some(e.toErrorString), e.accumUpdates)
+ Some(e.toErrorString)
case e: TaskFailedReason => // All other failure cases
stageData.numFailedTasks += 1
- (Some(e.toErrorString), Seq.empty[AccumulableInfo])
+ Some(e.toErrorString)
}
- val taskMetrics =
- if (accums.nonEmpty) {
- Some(TaskMetrics.fromAccumulatorUpdates(accums))
- } else {
- None
- }
+ val taskMetrics = Option(taskEnd.taskMetrics)
taskMetrics.foreach { m =>
val oldMetrics = stageData.taskData.get(info.taskId).flatMap(_.metrics)
updateAggregateMetrics(stageData, info.executorId, m, oldMetrics)
@@ -503,7 +498,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
new StageUIData
})
val taskData = stageData.taskData.get(taskId)
- val metrics = TaskMetrics.fromAccumulatorUpdates(accumUpdates)
+ val metrics = TaskMetrics.fromAccumulatorInfos(accumUpdates)
taskData.foreach { t =>
if (!t.taskInfo.finished) {
updateAggregateMetrics(stageData, executorMetricsUpdate.execId, metrics, t.metrics)
diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
index a613fbc5cc..aeab71d9df 100644
--- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
@@ -840,7 +840,9 @@ private[spark] object JsonProtocol {
// Fallback on getting accumulator updates from TaskMetrics, which was logged in Spark 1.x
val accumUpdates = Utils.jsonOption(json \ "Accumulator Updates")
.map(_.extract[List[JValue]].map(accumulableInfoFromJson))
- .getOrElse(taskMetricsFromJson(json \ "Metrics").accumulatorUpdates())
+ .getOrElse(taskMetricsFromJson(json \ "Metrics").accumulators().map(acc => {
+ acc.toInfo(Some(acc.localValue), None)
+ }))
ExceptionFailure(className, description, stackTrace, fullStackTrace, None, accumUpdates)
case `taskResultLost` => TaskResultLost
case `taskKilled` => TaskKilled
diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
index 6063476936..5f97e58845 100644
--- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
@@ -28,17 +28,17 @@ import scala.util.control.NonFatal
import org.scalatest.Matchers
import org.scalatest.exceptions.TestFailedException
-import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.AccumulatorParam.{ListAccumulatorParam, StringAccumulatorParam}
import org.apache.spark.scheduler._
import org.apache.spark.serializer.JavaSerializer
class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContext {
- import AccumulatorParam._
+ import AccumulatorSuite.createLongAccum
override def afterEach(): Unit = {
try {
- Accumulators.clear()
+ AccumulatorContext.clear()
} finally {
super.afterEach()
}
@@ -59,9 +59,30 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex
}
}
+ test("accumulator serialization") {
+ val ser = new JavaSerializer(new SparkConf).newInstance()
+ val acc = createLongAccum("x")
+ acc.add(5)
+ assert(acc.value == 5)
+ assert(acc.isAtDriverSide)
+
+ // serialize and de-serialize it, to simulate sending accumulator to executor.
+ val acc2 = ser.deserialize[LongAccumulator](ser.serialize(acc))
+ // value is reset on the executors
+ assert(acc2.localValue == 0)
+ assert(!acc2.isAtDriverSide)
+
+ acc2.add(10)
+ // serialize and de-serialize it again, to simulate sending accumulator back to driver.
+ val acc3 = ser.deserialize[LongAccumulator](ser.serialize(acc2))
+ // value is not reset on the driver
+ assert(acc3.value == 10)
+ assert(acc3.isAtDriverSide)
+ }
+
test ("basic accumulation") {
sc = new SparkContext("local", "test")
- val acc : Accumulator[Int] = sc.accumulator(0)
+ val acc: Accumulator[Int] = sc.accumulator(0)
val d = sc.parallelize(1 to 20)
d.foreach{x => acc += x}
@@ -75,7 +96,7 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex
test("value not assignable from tasks") {
sc = new SparkContext("local", "test")
- val acc : Accumulator[Int] = sc.accumulator(0)
+ val acc: Accumulator[Int] = sc.accumulator(0)
val d = sc.parallelize(1 to 20)
an [Exception] should be thrownBy {d.foreach{x => acc.value = x}}
@@ -169,14 +190,13 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex
System.gc()
assert(ref.get.isEmpty)
- Accumulators.remove(accId)
- assert(!Accumulators.originals.get(accId).isDefined)
+ AccumulatorContext.remove(accId)
+ assert(!AccumulatorContext.originals.containsKey(accId))
}
test("get accum") {
- sc = new SparkContext("local", "test")
// Don't register with SparkContext for cleanup
- var acc = new Accumulable[Int, Int](0, IntAccumulatorParam, None, true)
+ var acc = createLongAccum("a")
val accId = acc.id
val ref = WeakReference(acc)
assert(ref.get.isDefined)
@@ -188,44 +208,16 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex
// Getting a garbage collected accum should throw error
intercept[IllegalAccessError] {
- Accumulators.get(accId)
+ AccumulatorContext.get(accId)
}
// Getting a normal accumulator. Note: this has to be separate because referencing an
// accumulator above in an `assert` would keep it from being garbage collected.
- val acc2 = new Accumulable[Long, Long](0L, LongAccumulatorParam, None, true)
- assert(Accumulators.get(acc2.id) === Some(acc2))
+ val acc2 = createLongAccum("b")
+ assert(AccumulatorContext.get(acc2.id) === Some(acc2))
// Getting an accumulator that does not exist should return None
- assert(Accumulators.get(100000).isEmpty)
- }
-
- test("copy") {
- val acc1 = new Accumulable[Long, Long](456L, LongAccumulatorParam, Some("x"), false)
- val acc2 = acc1.copy()
- assert(acc1.id === acc2.id)
- assert(acc1.value === acc2.value)
- assert(acc1.name === acc2.name)
- assert(acc1.countFailedValues === acc2.countFailedValues)
- assert(acc1 !== acc2)
- // Modifying one does not affect the other
- acc1.add(44L)
- assert(acc1.value === 500L)
- assert(acc2.value === 456L)
- acc2.add(144L)
- assert(acc1.value === 500L)
- assert(acc2.value === 600L)
- }
-
- test("register multiple accums with same ID") {
- val acc1 = new Accumulable[Int, Int](0, IntAccumulatorParam, None, true)
- // `copy` will create a new Accumulable and register it.
- val acc2 = acc1.copy()
- assert(acc1 !== acc2)
- assert(acc1.id === acc2.id)
- // The second one does not override the first one
- assert(Accumulators.originals.size === 1)
- assert(Accumulators.get(acc1.id) === Some(acc1))
+ assert(AccumulatorContext.get(100000).isEmpty)
}
test("string accumulator param") {
@@ -257,38 +249,33 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex
acc.setValue(Seq(9, 10))
assert(acc.value === Seq(9, 10))
}
-
- test("value is reset on the executors") {
- val acc1 = new Accumulator(0, IntAccumulatorParam, Some("thing"))
- val acc2 = new Accumulator(0L, LongAccumulatorParam, Some("thing2"))
- val externalAccums = Seq(acc1, acc2)
- val taskMetrics = new TaskMetrics
- // Set some values; these should not be observed later on the "executors"
- acc1.setValue(10)
- acc2.setValue(20L)
- taskMetrics.testAccum.get.setValue(30L)
- // Simulate the task being serialized and sent to the executors.
- val dummyTask = new DummyTask(taskMetrics, externalAccums)
- val serInstance = new JavaSerializer(new SparkConf).newInstance()
- val taskSer = Task.serializeWithDependencies(
- dummyTask, mutable.HashMap(), mutable.HashMap(), serInstance)
- // Now we're on the executors.
- // Deserialize the task and assert that its accumulators are zero'ed out.
- val (_, _, _, taskBytes) = Task.deserializeWithDependencies(taskSer)
- val taskDeser = serInstance.deserialize[DummyTask](
- taskBytes, Thread.currentThread.getContextClassLoader)
- // Assert that executors see only zeros
- taskDeser.externalAccums.foreach { a => assert(a.localValue == a.zero) }
- taskDeser.metrics.internalAccums.foreach { a => assert(a.localValue == a.zero) }
- }
-
}
private[spark] object AccumulatorSuite {
-
import InternalAccumulator._
/**
+ * Create a long accumulator and register it to [[AccumulatorContext]].
+ */
+ def createLongAccum(
+ name: String,
+ countFailedValues: Boolean = false,
+ initValue: Long = 0,
+ id: Long = AccumulatorContext.newId()): LongAccumulator = {
+ val acc = new LongAccumulator
+ acc.setValue(initValue)
+ acc.metadata = AccumulatorMetadata(id, Some(name), countFailedValues)
+ AccumulatorContext.register(acc)
+ acc
+ }
+
+ /**
+ * Make an [[AccumulableInfo]] out of an [[Accumulable]] with the intent to use the
+ * info as an accumulator update.
+ */
+ def makeInfo(a: NewAccumulator[_, _]): AccumulableInfo = a.toInfo(Some(a.localValue), None)
+
+ /**
* Run one or more Spark jobs and verify that in at least one job the peak execution memory
* accumulator is updated afterwards.
*/
@@ -340,7 +327,6 @@ private class SaveInfoListener extends SparkListener {
if (jobCompletionCallback != null) {
jobCompletionSem.acquire()
if (exception != null) {
- exception = null
throw exception
}
}
@@ -377,13 +363,3 @@ private class SaveInfoListener extends SparkListener {
(taskEnd.stageId, taskEnd.stageAttemptId), new ArrayBuffer[TaskInfo]) += taskEnd.taskInfo
}
}
-
-
-/**
- * A dummy [[Task]] that contains internal and external [[Accumulator]]s.
- */
-private[spark] class DummyTask(
- metrics: TaskMetrics,
- val externalAccums: Seq[Accumulator[_]]) extends Task[Int](0, 0, 0, metrics) {
- override def runTask(c: TaskContext): Int = 1
-}
diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
index 4d2b3e7f3b..1adc90ab1e 100644
--- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
+++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
@@ -211,10 +211,10 @@ class HeartbeatReceiverSuite
private def triggerHeartbeat(
executorId: String,
executorShouldReregister: Boolean): Unit = {
- val metrics = new TaskMetrics
+ val metrics = TaskMetrics.empty
val blockManagerId = BlockManagerId(executorId, "localhost", 12345)
val response = heartbeatReceiverRef.askWithRetry[HeartbeatResponse](
- Heartbeat(executorId, Array(1L -> metrics.accumulatorUpdates()), blockManagerId))
+ Heartbeat(executorId, Array(1L -> metrics.accumulators()), blockManagerId))
if (executorShouldReregister) {
assert(response.reregisterBlockManager)
} else {
@@ -222,7 +222,7 @@ class HeartbeatReceiverSuite
// Additionally verify that the scheduler callback is called with the correct parameters
verify(scheduler).executorHeartbeatReceived(
Matchers.eq(executorId),
- Matchers.eq(Array(1L -> metrics.accumulatorUpdates())),
+ Matchers.eq(Array(1L -> metrics.accumulators())),
Matchers.eq(blockManagerId))
}
}
diff --git a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala
index b074b95424..e4474bb813 100644
--- a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark
+import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.executor.TaskMetrics
@@ -29,7 +30,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
override def afterEach(): Unit = {
try {
- Accumulators.clear()
+ AccumulatorContext.clear()
} finally {
super.afterEach()
}
@@ -37,9 +38,8 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
test("internal accumulators in TaskContext") {
val taskContext = TaskContext.empty()
- val accumUpdates = taskContext.taskMetrics.accumulatorUpdates()
+ val accumUpdates = taskContext.taskMetrics.accumulators()
assert(accumUpdates.size > 0)
- assert(accumUpdates.forall(_.internal))
val testAccum = taskContext.taskMetrics.testAccum.get
assert(accumUpdates.exists(_.id == testAccum.id))
}
@@ -51,7 +51,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
sc.addSparkListener(listener)
// Have each task add 1 to the internal accumulator
val rdd = sc.parallelize(1 to 100, numPartitions).mapPartitions { iter =>
- TaskContext.get().taskMetrics().testAccum.get += 1
+ TaskContext.get().taskMetrics().testAccum.get.add(1)
iter
}
// Register asserts in job completion callback to avoid flakiness
@@ -87,17 +87,17 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
val rdd = sc.parallelize(1 to 100, numPartitions)
.map { i => (i, i) }
.mapPartitions { iter =>
- TaskContext.get().taskMetrics().testAccum.get += 1
+ TaskContext.get().taskMetrics().testAccum.get.add(1)
iter
}
.reduceByKey { case (x, y) => x + y }
.mapPartitions { iter =>
- TaskContext.get().taskMetrics().testAccum.get += 10
+ TaskContext.get().taskMetrics().testAccum.get.add(10)
iter
}
.repartition(numPartitions * 2)
.mapPartitions { iter =>
- TaskContext.get().taskMetrics().testAccum.get += 100
+ TaskContext.get().taskMetrics().testAccum.get.add(100)
iter
}
// Register asserts in job completion callback to avoid flakiness
@@ -127,7 +127,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
// This should retry both stages in the scheduler. Note that we only want to fail the
// first stage attempt because we want the stage to eventually succeed.
val x = sc.parallelize(1 to 100, numPartitions)
- .mapPartitions { iter => TaskContext.get().taskMetrics().testAccum.get += 1; iter }
+ .mapPartitions { iter => TaskContext.get().taskMetrics().testAccum.get.add(1); iter }
.groupBy(identity)
val sid = x.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleHandle.shuffleId
val rdd = x.mapPartitionsWithIndex { case (i, iter) =>
@@ -183,18 +183,18 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
private val myCleaner = new SaveAccumContextCleaner(this)
override def cleaner: Option[ContextCleaner] = Some(myCleaner)
}
- assert(Accumulators.originals.isEmpty)
+ assert(AccumulatorContext.originals.isEmpty)
sc.parallelize(1 to 100).map { i => (i, i) }.reduceByKey { _ + _ }.count()
val numInternalAccums = TaskMetrics.empty.internalAccums.length
// We ran 2 stages, so we should have 2 sets of internal accumulators, 1 for each stage
- assert(Accumulators.originals.size === numInternalAccums * 2)
+ assert(AccumulatorContext.originals.size === numInternalAccums * 2)
val accumsRegistered = sc.cleaner match {
case Some(cleaner: SaveAccumContextCleaner) => cleaner.accumsRegisteredForCleanup
case _ => Seq.empty[Long]
}
// Make sure the same set of accumulators is registered for cleanup
assert(accumsRegistered.size === numInternalAccums * 2)
- assert(accumsRegistered.toSet === Accumulators.originals.keys.toSet)
+ assert(accumsRegistered.toSet === AccumulatorContext.originals.keySet().asScala)
}
/**
@@ -212,7 +212,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
private class SaveAccumContextCleaner(sc: SparkContext) extends ContextCleaner(sc) {
private val accumsRegistered = new ArrayBuffer[Long]
- override def registerAccumulatorForCleanup(a: Accumulable[_, _]): Unit = {
+ override def registerAccumulatorForCleanup(a: NewAccumulator[_, _]): Unit = {
accumsRegistered += a.id
super.registerAccumulatorForCleanup(a)
}
diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala
index 3228752b96..4aae2c9b4a 100644
--- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala
@@ -34,7 +34,7 @@ private[spark] abstract class SparkFunSuite
protected override def afterAll(): Unit = {
try {
// Avoid leaking map entries in tests that use accumulators without SparkContext
- Accumulators.clear()
+ AccumulatorContext.clear()
} finally {
super.afterAll()
}
diff --git a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala
index ee70419727..94f6e1a3a7 100644
--- a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala
@@ -20,14 +20,11 @@ package org.apache.spark.executor
import org.scalatest.Assertions
import org.apache.spark._
-import org.apache.spark.scheduler.AccumulableInfo
-import org.apache.spark.storage.{BlockId, BlockStatus, StorageLevel, TestBlockId}
+import org.apache.spark.storage.{BlockStatus, StorageLevel, TestBlockId}
class TaskMetricsSuite extends SparkFunSuite {
- import AccumulatorParam._
import StorageLevel._
- import TaskMetricsSuite._
test("mutating values") {
val tm = new TaskMetrics
@@ -59,8 +56,8 @@ class TaskMetricsSuite extends SparkFunSuite {
tm.incPeakExecutionMemory(8L)
val block1 = (TestBlockId("a"), BlockStatus(MEMORY_ONLY, 1L, 2L))
val block2 = (TestBlockId("b"), BlockStatus(MEMORY_ONLY, 3L, 4L))
- tm.incUpdatedBlockStatuses(Seq(block1))
- tm.incUpdatedBlockStatuses(Seq(block2))
+ tm.incUpdatedBlockStatuses(block1)
+ tm.incUpdatedBlockStatuses(block2)
// assert new values exist
assert(tm.executorDeserializeTime == 1L)
assert(tm.executorRunTime == 2L)
@@ -194,18 +191,19 @@ class TaskMetricsSuite extends SparkFunSuite {
}
test("additional accumulables") {
- val tm = new TaskMetrics
- val acc1 = new Accumulator(0, IntAccumulatorParam, Some("a"))
- val acc2 = new Accumulator(0, IntAccumulatorParam, Some("b"))
- val acc3 = new Accumulator(0, IntAccumulatorParam, Some("c"))
- val acc4 = new Accumulator(0, IntAccumulatorParam, Some("d"), countFailedValues = true)
+ val tm = TaskMetrics.empty
+ val acc1 = AccumulatorSuite.createLongAccum("a")
+ val acc2 = AccumulatorSuite.createLongAccum("b")
+ val acc3 = AccumulatorSuite.createLongAccum("c")
+ val acc4 = AccumulatorSuite.createLongAccum("d", true)
tm.registerAccumulator(acc1)
tm.registerAccumulator(acc2)
tm.registerAccumulator(acc3)
tm.registerAccumulator(acc4)
- acc1 += 1
- acc2 += 2
- val newUpdates = tm.accumulatorUpdates().map { a => (a.id, a) }.toMap
+ acc1.add(1)
+ acc2.add(2)
+ val newUpdates = tm.accumulators()
+ .map(a => (a.id, a.asInstanceOf[NewAccumulator[Any, Any]])).toMap
assert(newUpdates.contains(acc1.id))
assert(newUpdates.contains(acc2.id))
assert(newUpdates.contains(acc3.id))
@@ -214,46 +212,14 @@ class TaskMetricsSuite extends SparkFunSuite {
assert(newUpdates(acc2.id).name === Some("b"))
assert(newUpdates(acc3.id).name === Some("c"))
assert(newUpdates(acc4.id).name === Some("d"))
- assert(newUpdates(acc1.id).update === Some(1))
- assert(newUpdates(acc2.id).update === Some(2))
- assert(newUpdates(acc3.id).update === Some(0))
- assert(newUpdates(acc4.id).update === Some(0))
+ assert(newUpdates(acc1.id).value === 1)
+ assert(newUpdates(acc2.id).value === 2)
+ assert(newUpdates(acc3.id).value === 0)
+ assert(newUpdates(acc4.id).value === 0)
assert(!newUpdates(acc3.id).countFailedValues)
assert(newUpdates(acc4.id).countFailedValues)
- assert(newUpdates.values.map(_.update).forall(_.isDefined))
- assert(newUpdates.values.map(_.value).forall(_.isEmpty))
assert(newUpdates.size === tm.internalAccums.size + 4)
}
-
- test("from accumulator updates") {
- val accumUpdates1 = TaskMetrics.empty.internalAccums.map { a =>
- AccumulableInfo(a.id, a.name, Some(3L), None, true, a.countFailedValues)
- }
- val metrics1 = TaskMetrics.fromAccumulatorUpdates(accumUpdates1)
- assertUpdatesEquals(metrics1.accumulatorUpdates(), accumUpdates1)
- // Test this with additional accumulators to ensure that we do not crash when handling
- // updates from unregistered accumulators. In practice, all accumulators created
- // on the driver, internal or not, should be registered with `Accumulators` at some point.
- val param = IntAccumulatorParam
- val registeredAccums = Seq(
- new Accumulator(0, param, Some("a"), countFailedValues = true),
- new Accumulator(0, param, Some("b"), countFailedValues = false))
- val unregisteredAccums = Seq(
- new Accumulator(0, param, Some("c"), countFailedValues = true),
- new Accumulator(0, param, Some("d"), countFailedValues = false))
- registeredAccums.foreach(Accumulators.register)
- registeredAccums.foreach(a => assert(Accumulators.originals.contains(a.id)))
- unregisteredAccums.foreach(a => Accumulators.remove(a.id))
- unregisteredAccums.foreach(a => assert(!Accumulators.originals.contains(a.id)))
- // set some values in these accums
- registeredAccums.zipWithIndex.foreach { case (a, i) => a.setValue(i) }
- unregisteredAccums.zipWithIndex.foreach { case (a, i) => a.setValue(i) }
- val registeredAccumInfos = registeredAccums.map(makeInfo)
- val unregisteredAccumInfos = unregisteredAccums.map(makeInfo)
- val accumUpdates2 = accumUpdates1 ++ registeredAccumInfos ++ unregisteredAccumInfos
- // Simply checking that this does not crash:
- TaskMetrics.fromAccumulatorUpdates(accumUpdates2)
- }
}
@@ -264,21 +230,14 @@ private[spark] object TaskMetricsSuite extends Assertions {
* Note: this does NOT check accumulator ID equality.
*/
def assertUpdatesEquals(
- updates1: Seq[AccumulableInfo],
- updates2: Seq[AccumulableInfo]): Unit = {
+ updates1: Seq[NewAccumulator[_, _]],
+ updates2: Seq[NewAccumulator[_, _]]): Unit = {
assert(updates1.size === updates2.size)
- updates1.zip(updates2).foreach { case (info1, info2) =>
+ updates1.zip(updates2).foreach { case (acc1, acc2) =>
// do not assert ID equals here
- assert(info1.name === info2.name)
- assert(info1.update === info2.update)
- assert(info1.value === info2.value)
- assert(info1.countFailedValues === info2.countFailedValues)
+ assert(acc1.name === acc2.name)
+ assert(acc1.countFailedValues === acc2.countFailedValues)
+ assert(acc1.value == acc2.value)
}
}
-
- /**
- * Make an [[AccumulableInfo]] out of an [[Accumulable]] with the intent to use the
- * info as an accumulator update.
- */
- def makeInfo(a: Accumulable[_, _]): AccumulableInfo = a.toInfo(Some(a.value), None)
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index b76c0a4bd1..9912d1f3bc 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -112,7 +112,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
override def stop() = {}
override def executorHeartbeatReceived(
execId: String,
- accumUpdates: Array[(Long, Seq[AccumulableInfo])],
+ accumUpdates: Array[(Long, Seq[NewAccumulator[_, _]])],
blockManagerId: BlockManagerId): Boolean = true
override def submitTasks(taskSet: TaskSet) = {
// normally done by TaskSetManager
@@ -277,8 +277,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
taskSet.tasks(i),
result._1,
result._2,
- Seq(new AccumulableInfo(
- accumId, Some(""), Some(1), None, internal = false, countFailedValues = false))))
+ Seq(AccumulatorSuite.createLongAccum("", initValue = 1, id = accumId))))
}
}
}
@@ -484,7 +483,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
override def defaultParallelism(): Int = 2
override def executorHeartbeatReceived(
execId: String,
- accumUpdates: Array[(Long, Seq[AccumulableInfo])],
+ accumUpdates: Array[(Long, Seq[NewAccumulator[_, _]])],
blockManagerId: BlockManagerId): Boolean = true
override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {}
override def applicationAttemptId(): Option[String] = None
@@ -997,10 +996,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
// complete two tasks
runEvent(makeCompletionEvent(
taskSets(0).tasks(0), Success, 42,
- Seq.empty[AccumulableInfo], createFakeTaskInfoWithId(0)))
+ Seq.empty, createFakeTaskInfoWithId(0)))
runEvent(makeCompletionEvent(
taskSets(0).tasks(1), Success, 42,
- Seq.empty[AccumulableInfo], createFakeTaskInfoWithId(1)))
+ Seq.empty, createFakeTaskInfoWithId(1)))
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
// verify stage exists
assert(scheduler.stageIdToStage.contains(0))
@@ -1009,10 +1008,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
// finish other 2 tasks
runEvent(makeCompletionEvent(
taskSets(0).tasks(2), Success, 42,
- Seq.empty[AccumulableInfo], createFakeTaskInfoWithId(2)))
+ Seq.empty, createFakeTaskInfoWithId(2)))
runEvent(makeCompletionEvent(
taskSets(0).tasks(3), Success, 42,
- Seq.empty[AccumulableInfo], createFakeTaskInfoWithId(3)))
+ Seq.empty, createFakeTaskInfoWithId(3)))
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(sparkListener.endedTasks.size == 4)
@@ -1023,14 +1022,14 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
// with a speculative task and make sure the event is sent out
runEvent(makeCompletionEvent(
taskSets(0).tasks(3), Success, 42,
- Seq.empty[AccumulableInfo], createFakeTaskInfoWithId(5)))
+ Seq.empty, createFakeTaskInfoWithId(5)))
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(sparkListener.endedTasks.size == 5)
// make sure non successful tasks also send out event
runEvent(makeCompletionEvent(
taskSets(0).tasks(3), UnknownReason, 42,
- Seq.empty[AccumulableInfo], createFakeTaskInfoWithId(6)))
+ Seq.empty, createFakeTaskInfoWithId(6)))
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
assert(sparkListener.endedTasks.size == 6)
}
@@ -1613,37 +1612,43 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
test("accumulator not calculated for resubmitted result stage") {
// just for register
- val accum = new Accumulator[Int](0, AccumulatorParam.IntAccumulatorParam)
+ val accum = AccumulatorSuite.createLongAccum("a")
val finalRdd = new MyRDD(sc, 1, Nil)
submit(finalRdd, Array(0))
completeWithAccumulator(accum.id, taskSets(0), Seq((Success, 42)))
completeWithAccumulator(accum.id, taskSets(0), Seq((Success, 42)))
assert(results === Map(0 -> 42))
- val accVal = Accumulators.originals(accum.id).get.get.value
-
- assert(accVal === 1)
-
+ assert(accum.value === 1)
assertDataStructuresEmpty()
}
test("accumulators are updated on exception failures") {
- val acc1 = sc.accumulator(0L, "ingenieur")
- val acc2 = sc.accumulator(0L, "boulanger")
- val acc3 = sc.accumulator(0L, "agriculteur")
- assert(Accumulators.get(acc1.id).isDefined)
- assert(Accumulators.get(acc2.id).isDefined)
- assert(Accumulators.get(acc3.id).isDefined)
- val accInfo1 = acc1.toInfo(Some(15L), None)
- val accInfo2 = acc2.toInfo(Some(13L), None)
- val accInfo3 = acc3.toInfo(Some(18L), None)
- val accumUpdates = Seq(accInfo1, accInfo2, accInfo3)
- val exceptionFailure = new ExceptionFailure(new SparkException("fondue?"), accumUpdates)
+ val acc1 = AccumulatorSuite.createLongAccum("ingenieur")
+ val acc2 = AccumulatorSuite.createLongAccum("boulanger")
+ val acc3 = AccumulatorSuite.createLongAccum("agriculteur")
+ assert(AccumulatorContext.get(acc1.id).isDefined)
+ assert(AccumulatorContext.get(acc2.id).isDefined)
+ assert(AccumulatorContext.get(acc3.id).isDefined)
+ val accUpdate1 = new LongAccumulator
+ accUpdate1.metadata = acc1.metadata
+ accUpdate1.setValue(15)
+ val accUpdate2 = new LongAccumulator
+ accUpdate2.metadata = acc2.metadata
+ accUpdate2.setValue(13)
+ val accUpdate3 = new LongAccumulator
+ accUpdate3.metadata = acc3.metadata
+ accUpdate3.setValue(18)
+ val accumUpdates = Seq(accUpdate1, accUpdate2, accUpdate3)
+ val accumInfo = accumUpdates.map(AccumulatorSuite.makeInfo)
+ val exceptionFailure = new ExceptionFailure(
+ new SparkException("fondue?"),
+ accumInfo).copy(accums = accumUpdates)
submit(new MyRDD(sc, 1, Nil), Array(0))
runEvent(makeCompletionEvent(taskSets.head.tasks.head, exceptionFailure, "result"))
- assert(Accumulators.get(acc1.id).get.value === 15L)
- assert(Accumulators.get(acc2.id).get.value === 13L)
- assert(Accumulators.get(acc3.id).get.value === 18L)
+ assert(AccumulatorContext.get(acc1.id).get.value === 15L)
+ assert(AccumulatorContext.get(acc2.id).get.value === 13L)
+ assert(AccumulatorContext.get(acc3.id).get.value === 18L)
}
test("reduce tasks should be placed locally with map output") {
@@ -2007,12 +2012,12 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
task: Task[_],
reason: TaskEndReason,
result: Any,
- extraAccumUpdates: Seq[AccumulableInfo] = Seq.empty[AccumulableInfo],
+ extraAccumUpdates: Seq[NewAccumulator[_, _]] = Seq.empty,
taskInfo: TaskInfo = createFakeTaskInfo()): CompletionEvent = {
val accumUpdates = reason match {
- case Success => task.metrics.accumulatorUpdates()
- case ef: ExceptionFailure => ef.accumUpdates
- case _ => Seq.empty[AccumulableInfo]
+ case Success => task.metrics.accumulators()
+ case ef: ExceptionFailure => ef.accums
+ case _ => Seq.empty
}
CompletionEvent(task, reason, result, accumUpdates ++ extraAccumUpdates, taskInfo)
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala
index 9971d48a52..16027d944f 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala
@@ -17,12 +17,11 @@
package org.apache.spark.scheduler
-import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite}
+import org.apache.spark.{LocalSparkContext, NewAccumulator, SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
import org.apache.spark.storage.BlockManagerId
-class ExternalClusterManagerSuite extends SparkFunSuite with LocalSparkContext
-{
+class ExternalClusterManagerSuite extends SparkFunSuite with LocalSparkContext {
test("launch of backend and scheduler") {
val conf = new SparkConf().setMaster("myclusterManager").
setAppName("testcm").set("spark.driver.allowMultipleContexts", "true")
@@ -68,6 +67,6 @@ private class DummyTaskScheduler extends TaskScheduler {
override def applicationAttemptId(): Option[String] = None
def executorHeartbeatReceived(
execId: String,
- accumUpdates: Array[(Long, Seq[AccumulableInfo])],
+ accumUpdates: Array[(Long, Seq[NewAccumulator[_, _]])],
blockManagerId: BlockManagerId): Boolean = true
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index d55f6f60ec..9aca4dbc23 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -162,18 +162,17 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
}.count()
// The one that counts failed values should be 4x the one that didn't,
// since we ran each task 4 times
- assert(Accumulators.get(acc1.id).get.value === 40L)
- assert(Accumulators.get(acc2.id).get.value === 10L)
+ assert(AccumulatorContext.get(acc1.id).get.value === 40L)
+ assert(AccumulatorContext.get(acc2.id).get.value === 10L)
}
test("failed tasks collect only accumulators whose values count during failures") {
sc = new SparkContext("local", "test")
- val param = AccumulatorParam.LongAccumulatorParam
- val acc1 = new Accumulator(0L, param, Some("x"), countFailedValues = true)
- val acc2 = new Accumulator(0L, param, Some("y"), countFailedValues = false)
+ val acc1 = AccumulatorSuite.createLongAccum("x", true)
+ val acc2 = AccumulatorSuite.createLongAccum("y", false)
// Create a dummy task. We won't end up running this; we just want to collect
// accumulator updates from it.
- val taskMetrics = new TaskMetrics
+ val taskMetrics = TaskMetrics.empty
val task = new Task[Int](0, 0, 0) {
context = new TaskContextImpl(0, 0, 0L, 0,
new TaskMemoryManager(SparkEnv.get.memoryManager, 0L),
@@ -186,12 +185,11 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
}
// First, simulate task success. This should give us all the accumulators.
val accumUpdates1 = task.collectAccumulatorUpdates(taskFailed = false)
- val accumUpdates2 = (taskMetrics.internalAccums ++ Seq(acc1, acc2))
- .map(TaskMetricsSuite.makeInfo)
+ val accumUpdates2 = taskMetrics.internalAccums ++ Seq(acc1, acc2)
TaskMetricsSuite.assertUpdatesEquals(accumUpdates1, accumUpdates2)
// Now, simulate task failures. This should give us only the accums that count failed values.
val accumUpdates3 = task.collectAccumulatorUpdates(taskFailed = true)
- val accumUpdates4 = (taskMetrics.internalAccums ++ Seq(acc1)).map(TaskMetricsSuite.makeInfo)
+ val accumUpdates4 = taskMetrics.internalAccums ++ Seq(acc1)
TaskMetricsSuite.assertUpdatesEquals(accumUpdates3, accumUpdates4)
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
index b5385c11a9..9e472f900b 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
@@ -241,8 +241,8 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local
assert(resultGetter.taskResults.size === 1)
val resBefore = resultGetter.taskResults.head
val resAfter = captor.getValue
- val resSizeBefore = resBefore.accumUpdates.find(_.name == Some(RESULT_SIZE)).flatMap(_.update)
- val resSizeAfter = resAfter.accumUpdates.find(_.name == Some(RESULT_SIZE)).flatMap(_.update)
+ val resSizeBefore = resBefore.accumUpdates.find(_.name == Some(RESULT_SIZE)).map(_.value)
+ val resSizeAfter = resAfter.accumUpdates.find(_.name == Some(RESULT_SIZE)).map(_.value)
assert(resSizeBefore.exists(_ == 0L))
assert(resSizeAfter.exists(_.toString.toLong > 0L))
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index ecf4b76da5..339fc4254d 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -37,7 +37,7 @@ class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler)
task: Task[_],
reason: TaskEndReason,
result: Any,
- accumUpdates: Seq[AccumulableInfo],
+ accumUpdates: Seq[NewAccumulator[_, _]],
taskInfo: TaskInfo) {
taskScheduler.endedTasks(taskInfo.index) = reason
}
@@ -166,8 +166,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
val taskSet = FakeTask.createTaskSet(1)
val clock = new ManualClock
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock)
- val accumUpdates =
- taskSet.tasks.head.metrics.internalAccums.map { a => a.toInfo(Some(0L), None) }
+ val accumUpdates = taskSet.tasks.head.metrics.internalAccums
// Offer a host with NO_PREF as the constraint,
// we should get a nopref task immediately since that's what we only have
@@ -185,8 +184,8 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
val sched = new FakeTaskScheduler(sc, ("exec1", "host1"))
val taskSet = FakeTask.createTaskSet(3)
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES)
- val accumUpdatesByTask: Array[Seq[AccumulableInfo]] = taskSet.tasks.map { task =>
- task.metrics.internalAccums.map { a => a.toInfo(Some(0L), None) }
+ val accumUpdatesByTask: Array[Seq[NewAccumulator[_, _]]] = taskSet.tasks.map { task =>
+ task.metrics.internalAccums
}
// First three offers should all find tasks
@@ -792,7 +791,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
private def createTaskResult(
id: Int,
- accumUpdates: Seq[AccumulableInfo] = Seq.empty[AccumulableInfo]): DirectTaskResult[Int] = {
+ accumUpdates: Seq[NewAccumulator[_, _]] = Seq.empty): DirectTaskResult[Int] = {
val valueSer = SparkEnv.get.serializer.newInstance()
new DirectTaskResult[Int](valueSer.serialize(id), accumUpdates)
}
diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
index 221124829f..ce7d51d1c3 100644
--- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
@@ -183,7 +183,7 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with
test("test executor id to summary") {
val conf = new SparkConf()
val listener = new JobProgressListener(conf)
- val taskMetrics = new TaskMetrics()
+ val taskMetrics = TaskMetrics.empty
val shuffleReadMetrics = taskMetrics.createTempShuffleReadMetrics()
assert(listener.stageIdToData.size === 0)
@@ -230,7 +230,7 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with
test("test task success vs failure counting for different task end reasons") {
val conf = new SparkConf()
val listener = new JobProgressListener(conf)
- val metrics = new TaskMetrics()
+ val metrics = TaskMetrics.empty
val taskInfo = new TaskInfo(1234L, 0, 3, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false)
taskInfo.finishTime = 1
val task = new ShuffleMapTask(0)
@@ -269,7 +269,7 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with
val execId = "exe-1"
def makeTaskMetrics(base: Int): TaskMetrics = {
- val taskMetrics = new TaskMetrics
+ val taskMetrics = TaskMetrics.empty
val shuffleReadMetrics = taskMetrics.createTempShuffleReadMetrics()
val shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics
val inputMetrics = taskMetrics.inputMetrics
@@ -300,9 +300,9 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with
listener.onTaskStart(SparkListenerTaskStart(1, 0, makeTaskInfo(1237L)))
listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate(execId, Array(
- (1234L, 0, 0, makeTaskMetrics(0).accumulatorUpdates()),
- (1235L, 0, 0, makeTaskMetrics(100).accumulatorUpdates()),
- (1236L, 1, 0, makeTaskMetrics(200).accumulatorUpdates()))))
+ (1234L, 0, 0, makeTaskMetrics(0).accumulators().map(AccumulatorSuite.makeInfo)),
+ (1235L, 0, 0, makeTaskMetrics(100).accumulators().map(AccumulatorSuite.makeInfo)),
+ (1236L, 1, 0, makeTaskMetrics(200).accumulators().map(AccumulatorSuite.makeInfo)))))
var stage0Data = listener.stageIdToData.get((0, 0)).get
var stage1Data = listener.stageIdToData.get((1, 0)).get
diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
index d3b6cdfe86..6fda7378e6 100644
--- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
@@ -85,7 +85,8 @@ class JsonProtocolSuite extends SparkFunSuite {
// Use custom accum ID for determinism
val accumUpdates =
makeTaskMetrics(300L, 400L, 500L, 600L, 700, 800, hasHadoopInput = true, hasOutput = true)
- .accumulatorUpdates().zipWithIndex.map { case (a, i) => a.copy(id = i) }
+ .accumulators().map(AccumulatorSuite.makeInfo)
+ .zipWithIndex.map { case (a, i) => a.copy(id = i) }
SparkListenerExecutorMetricsUpdate("exec3", Seq((1L, 2, 3, accumUpdates)))
}
@@ -385,7 +386,7 @@ class JsonProtocolSuite extends SparkFunSuite {
// "Task Metrics" field, if it exists.
val tm = makeTaskMetrics(1L, 2L, 3L, 4L, 5, 6, hasHadoopInput = true, hasOutput = true)
val tmJson = JsonProtocol.taskMetricsToJson(tm)
- val accumUpdates = tm.accumulatorUpdates()
+ val accumUpdates = tm.accumulators().map(AccumulatorSuite.makeInfo)
val exception = new SparkException("sentimental")
val exceptionFailure = new ExceptionFailure(exception, accumUpdates)
val exceptionFailureJson = JsonProtocol.taskEndReasonToJson(exceptionFailure)
@@ -813,7 +814,7 @@ private[spark] object JsonProtocolSuite extends Assertions {
hasHadoopInput: Boolean,
hasOutput: Boolean,
hasRecords: Boolean = true) = {
- val t = new TaskMetrics
+ val t = TaskMetrics.empty
t.setExecutorDeserializeTime(a)
t.setExecutorRunTime(b)
t.setResultSize(c)
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 0f8648f890..6fc49a08fe 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -688,6 +688,18 @@ object MimaExcludes {
) ++ Seq(
// [SPARK-4452][Core]Shuffle data structures can starve others on the same thread for memory
ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.util.collection.Spillable")
+ ) ++ Seq(
+ // SPARK-14654: New accumulator API
+ ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ExceptionFailure$"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ExceptionFailure.apply"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ExceptionFailure.metrics"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ExceptionFailure.copy"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ExceptionFailure.this"),
+ ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.executor.ShuffleReadMetrics.remoteBlocksFetched"),
+ ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.executor.ShuffleReadMetrics.totalBlocksFetched"),
+ ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.executor.ShuffleReadMetrics.localBlocksFetched"),
+ ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.ShuffleReadMetrics.remoteBlocksFetched"),
+ ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.ShuffleReadMetrics.localBlocksFetched")
)
case v if v.startsWith("1.6") =>
Seq(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index 520ceaaaea..d6516f26a7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -106,7 +106,7 @@ private[sql] case class RDDScanExec(
override val nodeName: String) extends LeafExecNode {
private[sql] override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
@@ -147,7 +147,7 @@ private[sql] case class RowDataSourceScanExec(
extends DataSourceScanExec with CodegenSupport {
private[sql] override lazy val metrics =
- Map("numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
val outputUnsafeRows = relation match {
case r: HadoopFsRelation if r.fileFormat.isInstanceOf[ParquetSource] =>
@@ -216,7 +216,7 @@ private[sql] case class BatchedDataSourceScanExec(
extends DataSourceScanExec with CodegenSupport {
private[sql] override lazy val metrics =
- Map("numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"),
+ Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
"scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time"))
protected override def doExecute(): RDD[InternalRow] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
index 7c4756663a..c201822d44 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
@@ -40,7 +40,7 @@ case class ExpandExec(
extends UnaryExecNode with CodegenSupport {
private[sql] override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
// The GroupExpressions can output data with arbitrary partitioning, so set it
// as UNKNOWN partitioning
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
index 10cfec3330..934bc38dc4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
@@ -56,7 +56,7 @@ case class GenerateExec(
extends UnaryExecNode {
private[sql] override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
override def producedAttributes: AttributeSet = AttributeSet(output)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala
index 4ab447a47b..c5e78b0333 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala
@@ -31,7 +31,7 @@ private[sql] case class LocalTableScanExec(
rows: Seq[InternalRow]) extends LeafExecNode {
private[sql] override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
private val unsafeRows: Array[InternalRow] = {
val proj = UnsafeProjection.create(output, output)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 861ff3cd15..0bbe970420 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetric}
+import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.DataType
import org.apache.spark.util.ThreadUtils
@@ -77,7 +77,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
/**
* Return all metrics containing metrics of this SparkPlan.
*/
- private[sql] def metrics: Map[String, SQLMetric[_, _]] = Map.empty
+ private[sql] def metrics: Map[String, SQLMetric] = Map.empty
/**
* Reset all the metrics.
@@ -89,8 +89,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
/**
* Return a LongSQLMetric according to the name.
*/
- private[sql] def longMetric(name: String): LongSQLMetric =
- metrics(name).asInstanceOf[LongSQLMetric]
+ private[sql] def longMetric(name: String): SQLMetric = metrics(name)
// TODO: Move to `DistributedPlan`
/** Specifies how data is partitioned across different nodes in the cluster. */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
index cb4b1cfeb9..f84070a0c4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
@@ -55,8 +55,7 @@ private[sql] object SparkPlanInfo {
case _ => plan.children ++ plan.subqueries
}
val metrics = plan.metrics.toSeq.map { case (key, metric) =>
- new SQLMetricInfo(metric.name.getOrElse(key), metric.id,
- Utils.getFormattedClassName(metric.param))
+ new SQLMetricInfo(metric.name.getOrElse(key), metric.id, metric.metricType)
}
new SparkPlanInfo(plan.nodeName, plan.simpleString, children.map(fromSparkPlan),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
index 362d0d7a72..484923428f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala
@@ -26,7 +26,7 @@ import com.google.common.io.ByteStreams
import org.apache.spark.serializer.{DeserializationStream, SerializationStream, Serializer, SerializerInstance}
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
-import org.apache.spark.sql.execution.metric.LongSQLMetric
+import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.unsafe.Platform
/**
@@ -42,7 +42,7 @@ import org.apache.spark.unsafe.Platform
*/
private[sql] class UnsafeRowSerializer(
numFields: Int,
- dataSize: LongSQLMetric = null) extends Serializer with Serializable {
+ dataSize: SQLMetric = null) extends Serializer with Serializable {
override def newInstance(): SerializerInstance =
new UnsafeRowSerializerInstance(numFields, dataSize)
override private[spark] def supportsRelocationOfSerializedObjects: Boolean = true
@@ -50,7 +50,7 @@ private[sql] class UnsafeRowSerializer(
private class UnsafeRowSerializerInstance(
numFields: Int,
- dataSize: LongSQLMetric) extends SerializerInstance {
+ dataSize: SQLMetric) extends SerializerInstance {
/**
* Serializes a stream of UnsafeRows. Within the stream, each record consists of a record
* length (stored as a 4-byte integer, written high byte first), followed by the record's bytes.
@@ -60,13 +60,10 @@ private class UnsafeRowSerializerInstance(
private[this] val dOut: DataOutputStream =
new DataOutputStream(new BufferedOutputStream(out))
- // LongSQLMetricParam.add() is faster than LongSQLMetric.+=
- val localDataSize = if (dataSize != null) dataSize.localValue else null
-
override def writeValue[T: ClassTag](value: T): SerializationStream = {
val row = value.asInstanceOf[UnsafeRow]
- if (localDataSize != null) {
- localDataSize.add(row.getSizeInBytes)
+ if (dataSize != null) {
+ dataSize.add(row.getSizeInBytes)
}
dOut.writeInt(row.getSizeInBytes)
row.writeToStream(dOut, writeBuffer)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index 6a03bd08c5..15b4abe806 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.toCommentSafeString
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
-import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetrics}
+import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -52,11 +52,7 @@ trait CodegenSupport extends SparkPlan {
* @return name of the variable representing the metric
*/
def metricTerm(ctx: CodegenContext, name: String): String = {
- val metric = ctx.addReferenceObj(name, longMetric(name))
- val value = ctx.freshName("metricValue")
- val cls = classOf[LongSQLMetricValue].getName
- ctx.addMutableState(cls, value, s"$value = ($cls) $metric.localValue();")
- value
+ ctx.addReferenceObj(name, longMetric(name))
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregateExec.scala
index 3169e0a2fd..2e74d59c5f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregateExec.scala
@@ -46,7 +46,7 @@ case class SortBasedAggregateExec(
AttributeSet(aggregateBufferAttributes)
override private[sql] lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
index c35d781d3e..f392b135ce 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction}
-import org.apache.spark.sql.execution.metric.LongSQLMetric
+import org.apache.spark.sql.execution.metric.SQLMetric
/**
* An iterator used to evaluate [[AggregateFunction]]. It assumes the input rows have been
@@ -35,7 +35,7 @@ class SortBasedAggregationIterator(
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection,
- numOutputRows: LongSQLMetric)
+ numOutputRows: SQLMetric)
extends AggregationIterator(
groupingExpressions,
valueAttributes,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 16362f756f..d0ba37ee13 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetrics}
+import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.types.{DecimalType, StringType, StructType}
import org.apache.spark.unsafe.KVIterator
@@ -51,7 +51,7 @@ case class TungstenAggregate(
aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
override private[sql] lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"),
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
"peakMemory" -> SQLMetrics.createSizeMetric(sparkContext, "peak memory"),
"spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"),
"aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "aggregate time"))
@@ -309,8 +309,8 @@ case class TungstenAggregate(
def finishAggregate(
hashMap: UnsafeFixedWidthAggregationMap,
sorter: UnsafeKVExternalSorter,
- peakMemory: LongSQLMetricValue,
- spillSize: LongSQLMetricValue): KVIterator[UnsafeRow, UnsafeRow] = {
+ peakMemory: SQLMetric,
+ spillSize: SQLMetric): KVIterator[UnsafeRow, UnsafeRow] = {
// update peak execution memory
val mapMemory = hashMap.getPeakMemoryUsedBytes
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index 9db5087fe0..243aa15deb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, UnsafeKVExternalSorter}
-import org.apache.spark.sql.execution.metric.LongSQLMetric
+import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.KVIterator
@@ -86,9 +86,9 @@ class TungstenAggregationIterator(
originalInputAttributes: Seq[Attribute],
inputIter: Iterator[InternalRow],
testFallbackStartsAt: Option[(Int, Int)],
- numOutputRows: LongSQLMetric,
- peakMemory: LongSQLMetric,
- spillSize: LongSQLMetric)
+ numOutputRows: SQLMetric,
+ peakMemory: SQLMetric,
+ spillSize: SQLMetric)
extends AggregationIterator(
groupingExpressions,
originalInputAttributes,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index 83f527f555..77be613b83 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -103,7 +103,7 @@ case class FilterExec(condition: Expression, child: SparkPlan)
}
private[sql] override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
override def inputRDDs(): Seq[RDD[InternalRow]] = {
child.asInstanceOf[CodegenSupport].inputRDDs()
@@ -229,7 +229,7 @@ case class SampleExec(
override def output: Seq[Attribute] = child.output
private[sql] override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
protected override def doExecute(): RDD[InternalRow] = {
if (withReplacement) {
@@ -322,7 +322,7 @@ case class RangeExec(
extends LeafExecNode with CodegenSupport {
private[sql] override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
// output attributes should not affect the results
override lazy val cleanArgs: Seq[Any] = Seq(start, step, numSlices, numElements)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
index cb957b9666..577c34ba61 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.columnar
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.{Accumulable, Accumulator, Accumulators}
+import org.apache.spark.{Accumulable, Accumulator, AccumulatorContext}
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
@@ -204,7 +204,7 @@ private[sql] case class InMemoryRelation(
Seq(_cachedColumnBuffers, statisticsToBePropagated, batchStats)
private[sql] def uncache(blocking: Boolean): Unit = {
- Accumulators.remove(batchStats.id)
+ AccumulatorContext.remove(batchStats.id)
cachedColumnBuffers.unpersist(blocking)
_cachedColumnBuffers = null
}
@@ -217,7 +217,7 @@ private[sql] case class InMemoryTableScanExec(
extends LeafExecNode {
private[sql] override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
override def output: Seq[Attribute] = attributes
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
index 573ca195ac..b6ecd3cb06 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
@@ -38,10 +38,10 @@ case class BroadcastExchangeExec(
child: SparkPlan) extends Exchange {
override private[sql] lazy val metrics = Map(
- "dataSize" -> SQLMetrics.createLongMetric(sparkContext, "data size (bytes)"),
- "collectTime" -> SQLMetrics.createLongMetric(sparkContext, "time to collect (ms)"),
- "buildTime" -> SQLMetrics.createLongMetric(sparkContext, "time to build (ms)"),
- "broadcastTime" -> SQLMetrics.createLongMetric(sparkContext, "time to broadcast (ms)"))
+ "dataSize" -> SQLMetrics.createMetric(sparkContext, "data size (bytes)"),
+ "collectTime" -> SQLMetrics.createMetric(sparkContext, "time to collect (ms)"),
+ "buildTime" -> SQLMetrics.createMetric(sparkContext, "time to build (ms)"),
+ "broadcastTime" -> SQLMetrics.createMetric(sparkContext, "time to broadcast (ms)"))
override def outputPartitioning: Partitioning = BroadcastPartitioning(mode)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
index b0a6b8f28a..587c603192 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
@@ -46,7 +46,7 @@ case class BroadcastHashJoinExec(
extends BinaryExecNode with HashJoin with CodegenSupport {
override private[sql] lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
index 51afa0017d..a659bf26e3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
@@ -35,7 +35,7 @@ case class BroadcastNestedLoopJoinExec(
condition: Option[Expression]) extends BinaryExecNode {
override private[sql] lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
/** BuildRight means the right relation <=> the broadcast relation. */
private val (streamed, broadcast) = buildSide match {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
index 67f59197ad..8d7ecc442a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
@@ -86,7 +86,7 @@ case class CartesianProductExec(
override def output: Seq[Attribute] = left.output ++ right.output
override private[sql] lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index d6feedc272..9c173d7bf1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution.{RowIterator, SparkPlan}
-import org.apache.spark.sql.execution.metric.LongSQLMetric
+import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.{IntegralType, LongType}
trait HashJoin {
@@ -201,7 +201,7 @@ trait HashJoin {
protected def join(
streamedIter: Iterator[InternalRow],
hashed: HashedRelation,
- numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
+ numOutputRows: SQLMetric): Iterator[InternalRow] = {
val joinedIter = joinType match {
case Inner =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
index a242a078f6..3ef2fec352 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
@@ -40,7 +40,7 @@ case class ShuffledHashJoinExec(
extends BinaryExecNode with HashJoin {
override private[sql] lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"),
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
"buildDataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size of build side"),
"buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map"))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index a4c5491aff..775f8ac508 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCo
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, RowIterator, SparkPlan}
-import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics}
+import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.util.collection.BitSet
/**
@@ -41,7 +41,7 @@ case class SortMergeJoinExec(
right: SparkPlan) extends BinaryExecNode with CodegenSupport {
override private[sql] lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
override def output: Seq[Attribute] = {
joinType match {
@@ -734,7 +734,7 @@ private class LeftOuterIterator(
rightNullRow: InternalRow,
boundCondition: InternalRow => Boolean,
resultProj: InternalRow => InternalRow,
- numOutputRows: LongSQLMetric)
+ numOutputRows: SQLMetric)
extends OneSideOuterIterator(
smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows) {
@@ -750,7 +750,7 @@ private class RightOuterIterator(
leftNullRow: InternalRow,
boundCondition: InternalRow => Boolean,
resultProj: InternalRow => InternalRow,
- numOutputRows: LongSQLMetric)
+ numOutputRows: SQLMetric)
extends OneSideOuterIterator(smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows) {
protected override def setStreamSideOutput(row: InternalRow): Unit = joinedRow.withRight(row)
@@ -778,7 +778,7 @@ private abstract class OneSideOuterIterator(
bufferedSideNullRow: InternalRow,
boundCondition: InternalRow => Boolean,
resultProj: InternalRow => InternalRow,
- numOutputRows: LongSQLMetric) extends RowIterator {
+ numOutputRows: SQLMetric) extends RowIterator {
// A row to store the joined result, reused many times
protected[this] val joinedRow: JoinedRow = new JoinedRow()
@@ -1016,7 +1016,7 @@ private class SortMergeFullOuterJoinScanner(
private class FullOuterIterator(
smjScanner: SortMergeFullOuterJoinScanner,
resultProj: InternalRow => InternalRow,
- numRows: LongSQLMetric) extends RowIterator {
+ numRows: SQLMetric) extends RowIterator {
private[this] val joinedRow: JoinedRow = smjScanner.getJoinedRow()
override def advanceNext(): Boolean = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala
index 2708219ad3..adb81519db 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala
@@ -27,4 +27,4 @@ import org.apache.spark.annotation.DeveloperApi
class SQLMetricInfo(
val name: String,
val accumulatorId: Long,
- val metricParam: String)
+ val metricType: String)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
index 5755c00c1f..7bf9225272 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
@@ -19,200 +19,106 @@ package org.apache.spark.sql.execution.metric
import java.text.NumberFormat
-import org.apache.spark.{Accumulable, AccumulableParam, Accumulators, SparkContext}
+import org.apache.spark.{NewAccumulator, SparkContext}
import org.apache.spark.scheduler.AccumulableInfo
import org.apache.spark.util.Utils
-/**
- * Create a layer for specialized metric. We cannot add `@specialized` to
- * `Accumulable/AccumulableParam` because it will break Java source compatibility.
- *
- * An implementation of SQLMetric should override `+=` and `add` to avoid boxing.
- */
-private[sql] abstract class SQLMetric[R <: SQLMetricValue[T], T](
- name: String,
- val param: SQLMetricParam[R, T]) extends Accumulable[R, T](param.zero, param, Some(name)) {
- // Provide special identifier as metadata so we can tell that this is a `SQLMetric` later
- override def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = {
- new AccumulableInfo(id, Some(name), update, value, true, countFailedValues,
- Some(SQLMetrics.ACCUM_IDENTIFIER))
- }
-
- def reset(): Unit = {
- this.value = param.zero
- }
-}
-
-/**
- * Create a layer for specialized metric. We cannot add `@specialized` to
- * `Accumulable/AccumulableParam` because it will break Java source compatibility.
- */
-private[sql] trait SQLMetricParam[R <: SQLMetricValue[T], T] extends AccumulableParam[R, T] {
-
- /**
- * A function that defines how we aggregate the final accumulator results among all tasks,
- * and represent it in string for a SQL physical operator.
- */
- val stringValue: Seq[T] => String
-
- def zero: R
-}
+class SQLMetric(val metricType: String, initValue: Long = 0L) extends NewAccumulator[Long, Long] {
+ // This is a workaround for SPARK-11013.
+ // We may use -1 as initial value of the accumulator, if the accumulator is valid, we will
+ // update it at the end of task and the value will be at least 0. Then we can filter out the -1
+ // values before calculate max, min, etc.
+ private[this] var _value = initValue
-/**
- * Create a layer for specialized metric. We cannot add `@specialized` to
- * `Accumulable/AccumulableParam` because it will break Java source compatibility.
- */
-private[sql] trait SQLMetricValue[T] extends Serializable {
+ override def copyAndReset(): SQLMetric = new SQLMetric(metricType, initValue)
- def value: T
-
- override def toString: String = value.toString
-}
-
-/**
- * A wrapper of Long to avoid boxing and unboxing when using Accumulator
- */
-private[sql] class LongSQLMetricValue(private var _value : Long) extends SQLMetricValue[Long] {
-
- def add(incr: Long): LongSQLMetricValue = {
- _value += incr
- this
+ override def merge(other: NewAccumulator[Long, Long]): Unit = other match {
+ case o: SQLMetric => _value += o.localValue
+ case _ => throw new UnsupportedOperationException(
+ s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
}
- // Although there is a boxing here, it's fine because it's only called in SQLListener
- override def value: Long = _value
-
- // Needed for SQLListenerSuite
- override def equals(other: Any): Boolean = other match {
- case o: LongSQLMetricValue => value == o.value
- case _ => false
- }
+ override def isZero(): Boolean = _value == initValue
- override def hashCode(): Int = _value.hashCode()
-}
+ override def add(v: Long): Unit = _value += v
-/**
- * A specialized long Accumulable to avoid boxing and unboxing when using Accumulator's
- * `+=` and `add`.
- */
-private[sql] class LongSQLMetric private[metric](name: String, param: LongSQLMetricParam)
- extends SQLMetric[LongSQLMetricValue, Long](name, param) {
+ def +=(v: Long): Unit = _value += v
- override def +=(term: Long): Unit = {
- localValue.add(term)
- }
+ override def localValue: Long = _value
- override def add(term: Long): Unit = {
- localValue.add(term)
+ // Provide special identifier as metadata so we can tell that this is a `SQLMetric` later
+ private[spark] override def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = {
+ new AccumulableInfo(id, name, update, value, true, true, Some(SQLMetrics.ACCUM_IDENTIFIER))
}
-}
-
-private class LongSQLMetricParam(val stringValue: Seq[Long] => String, initialValue: Long)
- extends SQLMetricParam[LongSQLMetricValue, Long] {
-
- override def addAccumulator(r: LongSQLMetricValue, t: Long): LongSQLMetricValue = r.add(t)
- override def addInPlace(r1: LongSQLMetricValue, r2: LongSQLMetricValue): LongSQLMetricValue =
- r1.add(r2.value)
-
- override def zero(initialValue: LongSQLMetricValue): LongSQLMetricValue = zero
-
- override def zero: LongSQLMetricValue = new LongSQLMetricValue(initialValue)
+ def reset(): Unit = _value = initValue
}
-private object LongSQLMetricParam
- extends LongSQLMetricParam(x => NumberFormat.getInstance().format(x.sum), 0L)
-
-private object StatisticsBytesSQLMetricParam extends LongSQLMetricParam(
- (values: Seq[Long]) => {
- // This is a workaround for SPARK-11013.
- // We use -1 as initial value of the accumulator, if the accumulator is valid, we will update
- // it at the end of task and the value will be at least 0.
- val validValues = values.filter(_ >= 0)
- val Seq(sum, min, med, max) = {
- val metric = if (validValues.length == 0) {
- Seq.fill(4)(0L)
- } else {
- val sorted = validValues.sorted
- Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1))
- }
- metric.map(Utils.bytesToString)
- }
- s"\n$sum ($min, $med, $max)"
- }, -1L)
-
-private object StatisticsTimingSQLMetricParam extends LongSQLMetricParam(
- (values: Seq[Long]) => {
- // This is a workaround for SPARK-11013.
- // We use -1 as initial value of the accumulator, if the accumulator is valid, we will update
- // it at the end of task and the value will be at least 0.
- val validValues = values.filter(_ >= 0)
- val Seq(sum, min, med, max) = {
- val metric = if (validValues.length == 0) {
- Seq.fill(4)(0L)
- } else {
- val sorted = validValues.sorted
- Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1))
- }
- metric.map(Utils.msDurationToString)
- }
- s"\n$sum ($min, $med, $max)"
- }, -1L)
private[sql] object SQLMetrics {
-
// Identifier for distinguishing SQL metrics from other accumulators
private[sql] val ACCUM_IDENTIFIER = "sql"
- private def createLongMetric(
- sc: SparkContext,
- name: String,
- param: LongSQLMetricParam): LongSQLMetric = {
- val acc = new LongSQLMetric(name, param)
- // This is an internal accumulator so we need to register it explicitly.
- Accumulators.register(acc)
- sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc))
- acc
- }
+ private[sql] val SUM_METRIC = "sum"
+ private[sql] val SIZE_METRIC = "size"
+ private[sql] val TIMING_METRIC = "timing"
- def createLongMetric(sc: SparkContext, name: String): LongSQLMetric = {
- createLongMetric(sc, name, LongSQLMetricParam)
+ def createMetric(sc: SparkContext, name: String): SQLMetric = {
+ val acc = new SQLMetric(SUM_METRIC)
+ acc.register(sc, name = Some(name), countFailedValues = true)
+ acc
}
/**
* Create a metric to report the size information (including total, min, med, max) like data size,
* spill size, etc.
*/
- def createSizeMetric(sc: SparkContext, name: String): LongSQLMetric = {
+ def createSizeMetric(sc: SparkContext, name: String): SQLMetric = {
// The final result of this metric in physical operator UI may looks like:
// data size total (min, med, max):
// 100GB (100MB, 1GB, 10GB)
- createLongMetric(sc, s"$name total (min, med, max)", StatisticsBytesSQLMetricParam)
+ val acc = new SQLMetric(SIZE_METRIC, -1)
+ acc.register(sc, name = Some(s"$name total (min, med, max)"), countFailedValues = true)
+ acc
}
- def createTimingMetric(sc: SparkContext, name: String): LongSQLMetric = {
+ def createTimingMetric(sc: SparkContext, name: String): SQLMetric = {
// The final result of this metric in physical operator UI may looks like:
// duration(min, med, max):
// 5s (800ms, 1s, 2s)
- createLongMetric(sc, s"$name total (min, med, max)", StatisticsTimingSQLMetricParam)
- }
-
- def getMetricParam(metricParamName: String): SQLMetricParam[SQLMetricValue[Any], Any] = {
- val longSQLMetricParam = Utils.getFormattedClassName(LongSQLMetricParam)
- val bytesSQLMetricParam = Utils.getFormattedClassName(StatisticsBytesSQLMetricParam)
- val timingsSQLMetricParam = Utils.getFormattedClassName(StatisticsTimingSQLMetricParam)
- val metricParam = metricParamName match {
- case `longSQLMetricParam` => LongSQLMetricParam
- case `bytesSQLMetricParam` => StatisticsBytesSQLMetricParam
- case `timingsSQLMetricParam` => StatisticsTimingSQLMetricParam
- }
- metricParam.asInstanceOf[SQLMetricParam[SQLMetricValue[Any], Any]]
+ val acc = new SQLMetric(TIMING_METRIC, -1)
+ acc.register(sc, name = Some(s"$name total (min, med, max)"), countFailedValues = true)
+ acc
}
/**
- * A metric that its value will be ignored. Use this one when we need a metric parameter but don't
- * care about the value.
+ * A function that defines how we aggregate the final accumulator results among all tasks,
+ * and represent it in string for a SQL physical operator.
*/
- val nullLongMetric = new LongSQLMetric("null", LongSQLMetricParam)
+ def stringValue(metricsType: String, values: Seq[Long]): String = {
+ if (metricsType == SUM_METRIC) {
+ NumberFormat.getInstance().format(values.sum)
+ } else {
+ val strFormat: Long => String = if (metricsType == SIZE_METRIC) {
+ Utils.bytesToString
+ } else if (metricsType == TIMING_METRIC) {
+ Utils.msDurationToString
+ } else {
+ throw new IllegalStateException("unexpected metrics type: " + metricsType)
+ }
+
+ val validValues = values.filter(_ >= 0)
+ val Seq(sum, min, med, max) = {
+ val metric = if (validValues.length == 0) {
+ Seq.fill(4)(0L)
+ } else {
+ val sorted = validValues.sorted
+ Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1))
+ }
+ metric.map(strFormat)
+ }
+ s"\n$sum ($min, $med, $max)"
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
index 5ae9e916ad..9118593c0e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
@@ -164,7 +164,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi
taskEnd.taskInfo.taskId,
taskEnd.stageId,
taskEnd.stageAttemptId,
- taskEnd.taskMetrics.accumulatorUpdates(),
+ taskEnd.taskMetrics.accumulators().map(a => a.toInfo(Some(a.localValue), None)),
finishTask = true)
}
}
@@ -296,7 +296,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi
}
}.filter { case (id, _) => executionUIData.accumulatorMetrics.contains(id) }
mergeAccumulatorUpdates(accumulatorUpdates, accumulatorId =>
- executionUIData.accumulatorMetrics(accumulatorId).metricParam)
+ executionUIData.accumulatorMetrics(accumulatorId).metricType)
case None =>
// This execution has been dropped
Map.empty
@@ -305,11 +305,11 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi
private def mergeAccumulatorUpdates(
accumulatorUpdates: Seq[(Long, Any)],
- paramFunc: Long => SQLMetricParam[SQLMetricValue[Any], Any]): Map[Long, String] = {
+ metricTypeFunc: Long => String): Map[Long, String] = {
accumulatorUpdates.groupBy(_._1).map { case (accumulatorId, values) =>
- val param = paramFunc(accumulatorId)
- (accumulatorId,
- param.stringValue(values.map(_._2.asInstanceOf[SQLMetricValue[Any]].value)))
+ val metricType = metricTypeFunc(accumulatorId)
+ accumulatorId ->
+ SQLMetrics.stringValue(metricType, values.map(_._2.asInstanceOf[Long]))
}
}
@@ -337,7 +337,7 @@ private[spark] class SQLHistoryListener(conf: SparkConf, sparkUI: SparkUI)
// Filter out accumulators that are not SQL metrics
// For now we assume all SQL metrics are Long's that have been JSON serialized as String's
if (a.metadata == Some(SQLMetrics.ACCUM_IDENTIFIER)) {
- val newValue = new LongSQLMetricValue(a.update.map(_.toString.toLong).getOrElse(0L))
+ val newValue = a.update.map(_.toString.toLong).getOrElse(0L)
Some(a.copy(update = Some(newValue)))
} else {
None
@@ -403,7 +403,7 @@ private[ui] class SQLExecutionUIData(
private[ui] case class SQLPlanMetric(
name: String,
accumulatorId: Long,
- metricParam: SQLMetricParam[SQLMetricValue[Any], Any])
+ metricType: String)
/**
* Store all accumulatorUpdates for all tasks in a Spark stage.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
index 1959f1e368..8f5681bfc7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
@@ -80,8 +80,7 @@ private[sql] object SparkPlanGraph {
planInfo.nodeName match {
case "WholeStageCodegen" =>
val metrics = planInfo.metrics.map { metric =>
- SQLPlanMetric(metric.name, metric.accumulatorId,
- SQLMetrics.getMetricParam(metric.metricParam))
+ SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType)
}
val cluster = new SparkPlanGraphCluster(
@@ -106,8 +105,7 @@ private[sql] object SparkPlanGraph {
edges += SparkPlanGraphEdge(node.id, parent.id)
case name =>
val metrics = planInfo.metrics.map { metric =>
- SQLPlanMetric(metric.name, metric.accumulatorId,
- SQLMetrics.getMetricParam(metric.metricParam))
+ SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType)
}
val node = new SparkPlanGraphNode(
nodeIdGenerator.getAndIncrement(), planInfo.nodeName,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index 4aea21e52a..0e6356b578 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -22,7 +22,7 @@ import scala.language.postfixOps
import org.scalatest.concurrent.Eventually._
-import org.apache.spark.Accumulators
+import org.apache.spark.AccumulatorContext
import org.apache.spark.sql.execution.RDDScanExec
import org.apache.spark.sql.execution.columnar._
import org.apache.spark.sql.execution.exchange.ShuffleExchange
@@ -333,11 +333,11 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
sql("SELECT * FROM t1").count()
sql("SELECT * FROM t2").count()
- Accumulators.synchronized {
- val accsSize = Accumulators.originals.size
+ AccumulatorContext.synchronized {
+ val accsSize = AccumulatorContext.originals.size
sqlContext.uncacheTable("t1")
sqlContext.uncacheTable("t2")
- assert((accsSize - 2) == Accumulators.originals.size)
+ assert((accsSize - 2) == AccumulatorContext.originals.size)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index 1859c6e7ad..8de4d8bbd4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -37,8 +37,8 @@ import org.apache.spark.util.{JsonProtocol, Utils}
class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
import testImplicits._
- test("LongSQLMetric should not box Long") {
- val l = SQLMetrics.createLongMetric(sparkContext, "long")
+ test("SQLMetric should not box Long") {
+ val l = SQLMetrics.createMetric(sparkContext, "long")
val f = () => {
l += 1L
l.add(1L)
@@ -300,12 +300,12 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
}
test("metrics can be loaded by history server") {
- val metric = new LongSQLMetric("zanzibar", LongSQLMetricParam)
+ val metric = SQLMetrics.createMetric(sparkContext, "zanzibar")
metric += 10L
val metricInfo = metric.toInfo(Some(metric.localValue), None)
metricInfo.update match {
- case Some(v: LongSQLMetricValue) => assert(v.value === 10L)
- case Some(v) => fail(s"metric value was not a LongSQLMetricValue: ${v.getClass.getName}")
+ case Some(v: Long) => assert(v === 10L)
+ case Some(v) => fail(s"metric value was not a Long: ${v.getClass.getName}")
case _ => fail("metric update is missing")
}
assert(metricInfo.metadata === Some(SQLMetrics.ACCUM_IDENTIFIER))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
index 09bd7f6e8f..8572ed16aa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
@@ -21,18 +21,19 @@ import java.util.Properties
import org.mockito.Mockito.{mock, when}
-import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite}
+import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.scheduler._
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.util.quietly
import org.apache.spark.sql.execution.{SparkPlanInfo, SQLExecution}
-import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetrics}
+import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.ui.SparkUI
class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
import testImplicits._
+ import org.apache.spark.AccumulatorSuite.makeInfo
private def createTestDataFrame: DataFrame = {
Seq(
@@ -72,9 +73,11 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
private def createTaskMetrics(accumulatorUpdates: Map[Long, Long]): TaskMetrics = {
val metrics = mock(classOf[TaskMetrics])
- when(metrics.accumulatorUpdates()).thenReturn(accumulatorUpdates.map { case (id, update) =>
- new AccumulableInfo(id, Some(""), Some(new LongSQLMetricValue(update)),
- value = None, internal = true, countFailedValues = true)
+ when(metrics.accumulators()).thenReturn(accumulatorUpdates.map { case (id, update) =>
+ val acc = new LongAccumulator
+ acc.metadata = AccumulatorMetadata(id, Some(""), true)
+ acc.setValue(update)
+ acc
}.toSeq)
metrics
}
@@ -130,16 +133,17 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq(
// (task id, stage id, stage attempt, accum updates)
- (0L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()),
- (1L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates())
+ (0L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)),
+ (1L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo))
)))
checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2))
listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq(
// (task id, stage id, stage attempt, accum updates)
- (0L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()),
- (1L, 0, 0, createTaskMetrics(accumulatorUpdates.mapValues(_ * 2)).accumulatorUpdates())
+ (0L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)),
+ (1L, 0, 0,
+ createTaskMetrics(accumulatorUpdates.mapValues(_ * 2)).accumulators().map(makeInfo))
)))
checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 3))
@@ -149,8 +153,8 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq(
// (task id, stage id, stage attempt, accum updates)
- (0L, 0, 1, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()),
- (1L, 0, 1, createTaskMetrics(accumulatorUpdates).accumulatorUpdates())
+ (0L, 0, 1, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)),
+ (1L, 0, 1, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo))
)))
checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2))
@@ -189,8 +193,8 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq(
// (task id, stage id, stage attempt, accum updates)
- (0L, 1, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()),
- (1L, 1, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates())
+ (0L, 1, 0, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo)),
+ (1L, 1, 0, createTaskMetrics(accumulatorUpdates).accumulators().map(makeInfo))
)))
checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 7))
@@ -358,7 +362,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
val stageSubmitted = SparkListenerStageSubmitted(stageInfo)
// This task has both accumulators that are SQL metrics and accumulators that are not.
// The listener should only track the ones that are actually SQL metrics.
- val sqlMetric = SQLMetrics.createLongMetric(sparkContext, "beach umbrella")
+ val sqlMetric = SQLMetrics.createMetric(sparkContext, "beach umbrella")
val nonSqlMetric = sparkContext.accumulator[Int](0, "baseball")
val sqlMetricInfo = sqlMetric.toInfo(Some(sqlMetric.localValue), None)
val nonSqlMetricInfo = nonSqlMetric.toInfo(Some(nonSqlMetric.localValue), None)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
index eb25ea0629..8a0578c1ff 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
@@ -96,7 +96,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
case w: WholeStageCodegenExec => w.child.longMetric("numOutputRows")
case other => other.longMetric("numOutputRows")
}
- metrics += metric.value.value
+ metrics += metric.value
}
}
sqlContext.listenerManager.register(listener)
@@ -126,9 +126,9 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
- metrics += qe.executedPlan.longMetric("dataSize").value.value
+ metrics += qe.executedPlan.longMetric("dataSize").value
val bottomAgg = qe.executedPlan.children(0).children(0)
- metrics += bottomAgg.longMetric("dataSize").value.value
+ metrics += bottomAgg.longMetric("dataSize").value
}
}
sqlContext.listenerManager.register(listener)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
index 007c3384e5..b52b96a804 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
@@ -55,7 +55,7 @@ case class HiveTableScanExec(
"Partition pruning predicates only supported for partitioned tables.")
private[sql] override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
override def producedAttributes: AttributeSet = outputSet ++
AttributeSet(partitionPruningPred.flatMap(_.references))