diff options
3 files changed, 25 insertions, 9 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index a70be16f77..3904f7d106 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -433,6 +433,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Thread Local variable that can be used by users to pass information down the stack private val localProperties = new InheritableThreadLocal[Properties] { override protected def childValue(parent: Properties): Properties = new Properties(parent) + override protected def initialValue(): Properties = new Properties() } /** @@ -474,9 +475,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Spark fair scheduler pool. */ def setLocalProperty(key: String, value: String) { - if (localProperties.get() == null) { - localProperties.set(new Properties()) - } if (value == null) { localProperties.get.remove(key) } else { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index d35b4f9dba..7227fa9da4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -493,7 +493,7 @@ class DAGScheduler( callSite: CallSite, allowLocal: Boolean, resultHandler: (Int, U) => Unit, - properties: Properties = null): JobWaiter[U] = { + properties: Properties): JobWaiter[U] = { // Check to make sure we are not launching a task on a partition that does not exist. val maxPartitions = rdd.partitions.length partitions.find(p => p >= maxPartitions || p < 0).foreach { p => @@ -522,7 +522,7 @@ class DAGScheduler( callSite: CallSite, allowLocal: Boolean, resultHandler: (Int, U) => Unit, - properties: Properties = null): Unit = { + properties: Properties): Unit = { val start = System.nanoTime val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, resultHandler, properties) waiter.awaitResult() match { @@ -542,7 +542,7 @@ class DAGScheduler( evaluator: ApproximateEvaluator[U, R], callSite: CallSite, timeout: Long, - properties: Properties = null): PartialResult[R] = { + properties: Properties): PartialResult[R] = { val listener = new ApproximateActionListener(rdd, func, evaluator, timeout) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val partitions = (0 until rdd.partitions.size).toArray @@ -689,7 +689,7 @@ class DAGScheduler( // Cancel all jobs belonging to this job group. // First finds all active jobs with this group id, and then kill stages for them. val activeInGroup = activeJobs.filter(activeJob => - groupId == activeJob.properties.get(SparkContext.SPARK_JOB_GROUP_ID)) + Option(activeJob.properties).exists(_.get(SparkContext.SPARK_JOB_GROUP_ID) == groupId)) val jobIds = activeInGroup.map(_.jobId) jobIds.foreach(handleJobCancellation(_, "part of cancelled job group %s".format(groupId))) submitWaitingStages() @@ -736,7 +736,7 @@ class DAGScheduler( allowLocal: Boolean, callSite: CallSite, listener: JobListener, - properties: Properties = null) { + properties: Properties) { var finalStage: ResultStage = null try { // New stage creation may throw an exception if, for example, jobs are run on a diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index b07c4d93db..c7301a30d8 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark import java.io.File +import java.util.concurrent.TimeUnit import com.google.common.base.Charsets._ import com.google.common.io.Files @@ -25,9 +26,11 @@ import com.google.common.io.Files import org.scalatest.FunSuite import org.apache.hadoop.io.BytesWritable - import org.apache.spark.util.Utils +import scala.concurrent.Await +import scala.concurrent.duration.Duration + class SparkContextSuite extends FunSuite with LocalSparkContext { test("Only one SparkContext may be active at a time") { @@ -173,4 +176,19 @@ class SparkContextSuite extends FunSuite with LocalSparkContext { sc.stop() } } + + test("Cancelling job group should not cause SparkContext to shutdown (SPARK-6414)") { + try { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + val future = sc.parallelize(Seq(0)).foreachAsync(_ => {Thread.sleep(1000L)}) + sc.cancelJobGroup("nonExistGroupId") + Await.ready(future, Duration(2, TimeUnit.SECONDS)) + + // In SPARK-6414, sc.cancelJobGroup will cause NullPointerException and cause + // SparkContext to shutdown, so the following assertion will fail. + assert(sc.parallelize(1 to 10).count() == 10L) + } finally { + sc.stop() + } + } } |