aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/CacheManager.scala15
-rw-r--r--core/src/main/scala/org/apache/spark/InterruptibleIterator.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/TaskKilledException.scala23
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala8
-rw-r--r--core/src/test/scala/org/apache/spark/JobCancellationSuite.scala43
5 files changed, 86 insertions, 15 deletions
diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala
index c7893f288b..811610c657 100644
--- a/core/src/main/scala/org/apache/spark/CacheManager.scala
+++ b/core/src/main/scala/org/apache/spark/CacheManager.scala
@@ -47,7 +47,12 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
if (loading.contains(key)) {
logInfo("Another thread is loading %s, waiting for it to finish...".format(key))
while (loading.contains(key)) {
- try {loading.wait()} catch {case _ : Throwable =>}
+ try {
+ loading.wait()
+ } catch {
+ case e: Exception =>
+ logWarning(s"Got an exception while waiting for another thread to load $key", e)
+ }
}
logInfo("Finished waiting for %s".format(key))
/* See whether someone else has successfully loaded it. The main way this would fail
@@ -72,7 +77,9 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
val computedValues = rdd.computeOrReadCheckpoint(split, context)
// Persist the result, so long as the task is not running locally
- if (context.runningLocally) { return computedValues }
+ if (context.runningLocally) {
+ return computedValues
+ }
// Keep track of blocks with updated statuses
var updatedBlocks = Seq[(BlockId, BlockStatus)]()
@@ -88,7 +95,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
updatedBlocks = blockManager.put(key, computedValues, storageLevel, tellMaster = true)
blockManager.get(key) match {
case Some(values) =>
- new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]])
+ values.asInstanceOf[Iterator[T]]
case None =>
logInfo("Failure to store %s".format(key))
throw new Exception("Block manager failed to return persisted valued")
@@ -107,7 +114,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
val metrics = context.taskMetrics
metrics.updatedBlocks = Some(updatedBlocks)
- returnValue
+ new InterruptibleIterator(context, returnValue)
} finally {
loading.synchronized {
diff --git a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala
index fd1802ba2f..ec11dbbffa 100644
--- a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala
+++ b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala
@@ -24,7 +24,17 @@ package org.apache.spark
private[spark] class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator[T])
extends Iterator[T] {
- def hasNext: Boolean = !context.interrupted && delegate.hasNext
+ def hasNext: Boolean = {
+ // TODO(aarondav/rxin): Check Thread.interrupted instead of context.interrupted if interrupt
+ // is allowed. The assumption is that Thread.interrupted does not have a memory fence in read
+ // (just a volatile field in C), while context.interrupted is a volatile in the JVM, which
+ // introduces an expensive read fence.
+ if (context.interrupted) {
+ throw new TaskKilledException
+ } else {
+ delegate.hasNext
+ }
+ }
def next(): T = delegate.next()
}
diff --git a/core/src/main/scala/org/apache/spark/TaskKilledException.scala b/core/src/main/scala/org/apache/spark/TaskKilledException.scala
new file mode 100644
index 0000000000..cbd6b2866e
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/TaskKilledException.scala
@@ -0,0 +1,23 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+/**
+ * Exception for a task getting killed.
+ */
+private[spark] class TaskKilledException extends RuntimeException
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 914bc205ce..272bcda5f8 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -161,8 +161,6 @@ private[spark] class Executor(
class TaskRunner(execBackend: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer)
extends Runnable {
- object TaskKilledException extends Exception
-
@volatile private var killed = false
@volatile private var task: Task[Any] = _
@@ -200,7 +198,7 @@ private[spark] class Executor(
// causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl
// exception will be caught by the catch block, leading to an incorrect ExceptionFailure
// for the task.
- throw TaskKilledException
+ throw new TaskKilledException
}
attemptedTask = Some(task)
@@ -214,7 +212,7 @@ private[spark] class Executor(
// If the task has been killed, let's fail it.
if (task.killed) {
- throw TaskKilledException
+ throw new TaskKilledException
}
val resultSer = SparkEnv.get.serializer.newInstance()
@@ -257,7 +255,7 @@ private[spark] class Executor(
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
}
- case TaskKilledException | _: InterruptedException if task.killed => {
+ case _: TaskKilledException | _: InterruptedException if task.killed => {
logInfo("Executor killed task " + taskId)
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
}
diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
index 16cfdf11c4..2c8ef405c9 100644
--- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
@@ -84,6 +84,35 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf
assert(sc.parallelize(1 to 10, 2).count === 10)
}
+ test("do not put partially executed partitions into cache") {
+ // In this test case, we create a scenario in which a partition is only partially executed,
+ // and make sure CacheManager does not put that partially executed partition into the
+ // BlockManager.
+ import JobCancellationSuite._
+ sc = new SparkContext("local", "test")
+
+ // Run from 1 to 10, and then block and wait for the task to be killed.
+ val rdd = sc.parallelize(1 to 1000, 2).map { x =>
+ if (x > 10) {
+ taskStartedSemaphore.release()
+ taskCancelledSemaphore.acquire()
+ }
+ x
+ }.cache()
+
+ val rdd1 = rdd.map(x => x)
+
+ future {
+ taskStartedSemaphore.acquire()
+ sc.cancelAllJobs()
+ taskCancelledSemaphore.release(100000)
+ }
+
+ intercept[SparkException] { rdd1.count() }
+ // If the partial block is put into cache, rdd.count() would return a number less than 1000.
+ assert(rdd.count() === 1000)
+ }
+
test("job group") {
sc = new SparkContext("local[2]", "test")
@@ -114,7 +143,6 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf
assert(jobB.get() === 100)
}
-
test("job group with interruption") {
sc = new SparkContext("local[2]", "test")
@@ -145,15 +173,14 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf
assert(jobB.get() === 100)
}
-/*
- test("two jobs sharing the same stage") {
+ ignore("two jobs sharing the same stage") {
// sem1: make sure cancel is issued after some tasks are launched
// sem2: make sure the first stage is not finished until cancel is issued
val sem1 = new Semaphore(0)
val sem2 = new Semaphore(0)
sc = new SparkContext("local[2]", "test")
- sc.dagScheduler.addSparkListener(new SparkListener {
+ sc.addSparkListener(new SparkListener {
override def onTaskStart(taskStart: SparkListenerTaskStart) {
sem1.release()
}
@@ -179,7 +206,7 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf
intercept[SparkException] { f1.get() }
intercept[SparkException] { f2.get() }
}
- */
+
def testCount() {
// Cancel before launching any tasks
{
@@ -238,3 +265,9 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf
}
}
}
+
+
+object JobCancellationSuite {
+ val taskStartedSemaphore = new Semaphore(0)
+ val taskCancelledSemaphore = new Semaphore(0)
+}