From 686a420ddc33407050d9019711cbe801fc352fa3 Mon Sep 17 00:00:00 2001 From: Mark Hamstra Date: Fri, 22 Nov 2013 10:20:09 -0800 Subject: Refactoring to make job removal, stage removal, task cancellation clearer --- .../org/apache/spark/scheduler/DAGScheduler.scala | 76 +++++++++++----------- 1 file changed, 39 insertions(+), 37 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 6f9d4d52a4..b8b3ac0b43 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -370,9 +370,11 @@ class DAGScheduler( } } - // Removes job and applies p to any stages that aren't needed by any other jobs - private def forIndependentStagesOfRemovedJob(jobId: Int)(p: Int => Unit) { + // Removes job and any stages that are not needed by any other job. Returns the set of ids for stages that + // were removed and whose associated tasks may need to be cancelled. + private def removeJobAndIndependentStages(jobId: Int): Set[Int] = { val registeredStages = jobIdToStageIds(jobId) + val independentStages = new HashSet[Int]() if (registeredStages.isEmpty) { logError("No stages registered for job " + jobId) } else { @@ -382,49 +384,51 @@ class DAGScheduler( logError("Job %d not registered for stage %d even though that stage was registered for the job" .format(jobId, stageId)) } else { + def removeStage(stageId: Int) { + // data structures based on Stage + stageIdToStage.get(stageId).foreach { s => + if (running.contains(s)) { + logDebug("Removing running stage %d".format(stageId)) + running -= s + } + stageToInfos -= s + shuffleToMapStage.keys.filter(shuffleToMapStage(_) == s).foreach(shuffleToMapStage.remove) + if (pendingTasks.contains(s) && !pendingTasks(s).isEmpty) { + logDebug("Removing pending status for stage %d".format(stageId)) + } + pendingTasks -= s + if (waiting.contains(s)) { + logDebug("Removing stage %d from waiting set.".format(stageId)) + waiting -= s + } + if (failed.contains(s)) { + logDebug("Removing stage %d from failed set.".format(stageId)) + failed -= s + } + } + // data structures based on StageId + stageIdToStage -= stageId + stageIdToJobIds -= stageId + + logDebug("After removal of stage %d, remaining stages = %d".format(stageId, stageIdToStage.size)) + } + jobSet -= jobId if (jobSet.isEmpty) { // no other job needs this stage - p(stageId) + independentStages += stageId + removeStage(stageId) } } } } - } - - private def removeStage(stageId: Int) { - // data structures based on Stage - stageIdToStage.get(stageId).foreach { s => - if (running.contains(s)) { - logDebug("Removing running stage %d".format(stageId)) - running -= s - } - stageToInfos -= s - shuffleToMapStage.keys.filter(shuffleToMapStage(_) == s).foreach(shuffleToMapStage.remove(_)) - if (pendingTasks.contains(s) && !pendingTasks(s).isEmpty) { - logDebug("Removing pending status for stage %d".format(stageId)) - } - pendingTasks -= s - if (waiting.contains(s)) { - logDebug("Removing stage %d from waiting set.".format(stageId)) - waiting -= s - } - if (failed.contains(s)) { - logDebug("Removing stage %d from failed set.".format(stageId)) - failed -= s - } - } - // data structures based on StageId - stageIdToStage -= stageId - stageIdToJobIds -= stageId - - logDebug("After removal of stage %d, remaining stages = %d".format(stageId, stageIdToStage.size)) + independentStages.toSet } private def jobIdToStageIdsRemove(jobId: Int) { if (!jobIdToStageIds.contains(jobId)) { logDebug("Trying to remove unregistered job " + jobId) } else { - forIndependentStagesOfRemovedJob(jobId) { removeStage } + removeJobAndIndependentStages(jobId) jobIdToStageIds -= jobId } } @@ -987,10 +991,8 @@ class DAGScheduler( if (!jobIdToStageIds.contains(jobId)) { logDebug("Trying to cancel unregistered job " + jobId) } else { - forIndependentStagesOfRemovedJob(jobId) { stageId => - taskSched.cancelTasks(stageId) - removeStage(stageId) - } + val independentStages = removeJobAndIndependentStages(jobId) + independentStages.foreach { taskSched.cancelTasks } val error = new SparkException("Job %d cancelled".format(jobId)) val job = idToActiveJob(jobId) job.listener.jobFailed(error) -- cgit v1.2.3