aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/spark/scheduler/DAGScheduler.scala
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main/scala/spark/scheduler/DAGScheduler.scala')
-rw-r--r--core/src/main/scala/spark/scheduler/DAGScheduler.scala33
1 files changed, 18 insertions, 15 deletions
diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
index 6f4c6bffd7..29757b1178 100644
--- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
@@ -16,8 +16,8 @@ import spark.storage.BlockManagerMaster
import spark.storage.BlockManagerId
/**
- * A Scheduler subclass that implements stage-oriented scheduling. It computes a DAG of stages for
- * each job, keeps track of which RDDs and stage outputs are materialized, and computes a minimal
+ * A Scheduler subclass that implements stage-oriented scheduling. It computes a DAG of stages for
+ * each job, keeps track of which RDDs and stage outputs are materialized, and computes a minimal
* schedule to run the job. Subclasses only need to implement the code to send a task to the cluster
* and to report fetch failures (the submitTasks method, and code to add CompletionEvents).
*/
@@ -73,7 +73,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val deadHosts = new HashSet[String] // TODO: The code currently assumes these can't come back;
// that's not going to be a realistic assumption in general
-
+
val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done
val running = new HashSet[Stage] // Stages we are running right now
val failed = new HashSet[Stage] // Stages that must be resubmitted due to fetch failures
@@ -94,7 +94,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
cacheLocs(rdd.id)
}
-
+
def updateCacheLocs() {
cacheLocs = cacheTracker.getLocationsSnapshot()
}
@@ -104,7 +104,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* The priority value passed in will be used if the stage doesn't already exist with
* a lower priority (we assume that priorities always increase across jobs for now).
*/
- def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_,_], priority: Int): Stage = {
+ def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], priority: Int): Stage = {
shuffleToMapStage.get(shuffleDep.shuffleId) match {
case Some(stage) => stage
case None =>
@@ -119,7 +119,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
* as a result stage for the final RDD used directly in an action. The stage will also be given
* the provided priority.
*/
- def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_,_]], priority: Int): Stage = {
+ def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_]], priority: Int): Stage = {
// Kind of ugly: need to register RDDs with the cache and map output tracker here
// since we can't do it in the RDD constructor because # of splits is unknown
logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")")
@@ -149,7 +149,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
cacheTracker.registerRDD(r.id, r.splits.size)
for (dep <- r.dependencies) {
dep match {
- case shufDep: ShuffleDependency[_,_,_] =>
+ case shufDep: ShuffleDependency[_,_] =>
parents += getShuffleMapStage(shufDep, priority)
case _ =>
visit(dep.rdd)
@@ -172,7 +172,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
if (locs(p) == Nil) {
for (dep <- rdd.dependencies) {
dep match {
- case shufDep: ShuffleDependency[_,_,_] =>
+ case shufDep: ShuffleDependency[_,_] =>
val mapStage = getShuffleMapStage(shufDep, stage.priority)
if (!mapStage.isAvailable) {
missing += mapStage
@@ -326,7 +326,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val rdd = job.finalStage.rdd
val split = rdd.splits(job.partitions(0))
val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0)
- val result = job.func(taskContext, rdd.iterator(split))
+ val result = job.func(taskContext, rdd.iterator(split, taskContext))
+ taskContext.executeOnCompleteCallbacks()
job.listener.taskSucceeded(0, result)
} catch {
case e: Exception =>
@@ -353,7 +354,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
}
}
}
-
+
def submitMissingTasks(stage: Stage) {
logDebug("submitMissingTasks(" + stage + ")")
// Get our pending tasks and remember them in our pendingTasks entry
@@ -395,7 +396,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val task = event.task
val stage = idToStage(task.stageId)
event.reason match {
- case Success =>
+ case Success =>
logInfo("Completed " + task)
if (event.accumUpdates != null) {
Accumulators.add(event.accumUpdates) // TODO: do this only if task wasn't resubmitted
@@ -479,8 +480,10 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
") for resubmision due to a fetch failure")
// Mark the map whose fetch failed as broken in the map stage
val mapStage = shuffleToMapStage(shuffleId)
- mapStage.removeOutputLoc(mapId, bmAddress)
- mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
+ if (mapId != -1) {
+ mapStage.removeOutputLoc(mapId, bmAddress)
+ mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
+ }
logInfo("The failed fetch was from " + mapStage + " (" + mapStage.origin +
"); marking it for resubmission")
failed += mapStage
@@ -517,7 +520,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
updateCacheLocs()
}
}
-
+
/**
* Aborts all jobs depending on a particular Stage. This is called in response to a task set
* being cancelled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.
@@ -549,7 +552,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
visitedRdds += rdd
for (dep <- rdd.dependencies) {
dep match {
- case shufDep: ShuffleDependency[_,_,_] =>
+ case shufDep: ShuffleDependency[_,_] =>
val mapStage = getShuffleMapStage(shufDep, stage.priority)
if (!mapStage.isAvailable) {
visitedStages += mapStage