diff options
author | Reynold Xin <rxin@apache.org> | 2013-10-10 00:28:00 -0700 |
---|---|---|
committer | Reynold Xin <rxin@apache.org> | 2013-10-10 00:28:00 -0700 |
commit | 0353f74a9a6882f4faa987a70a256786540a8727 (patch) | |
tree | 0ad28d48cef8ada458ec593b26ec1cf31c034de9 | |
parent | dbae7795ba489bfc1fedb88155bf42bb4992b006 (diff) | |
download | spark-0353f74a9a6882f4faa987a70a256786540a8727.tar.gz spark-0353f74a9a6882f4faa987a70a256786540a8727.tar.bz2 spark-0353f74a9a6882f4faa987a70a256786540a8727.zip |
Put the job cancellation handling into the dagscheduler's main event loop.
8 files changed, 69 insertions, 44 deletions
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 714f4303c3..93303a9d36 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -338,15 +338,7 @@ class DAGScheduler( */ def killJob(jobId: Int): Unit = this.synchronized { logInfo("Asked to kill job " + jobId) - activeJobs.find(job => job.jobId == jobId).foreach { job => - killStage(job, job.finalStage) - } - - def killStage(job: ActiveJob, stage: Stage): Unit = this.synchronized { - logDebug("Killing stage %s".format(stage.id)) - taskSched.killTasks(stage.id) - stage.parents.foreach(parentStage => killStage(job, parentStage)) - } + eventQueue.put(JobCancelled(jobId)) } /** @@ -375,6 +367,12 @@ class DAGScheduler( submitStage(finalStage) } + case JobCancelled(jobId) => + // Cancel a job: find all the running stages that are linked to this job, and cancel them. + running.find(_.jobId == jobId).foreach { stage => + taskSched.cancelTasks(stage.id) + } + case ExecutorGained(execId, host) => handleExecutorGained(execId, host) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index 8dd85694ab..0d4d4edc55 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -44,6 +44,8 @@ private[scheduler] case class JobSubmitted( properties: Properties = null) extends DAGSchedulerEvent +private[scheduler] case class JobCancelled(jobId: Int) extends DAGSchedulerEvent + private[scheduler] case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 62b521ad45..466baf9913 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -54,7 +54,7 @@ trait SparkListener { /** * Called when a task starts */ - def onTaskStart(taskEnd: SparkListenerTaskStart) { } + def onTaskStart(taskStart: SparkListenerTaskStart) { } /** * Called when a task ends diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index d25b0a5e0d..6a51efe8d6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -45,8 +45,8 @@ private[spark] trait TaskScheduler { // Submit a sequence of tasks to run. def submitTasks(taskSet: TaskSet): Unit - // Kill the stage. - def killTasks(stageId: Int) + // Cancel a stage. + def cancelTasks(stageId: Int) // Set a listener for upcalls. This is guaranteed to be set before submitTasks is called. def setListener(listener: TaskSchedulerListener): Unit diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala index be0dabf4b9..031d0b1ef7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala @@ -17,7 +17,6 @@ package org.apache.spark.scheduler.cluster -import java.lang.{Boolean => JBoolean} import java.nio.ByteBuffer import java.util.concurrent.atomic.AtomicLong import java.util.{TimerTask, Timer} @@ -171,28 +170,37 @@ private[spark] class ClusterScheduler(val sc: SparkContext) backend.reviveOffers() } - override def killTasks(stageId: Int): Unit = synchronized { - schedulableBuilder.getTaskSetManagers(stageId).foreach { t => - // Notify the executors to kill the tasks. - val ts = t.asInstanceOf[TaskSetManager].taskSet - val taskIds = taskSetTaskIds(ts.id) - taskIds.foreach { tid => - val execId = taskIdToExecutorId(tid) - backend.killTask(tid, execId) + override def cancelTasks(stageId: Int): Unit = synchronized { + logInfo("Cancelling stage " + stageId) + schedulableBuilder.getTaskSetManagers(stageId).foreach { case tsm: TaskSetManager => + // There are two possible cases here: + // 1. The task set manager has been created and some tasks have been scheduled. + // In this case, send a kill signal to the executors to kill the task. + // 2. The task set manager has been created but no tasks has been scheduled. In this case, + // simply abort the task set. + val taskIds = taskSetTaskIds(tsm.taskSet.id) + if (taskIds.size > 0) { + taskIds.foreach { tid => + val execId = taskIdToExecutorId(tid) + backend.killTask(tid, execId) + } + } else { + tsm.error("Stage %d was cancelled before any tasks was launched".format(stageId)) } } } - def taskSetFinished(manager: TaskSetManager) { - this.synchronized { - if (activeTaskSets.contains(manager.taskSet.id)) { - activeTaskSets -= manager.taskSet.id - manager.parent.removeSchedulable(manager) - logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name)) - taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id) - taskIdToExecutorId --= taskSetTaskIds(manager.taskSet.id) - taskSetTaskIds.remove(manager.taskSet.id) - } + def taskSetFinished(manager: TaskSetManager): Unit = synchronized { + // Check to see if the given task set has been removed. This is possible in the case of + // multiple unrecoverable task failures (e.g. if the entire task set is killed when it has + // more than one running tasks). + if (activeTaskSets.contains(manager.taskSet.id)) { + activeTaskSets -= manager.taskSet.id + manager.parent.removeSchedulable(manager) + logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name)) + taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id) + taskIdToExecutorId --= taskSetTaskIds(manager.taskSet.id) + taskSetTaskIds.remove(manager.taskSet.id) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala index 35762b9b01..e132182231 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala @@ -138,7 +138,7 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: } } - override def killTasks(stageId: Int): Unit = synchronized { + override def cancelTasks(stageId: Int): Unit = synchronized { schedulableBuilder.getTaskSetManagers(stageId).foreach { sched => val taskIds = taskSetTaskIds(sched.asInstanceOf[TaskSetManager].taskSet.id) for (tid <- taskIds) { diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala index 0fd96ed3b1..758670bdbf 100644 --- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala @@ -29,6 +29,8 @@ import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.apache.spark.SparkContext._ import org.apache.spark.{SparkContext, SparkException, LocalSparkContext} +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} +import org.apache.spark.scheduler._ class AsyncRDDActionsSuite extends FunSuite with BeforeAndAfterAll { @@ -46,24 +48,39 @@ class AsyncRDDActionsSuite extends FunSuite with BeforeAndAfterAll { lazy val zeroPartRdd = new EmptyRDD[Int](sc) - test("job cancellation") { - val f = sc.parallelize(1 to 1000, 2).map { i => Thread.sleep(1000); i }.countAsync() + test("job cancellation before any tasks is launched") { + val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.countAsync() + future { f.cancel() } + val e = intercept[SparkException] { f.get() } + assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed")) + } + test("job cancellation after some tasks have been launched") { + // Add a listener to release the semaphore once any tasks are launched. val sem = new Semaphore(0) + sc.dagScheduler.addSparkListener(new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart) { + sem.release() + } + }) + + val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.countAsync() future { - //sem.acquire() - Thread.sleep(1000) + // Wait until some tasks were launched before we cancel the job. + sem.acquire() f.cancel() - println("killing previous job") - } - - intercept[SparkException] { - println("lalalalalala") - println(f.get()) - println("hahahahah") } + val e = intercept[SparkException] { f.get() } + assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed")) + } + test("cancelling take action") { + val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.takeAsync(5000) + future { f.cancel() } + val e = intercept[SparkException] { f.get() } + assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed")) } + // // test("countAsync") { // assert(zeroPartRdd.countAsync().get() === 0) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 6643c9d504..5e7544452e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -60,7 +60,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch) taskSets += taskSet } - override def killTasks(stageId: Int) {} + override def cancelTasks(stageId: Int) {} override def setListener(listener: TaskSchedulerListener) = {} override def defaultParallelism() = 2 } |