aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala')
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala102
1 files changed, 61 insertions, 41 deletions
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 6b01a10fc1..897479b500 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -208,11 +208,10 @@ class DAGScheduler(
task: Task[_],
reason: TaskEndReason,
result: Any,
- accumUpdates: Map[Long, Any],
- taskInfo: TaskInfo,
- taskMetrics: TaskMetrics): Unit = {
+ accumUpdates: Seq[AccumulableInfo],
+ taskInfo: TaskInfo): Unit = {
eventProcessLoop.post(
- CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics))
+ CompletionEvent(task, reason, result, accumUpdates, taskInfo))
}
/**
@@ -222,9 +221,10 @@ class DAGScheduler(
*/
def executorHeartbeatReceived(
execId: String,
- taskMetrics: Array[(Long, Int, Int, TaskMetrics)], // (taskId, stageId, stateAttempt, metrics)
+ // (taskId, stageId, stageAttemptId, accumUpdates)
+ accumUpdates: Array[(Long, Int, Int, Seq[AccumulableInfo])],
blockManagerId: BlockManagerId): Boolean = {
- listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, taskMetrics))
+ listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, accumUpdates))
blockManagerMaster.driverEndpoint.askWithRetry[Boolean](
BlockManagerHeartbeat(blockManagerId), new RpcTimeout(600 seconds, "BlockManagerHeartbeat"))
}
@@ -1074,39 +1074,43 @@ class DAGScheduler(
}
}
- /** Merge updates from a task to our local accumulator values */
+ /**
+ * Merge local values from a task into the corresponding accumulators previously registered
+ * here on the driver.
+ *
+ * Although accumulators themselves are not thread-safe, this method is called only from one
+ * thread, the one that runs the scheduling loop. This means we only handle one task
+ * completion event at a time so we don't need to worry about locking the accumulators.
+ * This still doesn't stop the caller from updating the accumulator outside the scheduler,
+ * but that's not our problem since there's nothing we can do about that.
+ */
private def updateAccumulators(event: CompletionEvent): Unit = {
val task = event.task
val stage = stageIdToStage(task.stageId)
- if (event.accumUpdates != null) {
- try {
- Accumulators.add(event.accumUpdates)
-
- event.accumUpdates.foreach { case (id, partialValue) =>
- // In this instance, although the reference in Accumulators.originals is a WeakRef,
- // it's guaranteed to exist since the event.accumUpdates Map exists
-
- val acc = Accumulators.originals(id).get match {
- case Some(accum) => accum.asInstanceOf[Accumulable[Any, Any]]
- case None => throw new NullPointerException("Non-existent reference to Accumulator")
- }
-
- // To avoid UI cruft, ignore cases where value wasn't updated
- if (acc.name.isDefined && partialValue != acc.zero) {
- val name = acc.name.get
- val value = s"${acc.value}"
- stage.latestInfo.accumulables(id) =
- new AccumulableInfo(id, name, None, value, acc.isInternal)
- event.taskInfo.accumulables +=
- new AccumulableInfo(id, name, Some(s"$partialValue"), value, acc.isInternal)
- }
+ try {
+ event.accumUpdates.foreach { ainfo =>
+ assert(ainfo.update.isDefined, "accumulator from task should have a partial value")
+ val id = ainfo.id
+ val partialValue = ainfo.update.get
+ // Find the corresponding accumulator on the driver and update it
+ val acc: Accumulable[Any, Any] = Accumulators.get(id) match {
+ case Some(accum) => accum.asInstanceOf[Accumulable[Any, Any]]
+ case None =>
+ throw new SparkException(s"attempted to access non-existent accumulator $id")
+ }
+ acc ++= partialValue
+ // To avoid UI cruft, ignore cases where value wasn't updated
+ if (acc.name.isDefined && partialValue != acc.zero) {
+ val name = acc.name
+ stage.latestInfo.accumulables(id) = new AccumulableInfo(
+ id, name, None, Some(acc.value), acc.isInternal, acc.countFailedValues)
+ event.taskInfo.accumulables += new AccumulableInfo(
+ id, name, Some(partialValue), Some(acc.value), acc.isInternal, acc.countFailedValues)
}
- } catch {
- // If we see an exception during accumulator update, just log the
- // error and move on.
- case e: Exception =>
- logError(s"Failed to update accumulators for $task", e)
}
+ } catch {
+ case NonFatal(e) =>
+ logError(s"Failed to update accumulators for task ${task.partitionId}", e)
}
}
@@ -1116,6 +1120,7 @@ class DAGScheduler(
*/
private[scheduler] def handleTaskCompletion(event: CompletionEvent) {
val task = event.task
+ val taskId = event.taskInfo.id
val stageId = task.stageId
val taskType = Utils.getFormattedClassName(task)
@@ -1125,12 +1130,26 @@ class DAGScheduler(
event.taskInfo.attemptNumber, // this is a task attempt number
event.reason)
- // The success case is dealt with separately below, since we need to compute accumulator
- // updates before posting.
+ // Reconstruct task metrics. Note: this may be null if the task has failed.
+ val taskMetrics: TaskMetrics =
+ if (event.accumUpdates.nonEmpty) {
+ try {
+ TaskMetrics.fromAccumulatorUpdates(event.accumUpdates)
+ } catch {
+ case NonFatal(e) =>
+ logError(s"Error when attempting to reconstruct metrics for task $taskId", e)
+ null
+ }
+ } else {
+ null
+ }
+
+ // The success case is dealt with separately below.
+ // TODO: Why post it only for failed tasks in cancelled stages? Clarify semantics here.
if (event.reason != Success) {
val attemptId = task.stageAttemptId
- listenerBus.post(SparkListenerTaskEnd(stageId, attemptId, taskType, event.reason,
- event.taskInfo, event.taskMetrics))
+ listenerBus.post(SparkListenerTaskEnd(
+ stageId, attemptId, taskType, event.reason, event.taskInfo, taskMetrics))
}
if (!stageIdToStage.contains(task.stageId)) {
@@ -1142,7 +1161,7 @@ class DAGScheduler(
event.reason match {
case Success =>
listenerBus.post(SparkListenerTaskEnd(stageId, stage.latestInfo.attemptId, taskType,
- event.reason, event.taskInfo, event.taskMetrics))
+ event.reason, event.taskInfo, taskMetrics))
stage.pendingPartitions -= task.partitionId
task match {
case rt: ResultTask[_, _] =>
@@ -1291,7 +1310,8 @@ class DAGScheduler(
// Do nothing here, left up to the TaskScheduler to decide how to handle denied commits
case exceptionFailure: ExceptionFailure =>
- // Do nothing here, left up to the TaskScheduler to decide how to handle user failures
+ // Tasks failed with exceptions might still have accumulator updates.
+ updateAccumulators(event)
case TaskResultLost =>
// Do nothing here; the TaskScheduler handles these failures and resubmits the task.
@@ -1637,7 +1657,7 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler
case GettingResultEvent(taskInfo) =>
dagScheduler.handleGetTaskResult(taskInfo)
- case completion @ CompletionEvent(task, reason, _, _, taskInfo, taskMetrics) =>
+ case completion: CompletionEvent =>
dagScheduler.handleTaskCompletion(completion)
case TaskSetFailed(taskSet, reason, exception) =>