aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorLiwei Lin <lwlin7@gmail.com>2016-06-24 10:09:04 -0500
committerImran Rashid <irashid@cloudera.com>2016-06-24 10:09:04 -0500
commita4851ed05053a9b7545a258c9159fd529225c455 (patch)
tree74be27f35aebedc1efb57f3af76460dd6d7e2e40 /core
parentbe88383e15a86d094963de5f7e8792510bc990de (diff)
downloadspark-a4851ed05053a9b7545a258c9159fd529225c455.tar.gz
spark-a4851ed05053a9b7545a258c9159fd529225c455.tar.bz2
spark-a4851ed05053a9b7545a258c9159fd529225c455.zip
[SPARK-15963][CORE] Catch `TaskKilledException` correctly in Executor.TaskRunner
## The problem Before this change, if either of the following cases happened to a task , the task would be marked as `FAILED` instead of `KILLED`: - the task was killed before it was deserialized - `executor.kill()` marked `taskRunner.killed`, but before calling `task.killed()` the worker thread threw the `TaskKilledException` The reason is, in the `catch` block of the current [Executor.TaskRunner](https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/executor/Executor.scala#L362)'s implementation, we are mistakenly catching: ```scala case _: TaskKilledException | _: InterruptedException if task.killed => ... ``` the semantics of which is: - **(**`TaskKilledException` **OR** `InterruptedException`**)** **AND** `task.killed` Then when `TaskKilledException` is thrown but `task.killed` is not marked, we would mark the task as `FAILED` (which should really be `KILLED`). ## What changes were proposed in this pull request? This patch alters the catch condition's semantics from: - **(**`TaskKilledException` **OR** `InterruptedException`**)** **AND** `task.killed` to - `TaskKilledException` **OR** **(**`InterruptedException` **AND** `task.killed`**)** so that we can catch `TaskKilledException` correctly and mark the task as `KILLED` correctly. ## How was this patch tested? Added unit test which failed before the change, ran new test 1000 times manually Author: Liwei Lin <lwlin7@gmail.com> Closes #13685 from lw-lin/fix-task-killed.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala7
-rw-r--r--core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala139
2 files changed, 145 insertions, 1 deletions
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 9a017f29f7..fbf2b86db1 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -359,11 +359,16 @@ private[spark] class Executor(
setTaskFinishedAndClearInterruptStatus()
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
- case _: TaskKilledException | _: InterruptedException if task.killed =>
+ case _: TaskKilledException =>
logInfo(s"Executor killed $taskName (TID $taskId)")
setTaskFinishedAndClearInterruptStatus()
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
+ case _: InterruptedException if task.killed =>
+ logInfo(s"Executor interrupted and killed $taskName (TID $taskId)")
+ setTaskFinishedAndClearInterruptStatus()
+ execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
+
case CausedBy(cDE: CommitDeniedException) =>
val reason = cDE.toTaskEndReason
setTaskFinishedAndClearInterruptStatus()
diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
new file mode 100644
index 0000000000..3e69894b10
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.executor
+
+import java.nio.ByteBuffer
+import java.util.concurrent.CountDownLatch
+
+import scala.collection.mutable.HashMap
+
+import org.mockito.Matchers._
+import org.mockito.Mockito.{mock, when}
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+
+import org.apache.spark._
+import org.apache.spark.TaskState.TaskState
+import org.apache.spark.memory.MemoryManager
+import org.apache.spark.metrics.MetricsSystem
+import org.apache.spark.rpc.RpcEnv
+import org.apache.spark.scheduler.{FakeTask, Task}
+import org.apache.spark.serializer.JavaSerializer
+
+class ExecutorSuite extends SparkFunSuite {
+
+ test("SPARK-15963: Catch `TaskKilledException` correctly in Executor.TaskRunner") {
+ // mock some objects to make Executor.launchTask() happy
+ val conf = new SparkConf
+ val serializer = new JavaSerializer(conf)
+ val mockEnv = mock(classOf[SparkEnv])
+ val mockRpcEnv = mock(classOf[RpcEnv])
+ val mockMetricsSystem = mock(classOf[MetricsSystem])
+ val mockMemoryManager = mock(classOf[MemoryManager])
+ when(mockEnv.conf).thenReturn(conf)
+ when(mockEnv.serializer).thenReturn(serializer)
+ when(mockEnv.rpcEnv).thenReturn(mockRpcEnv)
+ when(mockEnv.metricsSystem).thenReturn(mockMetricsSystem)
+ when(mockEnv.memoryManager).thenReturn(mockMemoryManager)
+ when(mockEnv.closureSerializer).thenReturn(serializer)
+ val serializedTask =
+ Task.serializeWithDependencies(
+ new FakeTask(0),
+ HashMap[String, Long](),
+ HashMap[String, Long](),
+ serializer.newInstance())
+
+ // we use latches to force the program to run in this order:
+ // +-----------------------------+---------------------------------------+
+ // | main test thread | worker thread |
+ // +-----------------------------+---------------------------------------+
+ // | executor.launchTask() | |
+ // | | TaskRunner.run() begins |
+ // | | ... |
+ // | | execBackend.statusUpdate // 1st time |
+ // | executor.killAllTasks(true) | |
+ // | | ... |
+ // | | task = ser.deserialize |
+ // | | ... |
+ // | | execBackend.statusUpdate // 2nd time |
+ // | | ... |
+ // | | TaskRunner.run() ends |
+ // | check results | |
+ // +-----------------------------+---------------------------------------+
+
+ val executorSuiteHelper = new ExecutorSuiteHelper
+
+ val mockExecutorBackend = mock(classOf[ExecutorBackend])
+ when(mockExecutorBackend.statusUpdate(any(), any(), any()))
+ .thenAnswer(new Answer[Unit] {
+ var firstTime = true
+ override def answer(invocationOnMock: InvocationOnMock): Unit = {
+ if (firstTime) {
+ executorSuiteHelper.latch1.countDown()
+ // here between latch1 and latch2, executor.killAllTasks() is called
+ executorSuiteHelper.latch2.await()
+ firstTime = false
+ }
+ else {
+ // save the returned `taskState` and `testFailedReason` into `executorSuiteHelper`
+ val taskState = invocationOnMock.getArguments()(1).asInstanceOf[TaskState]
+ executorSuiteHelper.taskState = taskState
+ val taskEndReason = invocationOnMock.getArguments()(2).asInstanceOf[ByteBuffer]
+ executorSuiteHelper.testFailedReason
+ = serializer.newInstance().deserialize(taskEndReason)
+ // let the main test thread check `taskState` and `testFailedReason`
+ executorSuiteHelper.latch3.countDown()
+ }
+ }
+ })
+
+ var executor: Executor = null
+ try {
+ executor = new Executor("id", "localhost", mockEnv, userClassPath = Nil, isLocal = true)
+ // the task will be launched in a dedicated worker thread
+ executor.launchTask(mockExecutorBackend, 0, 0, "", serializedTask)
+
+ executorSuiteHelper.latch1.await()
+ // we know the task will be started, but not yet deserialized, because of the latches we
+ // use in mockExecutorBackend.
+ executor.killAllTasks(true)
+ executorSuiteHelper.latch2.countDown()
+ executorSuiteHelper.latch3.await()
+
+ // `testFailedReason` should be `TaskKilled`; `taskState` should be `KILLED`
+ assert(executorSuiteHelper.testFailedReason === TaskKilled)
+ assert(executorSuiteHelper.taskState === TaskState.KILLED)
+ }
+ finally {
+ if (executor != null) {
+ executor.stop()
+ }
+ }
+ }
+}
+
+// Helps to test("SPARK-15963")
+private class ExecutorSuiteHelper {
+
+ val latch1 = new CountDownLatch(1)
+ val latch2 = new CountDownLatch(1)
+ val latch3 = new CountDownLatch(1)
+
+ @volatile var taskState: TaskState = _
+ @volatile var testFailedReason: TaskFailedReason = _
+}