aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/org/apache/spark/executor/Executor.scala
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main/scala/org/apache/spark/executor/Executor.scala')
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala45
1 files changed, 19 insertions, 26 deletions
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 030ae41db4..51c000ea5c 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -31,7 +31,7 @@ import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.rpc.RpcTimeout
-import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task}
+import org.apache.spark.scheduler.{AccumulableInfo, DirectTaskResult, IndirectTaskResult, Task}
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
import org.apache.spark.util._
@@ -210,7 +210,7 @@ private[spark] class Executor(
// Run the actual task and measure its runtime.
taskStart = System.currentTimeMillis()
var threwException = true
- val (value, accumUpdates) = try {
+ val value = try {
val res = task.run(
taskAttemptId = taskId,
attemptNumber = attemptNumber,
@@ -249,10 +249,11 @@ private[spark] class Executor(
m.setExecutorRunTime((taskFinish - taskStart) - task.executorDeserializeTime)
m.setJvmGCTime(computeTotalGcTime() - startGCTime)
m.setResultSerializationTime(afterSerialization - beforeSerialization)
- m.updateAccumulators()
}
- val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull)
+ // Note: accumulator updates must be collected after TaskMetrics is updated
+ val accumUpdates = task.collectAccumulatorUpdates()
+ val directResult = new DirectTaskResult(valueBytes, accumUpdates)
val serializedDirectResult = ser.serialize(directResult)
val resultSize = serializedDirectResult.limit
@@ -297,21 +298,25 @@ private[spark] class Executor(
// the default uncaught exception handler, which will terminate the Executor.
logError(s"Exception in $taskName (TID $taskId)", t)
- val metrics: Option[TaskMetrics] = Option(task).flatMap { task =>
- task.metrics.map { m =>
+ // Collect latest accumulator values to report back to the driver
+ val accumulatorUpdates: Seq[AccumulableInfo] =
+ if (task != null) {
+ task.metrics.foreach { m =>
m.setExecutorRunTime(System.currentTimeMillis() - taskStart)
m.setJvmGCTime(computeTotalGcTime() - startGCTime)
- m.updateAccumulators()
- m
}
+ task.collectAccumulatorUpdates(taskFailed = true)
+ } else {
+ Seq.empty[AccumulableInfo]
}
+
val serializedTaskEndReason = {
try {
- ser.serialize(new ExceptionFailure(t, metrics))
+ ser.serialize(new ExceptionFailure(t, accumulatorUpdates))
} catch {
case _: NotSerializableException =>
// t is not serializable so just send the stacktrace
- ser.serialize(new ExceptionFailure(t, metrics, false))
+ ser.serialize(new ExceptionFailure(t, accumulatorUpdates, preserveCause = false))
}
}
execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason)
@@ -418,33 +423,21 @@ private[spark] class Executor(
/** Reports heartbeat and metrics for active tasks to the driver. */
private def reportHeartBeat(): Unit = {
- // list of (task id, metrics) to send back to the driver
- val tasksMetrics = new ArrayBuffer[(Long, TaskMetrics)]()
+ // list of (task id, accumUpdates) to send back to the driver
+ val accumUpdates = new ArrayBuffer[(Long, Seq[AccumulableInfo])]()
val curGCTime = computeTotalGcTime()
for (taskRunner <- runningTasks.values().asScala) {
if (taskRunner.task != null) {
taskRunner.task.metrics.foreach { metrics =>
metrics.mergeShuffleReadMetrics()
- metrics.updateInputMetrics()
metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime)
- metrics.updateAccumulators()
-
- if (isLocal) {
- // JobProgressListener will hold an reference of it during
- // onExecutorMetricsUpdate(), then JobProgressListener can not see
- // the changes of metrics any more, so make a deep copy of it
- val copiedMetrics = Utils.deserialize[TaskMetrics](Utils.serialize(metrics))
- tasksMetrics += ((taskRunner.taskId, copiedMetrics))
- } else {
- // It will be copied by serialization
- tasksMetrics += ((taskRunner.taskId, metrics))
- }
+ accumUpdates += ((taskRunner.taskId, metrics.accumulatorUpdates()))
}
}
}
- val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId)
+ val message = Heartbeat(executorId, accumUpdates.toArray, env.blockManager.blockManagerId)
try {
val response = heartbeatReceiverRef.askWithRetry[HeartbeatResponse](
message, RpcTimeout(conf, "spark.executor.heartbeatInterval", "10s"))