diff options
Diffstat (limited to 'core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala')
-rw-r--r-- | core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala | 102 |
1 files changed, 61 insertions, 41 deletions
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 6b01a10fc1..897479b500 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -208,11 +208,10 @@ class DAGScheduler( task: Task[_], reason: TaskEndReason, result: Any, - accumUpdates: Map[Long, Any], - taskInfo: TaskInfo, - taskMetrics: TaskMetrics): Unit = { + accumUpdates: Seq[AccumulableInfo], + taskInfo: TaskInfo): Unit = { eventProcessLoop.post( - CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics)) + CompletionEvent(task, reason, result, accumUpdates, taskInfo)) } /** @@ -222,9 +221,10 @@ class DAGScheduler( */ def executorHeartbeatReceived( execId: String, - taskMetrics: Array[(Long, Int, Int, TaskMetrics)], // (taskId, stageId, stateAttempt, metrics) + // (taskId, stageId, stageAttemptId, accumUpdates) + accumUpdates: Array[(Long, Int, Int, Seq[AccumulableInfo])], blockManagerId: BlockManagerId): Boolean = { - listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, taskMetrics)) + listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, accumUpdates)) blockManagerMaster.driverEndpoint.askWithRetry[Boolean]( BlockManagerHeartbeat(blockManagerId), new RpcTimeout(600 seconds, "BlockManagerHeartbeat")) } @@ -1074,39 +1074,43 @@ class DAGScheduler( } } - /** Merge updates from a task to our local accumulator values */ + /** + * Merge local values from a task into the corresponding accumulators previously registered + * here on the driver. + * + * Although accumulators themselves are not thread-safe, this method is called only from one + * thread, the one that runs the scheduling loop. This means we only handle one task + * completion event at a time so we don't need to worry about locking the accumulators. + * This still doesn't stop the caller from updating the accumulator outside the scheduler, + * but that's not our problem since there's nothing we can do about that. + */ private def updateAccumulators(event: CompletionEvent): Unit = { val task = event.task val stage = stageIdToStage(task.stageId) - if (event.accumUpdates != null) { - try { - Accumulators.add(event.accumUpdates) - - event.accumUpdates.foreach { case (id, partialValue) => - // In this instance, although the reference in Accumulators.originals is a WeakRef, - // it's guaranteed to exist since the event.accumUpdates Map exists - - val acc = Accumulators.originals(id).get match { - case Some(accum) => accum.asInstanceOf[Accumulable[Any, Any]] - case None => throw new NullPointerException("Non-existent reference to Accumulator") - } - - // To avoid UI cruft, ignore cases where value wasn't updated - if (acc.name.isDefined && partialValue != acc.zero) { - val name = acc.name.get - val value = s"${acc.value}" - stage.latestInfo.accumulables(id) = - new AccumulableInfo(id, name, None, value, acc.isInternal) - event.taskInfo.accumulables += - new AccumulableInfo(id, name, Some(s"$partialValue"), value, acc.isInternal) - } + try { + event.accumUpdates.foreach { ainfo => + assert(ainfo.update.isDefined, "accumulator from task should have a partial value") + val id = ainfo.id + val partialValue = ainfo.update.get + // Find the corresponding accumulator on the driver and update it + val acc: Accumulable[Any, Any] = Accumulators.get(id) match { + case Some(accum) => accum.asInstanceOf[Accumulable[Any, Any]] + case None => + throw new SparkException(s"attempted to access non-existent accumulator $id") + } + acc ++= partialValue + // To avoid UI cruft, ignore cases where value wasn't updated + if (acc.name.isDefined && partialValue != acc.zero) { + val name = acc.name + stage.latestInfo.accumulables(id) = new AccumulableInfo( + id, name, None, Some(acc.value), acc.isInternal, acc.countFailedValues) + event.taskInfo.accumulables += new AccumulableInfo( + id, name, Some(partialValue), Some(acc.value), acc.isInternal, acc.countFailedValues) } - } catch { - // If we see an exception during accumulator update, just log the - // error and move on. - case e: Exception => - logError(s"Failed to update accumulators for $task", e) } + } catch { + case NonFatal(e) => + logError(s"Failed to update accumulators for task ${task.partitionId}", e) } } @@ -1116,6 +1120,7 @@ class DAGScheduler( */ private[scheduler] def handleTaskCompletion(event: CompletionEvent) { val task = event.task + val taskId = event.taskInfo.id val stageId = task.stageId val taskType = Utils.getFormattedClassName(task) @@ -1125,12 +1130,26 @@ class DAGScheduler( event.taskInfo.attemptNumber, // this is a task attempt number event.reason) - // The success case is dealt with separately below, since we need to compute accumulator - // updates before posting. + // Reconstruct task metrics. Note: this may be null if the task has failed. + val taskMetrics: TaskMetrics = + if (event.accumUpdates.nonEmpty) { + try { + TaskMetrics.fromAccumulatorUpdates(event.accumUpdates) + } catch { + case NonFatal(e) => + logError(s"Error when attempting to reconstruct metrics for task $taskId", e) + null + } + } else { + null + } + + // The success case is dealt with separately below. + // TODO: Why post it only for failed tasks in cancelled stages? Clarify semantics here. if (event.reason != Success) { val attemptId = task.stageAttemptId - listenerBus.post(SparkListenerTaskEnd(stageId, attemptId, taskType, event.reason, - event.taskInfo, event.taskMetrics)) + listenerBus.post(SparkListenerTaskEnd( + stageId, attemptId, taskType, event.reason, event.taskInfo, taskMetrics)) } if (!stageIdToStage.contains(task.stageId)) { @@ -1142,7 +1161,7 @@ class DAGScheduler( event.reason match { case Success => listenerBus.post(SparkListenerTaskEnd(stageId, stage.latestInfo.attemptId, taskType, - event.reason, event.taskInfo, event.taskMetrics)) + event.reason, event.taskInfo, taskMetrics)) stage.pendingPartitions -= task.partitionId task match { case rt: ResultTask[_, _] => @@ -1291,7 +1310,8 @@ class DAGScheduler( // Do nothing here, left up to the TaskScheduler to decide how to handle denied commits case exceptionFailure: ExceptionFailure => - // Do nothing here, left up to the TaskScheduler to decide how to handle user failures + // Tasks failed with exceptions might still have accumulator updates. + updateAccumulators(event) case TaskResultLost => // Do nothing here; the TaskScheduler handles these failures and resubmits the task. @@ -1637,7 +1657,7 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler case GettingResultEvent(taskInfo) => dagScheduler.handleGetTaskResult(taskInfo) - case completion @ CompletionEvent(task, reason, _, _, taskInfo, taskMetrics) => + case completion: CompletionEvent => dagScheduler.handleTaskCompletion(completion) case TaskSetFailed(taskSet, reason, exception) => |