aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorReynold Xin <rxin@apache.org>2013-10-10 00:28:00 -0700
committerReynold Xin <rxin@apache.org>2013-10-10 00:28:00 -0700
commit0353f74a9a6882f4faa987a70a256786540a8727 (patch)
tree0ad28d48cef8ada458ec593b26ec1cf31c034de9
parentdbae7795ba489bfc1fedb88155bf42bb4992b006 (diff)
downloadspark-0353f74a9a6882f4faa987a70a256786540a8727.tar.gz
spark-0353f74a9a6882f4faa987a70a256786540a8727.tar.bz2
spark-0353f74a9a6882f4faa987a70a256786540a8727.zip
Put the job cancellation handling into the dagscheduler's main event loop.
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala16
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala46
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala39
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala2
8 files changed, 69 insertions, 44 deletions
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 714f4303c3..93303a9d36 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -338,15 +338,7 @@ class DAGScheduler(
*/
def killJob(jobId: Int): Unit = this.synchronized {
logInfo("Asked to kill job " + jobId)
- activeJobs.find(job => job.jobId == jobId).foreach { job =>
- killStage(job, job.finalStage)
- }
-
- def killStage(job: ActiveJob, stage: Stage): Unit = this.synchronized {
- logDebug("Killing stage %s".format(stage.id))
- taskSched.killTasks(stage.id)
- stage.parents.foreach(parentStage => killStage(job, parentStage))
- }
+ eventQueue.put(JobCancelled(jobId))
}
/**
@@ -375,6 +367,12 @@ class DAGScheduler(
submitStage(finalStage)
}
+ case JobCancelled(jobId) =>
+ // Cancel a job: find all the running stages that are linked to this job, and cancel them.
+ running.find(_.jobId == jobId).foreach { stage =>
+ taskSched.cancelTasks(stage.id)
+ }
+
case ExecutorGained(execId, host) =>
handleExecutorGained(execId, host)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
index 8dd85694ab..0d4d4edc55 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
@@ -44,6 +44,8 @@ private[scheduler] case class JobSubmitted(
properties: Properties = null)
extends DAGSchedulerEvent
+private[scheduler] case class JobCancelled(jobId: Int) extends DAGSchedulerEvent
+
private[scheduler]
case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
index 62b521ad45..466baf9913 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
@@ -54,7 +54,7 @@ trait SparkListener {
/**
* Called when a task starts
*/
- def onTaskStart(taskEnd: SparkListenerTaskStart) { }
+ def onTaskStart(taskStart: SparkListenerTaskStart) { }
/**
* Called when a task ends
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
index d25b0a5e0d..6a51efe8d6 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
@@ -45,8 +45,8 @@ private[spark] trait TaskScheduler {
// Submit a sequence of tasks to run.
def submitTasks(taskSet: TaskSet): Unit
- // Kill the stage.
- def killTasks(stageId: Int)
+ // Cancel a stage.
+ def cancelTasks(stageId: Int)
// Set a listener for upcalls. This is guaranteed to be set before submitTasks is called.
def setListener(listener: TaskSchedulerListener): Unit
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
index be0dabf4b9..031d0b1ef7 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
@@ -17,7 +17,6 @@
package org.apache.spark.scheduler.cluster
-import java.lang.{Boolean => JBoolean}
import java.nio.ByteBuffer
import java.util.concurrent.atomic.AtomicLong
import java.util.{TimerTask, Timer}
@@ -171,28 +170,37 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
backend.reviveOffers()
}
- override def killTasks(stageId: Int): Unit = synchronized {
- schedulableBuilder.getTaskSetManagers(stageId).foreach { t =>
- // Notify the executors to kill the tasks.
- val ts = t.asInstanceOf[TaskSetManager].taskSet
- val taskIds = taskSetTaskIds(ts.id)
- taskIds.foreach { tid =>
- val execId = taskIdToExecutorId(tid)
- backend.killTask(tid, execId)
+ override def cancelTasks(stageId: Int): Unit = synchronized {
+ logInfo("Cancelling stage " + stageId)
+ schedulableBuilder.getTaskSetManagers(stageId).foreach { case tsm: TaskSetManager =>
+ // There are two possible cases here:
+ // 1. The task set manager has been created and some tasks have been scheduled.
+ // In this case, send a kill signal to the executors to kill the task.
+ // 2. The task set manager has been created but no tasks has been scheduled. In this case,
+ // simply abort the task set.
+ val taskIds = taskSetTaskIds(tsm.taskSet.id)
+ if (taskIds.size > 0) {
+ taskIds.foreach { tid =>
+ val execId = taskIdToExecutorId(tid)
+ backend.killTask(tid, execId)
+ }
+ } else {
+ tsm.error("Stage %d was cancelled before any tasks was launched".format(stageId))
}
}
}
- def taskSetFinished(manager: TaskSetManager) {
- this.synchronized {
- if (activeTaskSets.contains(manager.taskSet.id)) {
- activeTaskSets -= manager.taskSet.id
- manager.parent.removeSchedulable(manager)
- logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name))
- taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
- taskIdToExecutorId --= taskSetTaskIds(manager.taskSet.id)
- taskSetTaskIds.remove(manager.taskSet.id)
- }
+ def taskSetFinished(manager: TaskSetManager): Unit = synchronized {
+ // Check to see if the given task set has been removed. This is possible in the case of
+ // multiple unrecoverable task failures (e.g. if the entire task set is killed when it has
+ // more than one running tasks).
+ if (activeTaskSets.contains(manager.taskSet.id)) {
+ activeTaskSets -= manager.taskSet.id
+ manager.parent.removeSchedulable(manager)
+ logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name))
+ taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
+ taskIdToExecutorId --= taskSetTaskIds(manager.taskSet.id)
+ taskSetTaskIds.remove(manager.taskSet.id)
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
index 35762b9b01..e132182231 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
@@ -138,7 +138,7 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
}
}
- override def killTasks(stageId: Int): Unit = synchronized {
+ override def cancelTasks(stageId: Int): Unit = synchronized {
schedulableBuilder.getTaskSetManagers(stageId).foreach { sched =>
val taskIds = taskSetTaskIds(sched.asInstanceOf[TaskSetManager].taskSet.id)
for (tid <- taskIds) {
diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
index 0fd96ed3b1..758670bdbf 100644
--- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
@@ -29,6 +29,8 @@ import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.apache.spark.SparkContext._
import org.apache.spark.{SparkContext, SparkException, LocalSparkContext}
+import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart}
+import org.apache.spark.scheduler._
class AsyncRDDActionsSuite extends FunSuite with BeforeAndAfterAll {
@@ -46,24 +48,39 @@ class AsyncRDDActionsSuite extends FunSuite with BeforeAndAfterAll {
lazy val zeroPartRdd = new EmptyRDD[Int](sc)
- test("job cancellation") {
- val f = sc.parallelize(1 to 1000, 2).map { i => Thread.sleep(1000); i }.countAsync()
+ test("job cancellation before any tasks is launched") {
+ val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.countAsync()
+ future { f.cancel() }
+ val e = intercept[SparkException] { f.get() }
+ assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed"))
+ }
+ test("job cancellation after some tasks have been launched") {
+ // Add a listener to release the semaphore once any tasks are launched.
val sem = new Semaphore(0)
+ sc.dagScheduler.addSparkListener(new SparkListener {
+ override def onTaskStart(taskStart: SparkListenerTaskStart) {
+ sem.release()
+ }
+ })
+
+ val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.countAsync()
future {
- //sem.acquire()
- Thread.sleep(1000)
+ // Wait until some tasks were launched before we cancel the job.
+ sem.acquire()
f.cancel()
- println("killing previous job")
- }
-
- intercept[SparkException] {
- println("lalalalalala")
- println(f.get())
- println("hahahahah")
}
+ val e = intercept[SparkException] { f.get() }
+ assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed"))
+ }
+ test("cancelling take action") {
+ val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.takeAsync(5000)
+ future { f.cancel() }
+ val e = intercept[SparkException] { f.get() }
+ assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed"))
}
+
//
// test("countAsync") {
// assert(zeroPartRdd.countAsync().get() === 0)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 6643c9d504..5e7544452e 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -60,7 +60,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch)
taskSets += taskSet
}
- override def killTasks(stageId: Int) {}
+ override def cancelTasks(stageId: Int) {}
override def setListener(listener: TaskSchedulerListener) = {}
override def defaultParallelism() = 2
}