diff options
author | Matei Zaharia <matei@eecs.berkeley.edu> | 2013-10-18 22:49:00 -0700 |
---|---|---|
committer | Matei Zaharia <matei@eecs.berkeley.edu> | 2013-10-18 22:49:00 -0700 |
commit | 599dcb0ddf740e028cc8faac163303be8f9400a6 (patch) | |
tree | 1c2be699552c17bf3860298570952e4048f00ed9 | |
parent | 8de9706b86f41a37464f55e1ffe5a246adc712d1 (diff) | |
parent | 806f3a3adb19dab2ffe864226b6e5438015222eb (diff) | |
download | spark-599dcb0ddf740e028cc8faac163303be8f9400a6.tar.gz spark-599dcb0ddf740e028cc8faac163303be8f9400a6.tar.bz2 spark-599dcb0ddf740e028cc8faac163303be8f9400a6.zip |
Merge pull request #74 from rxin/kill
Job cancellation via job group id.
This PR adds a simple API to group together a set of jobs belonging to a thread and threads spawned from it. It also allows the cancellation of all jobs in this group.
An example:
sc.setJobDescription("this_is_the_group_id", "some job description")
sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.count()
In a separate thread:
sc.cancelJobGroup("this_is_the_group_id")
4 files changed, 75 insertions, 7 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 3ed9caa242..48bbc78795 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -287,8 +287,19 @@ class SparkContext( Option(localProperties.get).map(_.getProperty(key)).getOrElse(null) /** Set a human readable description of the current job. */ + @deprecated("use setJobGroup", "0.8.1") def setJobDescription(value: String) { - setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, value) + setJobGroup("", value) + } + + def setJobGroup(groupId: String, description: String) { + setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, description) + setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, groupId) + } + + def clearJobGroup() { + setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, null) + setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, null) } // Post init @@ -866,10 +877,14 @@ class SparkContext( callSite, allowLocal = false, resultHandler, - null) + localProperties.get) new SimpleFutureAction(waiter, resultFunc) } + def cancelJobGroup(groupId: String) { + dagScheduler.cancelJobGroup(groupId) + } + /** * Cancel all jobs that have been scheduled or are running. */ @@ -933,8 +948,11 @@ class SparkContext( * various Spark features. */ object SparkContext { + val SPARK_JOB_DESCRIPTION = "spark.job.description" + val SPARK_JOB_GROUP_ID = "spark.jobGroup.id" + implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] { def addInPlace(t1: Double, t2: Double): Double = t1 + t2 def zero(initialValue: Double) = 0.0 diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 15a04e6558..d84f5968df 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -277,11 +277,6 @@ class DAGScheduler( resultHandler: (Int, U) => Unit, properties: Properties = null): JobWaiter[U] = { - val jobId = nextJobId.getAndIncrement() - if (partitions.size == 0) { - return new JobWaiter[U](this, jobId, 0, resultHandler) - } - // Check to make sure we are not launching a task on a partition that does not exist. val maxPartitions = rdd.partitions.length partitions.find(p => p >= maxPartitions).foreach { p => @@ -290,6 +285,11 @@ class DAGScheduler( "Total number of partitions: " + maxPartitions) } + val jobId = nextJobId.getAndIncrement() + if (partitions.size == 0) { + return new JobWaiter[U](this, jobId, 0, resultHandler) + } + assert(partitions.size > 0) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler) @@ -342,6 +342,11 @@ class DAGScheduler( eventQueue.put(JobCancelled(jobId)) } + def cancelJobGroup(groupId: String) { + logInfo("Asked to cancel job group " + groupId) + eventQueue.put(JobGroupCancelled(groupId)) + } + /** * Cancel all jobs that are running or waiting in the queue. */ @@ -381,6 +386,17 @@ class DAGScheduler( taskSched.cancelTasks(stage.id) } + case JobGroupCancelled(groupId) => + // Cancel all jobs belonging to this job group. + // First finds all active jobs with this group id, and then kill stages for them. + val jobIds = activeJobs.filter(groupId == _.properties.get(SparkContext.SPARK_JOB_GROUP_ID)) + .map(_.jobId) + if (!jobIds.isEmpty) { + running.filter(stage => jobIds.contains(stage.jobId)).foreach { stage => + taskSched.cancelTasks(stage.id) + } + } + case AllJobsCancelled => // Cancel all running jobs. running.foreach { stage => diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index ee89bfb38d..a5769c6041 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -46,6 +46,8 @@ private[scheduler] case class JobSubmitted( private[scheduler] case class JobCancelled(jobId: Int) extends DAGSchedulerEvent +private[scheduler] case class JobGroupCancelled(groupId: String) extends DAGSchedulerEvent + private[scheduler] case object AllJobsCancelled extends DAGSchedulerEvent private[scheduler] diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index a192651491..d8a0e983b2 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark import java.util.concurrent.Semaphore +import scala.concurrent.Await +import scala.concurrent.duration.Duration import scala.concurrent.future import scala.concurrent.ExecutionContext.Implicits.global @@ -83,6 +85,36 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf assert(sc.parallelize(1 to 10, 2).count === 10) } + test("job group") { + sc = new SparkContext("local[2]", "test") + + // Add a listener to release the semaphore once any tasks are launched. + val sem = new Semaphore(0) + sc.dagScheduler.addSparkListener(new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart) { + sem.release() + } + }) + + // jobA is the one to be cancelled. + val jobA = future { + sc.setJobGroup("jobA", "this is a job to be cancelled") + sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.count() + } + + sc.clearJobGroup() + val jobB = sc.parallelize(1 to 100, 2).countAsync() + + // Block until both tasks of job A have started and cancel job A. + sem.acquire(2) + sc.cancelJobGroup("jobA") + val e = intercept[SparkException] { Await.result(jobA, Duration.Inf) } + assert(e.getMessage contains "cancel") + + // Once A is cancelled, job B should finish fairly quickly. + assert(jobB.get() === 100) + } + test("two jobs sharing the same stage") { // sem1: make sure cancel is issued after some tasks are launched // sem2: make sure the first stage is not finished until cancel is issued |