aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/spark/scheduler/DAGScheduler.scala
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main/scala/spark/scheduler/DAGScheduler.scala')
-rw-r--r--core/src/main/scala/spark/scheduler/DAGScheduler.scala17
1 files changed, 9 insertions, 8 deletions
diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
index 5c71207d43..29757b1178 100644
--- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
@@ -16,8 +16,8 @@ import spark.storage.BlockManagerMaster
import spark.storage.BlockManagerId
/**
- * A Scheduler subclass that implements stage-oriented scheduling. It computes a DAG of stages for
- * each job, keeps track of which RDDs and stage outputs are materialized, and computes a minimal
+ * A Scheduler subclass that implements stage-oriented scheduling. It computes a DAG of stages for
+ * each job, keeps track of which RDDs and stage outputs are materialized, and computes a minimal
* schedule to run the job. Subclasses only need to implement the code to send a task to the cluster
* and to report fetch failures (the submitTasks method, and code to add CompletionEvents).
*/
@@ -73,7 +73,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val deadHosts = new HashSet[String] // TODO: The code currently assumes these can't come back;
// that's not going to be a realistic assumption in general
-
+
val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done
val running = new HashSet[Stage] // Stages we are running right now
val failed = new HashSet[Stage] // Stages that must be resubmitted due to fetch failures
@@ -94,7 +94,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
cacheLocs(rdd.id)
}
-
+
def updateCacheLocs() {
cacheLocs = cacheTracker.getLocationsSnapshot()
}
@@ -326,7 +326,8 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val rdd = job.finalStage.rdd
val split = rdd.splits(job.partitions(0))
val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0)
- val result = job.func(taskContext, rdd.iterator(split))
+ val result = job.func(taskContext, rdd.iterator(split, taskContext))
+ taskContext.executeOnCompleteCallbacks()
job.listener.taskSucceeded(0, result)
} catch {
case e: Exception =>
@@ -353,7 +354,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
}
}
}
-
+
def submitMissingTasks(stage: Stage) {
logDebug("submitMissingTasks(" + stage + ")")
// Get our pending tasks and remember them in our pendingTasks entry
@@ -395,7 +396,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
val task = event.task
val stage = idToStage(task.stageId)
event.reason match {
- case Success =>
+ case Success =>
logInfo("Completed " + task)
if (event.accumUpdates != null) {
Accumulators.add(event.accumUpdates) // TODO: do this only if task wasn't resubmitted
@@ -519,7 +520,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with
updateCacheLocs()
}
}
-
+
/**
* Aborts all jobs depending on a particular Stage. This is called in response to a task set
* being cancelled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.