diff options
author | Liwei Lin <lwlin7@gmail.com> | 2016-06-24 10:09:04 -0500 |
---|---|---|
committer | Imran Rashid <irashid@cloudera.com> | 2016-06-24 10:09:04 -0500 |
commit | a4851ed05053a9b7545a258c9159fd529225c455 (patch) | |
tree | 74be27f35aebedc1efb57f3af76460dd6d7e2e40 /core/src | |
parent | be88383e15a86d094963de5f7e8792510bc990de (diff) | |
download | spark-a4851ed05053a9b7545a258c9159fd529225c455.tar.gz spark-a4851ed05053a9b7545a258c9159fd529225c455.tar.bz2 spark-a4851ed05053a9b7545a258c9159fd529225c455.zip |
[SPARK-15963][CORE] Catch `TaskKilledException` correctly in Executor.TaskRunner
## The problem
Before this change, if either of the following cases happened to a task , the task would be marked as `FAILED` instead of `KILLED`:
- the task was killed before it was deserialized
- `executor.kill()` marked `taskRunner.killed`, but before calling `task.killed()` the worker thread threw the `TaskKilledException`
The reason is, in the `catch` block of the current [Executor.TaskRunner](https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/executor/Executor.scala#L362)'s implementation, we are mistakenly catching:
```scala
case _: TaskKilledException | _: InterruptedException if task.killed => ...
```
the semantics of which is:
- **(**`TaskKilledException` **OR** `InterruptedException`**)** **AND** `task.killed`
Then when `TaskKilledException` is thrown but `task.killed` is not marked, we would mark the task as `FAILED` (which should really be `KILLED`).
## What changes were proposed in this pull request?
This patch alters the catch condition's semantics from:
- **(**`TaskKilledException` **OR** `InterruptedException`**)** **AND** `task.killed`
to
- `TaskKilledException` **OR** **(**`InterruptedException` **AND** `task.killed`**)**
so that we can catch `TaskKilledException` correctly and mark the task as `KILLED` correctly.
## How was this patch tested?
Added unit test which failed before the change, ran new test 1000 times manually
Author: Liwei Lin <lwlin7@gmail.com>
Closes #13685 from lw-lin/fix-task-killed.
Diffstat (limited to 'core/src')
-rw-r--r-- | core/src/main/scala/org/apache/spark/executor/Executor.scala | 7 | ||||
-rw-r--r-- | core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala | 139 |
2 files changed, 145 insertions, 1 deletions
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 9a017f29f7..fbf2b86db1 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -359,11 +359,16 @@ private[spark] class Executor( setTaskFinishedAndClearInterruptStatus() execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) - case _: TaskKilledException | _: InterruptedException if task.killed => + case _: TaskKilledException => logInfo(s"Executor killed $taskName (TID $taskId)") setTaskFinishedAndClearInterruptStatus() execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) + case _: InterruptedException if task.killed => + logInfo(s"Executor interrupted and killed $taskName (TID $taskId)") + setTaskFinishedAndClearInterruptStatus() + execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) + case CausedBy(cDE: CommitDeniedException) => val reason = cDE.toTaskEndReason setTaskFinishedAndClearInterruptStatus() diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala new file mode 100644 index 0000000000..3e69894b10 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import java.nio.ByteBuffer +import java.util.concurrent.CountDownLatch + +import scala.collection.mutable.HashMap + +import org.mockito.Matchers._ +import org.mockito.Mockito.{mock, when} +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer + +import org.apache.spark._ +import org.apache.spark.TaskState.TaskState +import org.apache.spark.memory.MemoryManager +import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.rpc.RpcEnv +import org.apache.spark.scheduler.{FakeTask, Task} +import org.apache.spark.serializer.JavaSerializer + +class ExecutorSuite extends SparkFunSuite { + + test("SPARK-15963: Catch `TaskKilledException` correctly in Executor.TaskRunner") { + // mock some objects to make Executor.launchTask() happy + val conf = new SparkConf + val serializer = new JavaSerializer(conf) + val mockEnv = mock(classOf[SparkEnv]) + val mockRpcEnv = mock(classOf[RpcEnv]) + val mockMetricsSystem = mock(classOf[MetricsSystem]) + val mockMemoryManager = mock(classOf[MemoryManager]) + when(mockEnv.conf).thenReturn(conf) + when(mockEnv.serializer).thenReturn(serializer) + when(mockEnv.rpcEnv).thenReturn(mockRpcEnv) + when(mockEnv.metricsSystem).thenReturn(mockMetricsSystem) + when(mockEnv.memoryManager).thenReturn(mockMemoryManager) + when(mockEnv.closureSerializer).thenReturn(serializer) + val serializedTask = + Task.serializeWithDependencies( + new FakeTask(0), + HashMap[String, Long](), + HashMap[String, Long](), + serializer.newInstance()) + + // we use latches to force the program to run in this order: + // +-----------------------------+---------------------------------------+ + // | main test thread | worker thread | + // +-----------------------------+---------------------------------------+ + // | executor.launchTask() | | + // | | TaskRunner.run() begins | + // | | ... | + // | | execBackend.statusUpdate // 1st time | + // | executor.killAllTasks(true) | | + // | | ... | + // | | task = ser.deserialize | + // | | ... | + // | | execBackend.statusUpdate // 2nd time | + // | | ... | + // | | TaskRunner.run() ends | + // | check results | | + // +-----------------------------+---------------------------------------+ + + val executorSuiteHelper = new ExecutorSuiteHelper + + val mockExecutorBackend = mock(classOf[ExecutorBackend]) + when(mockExecutorBackend.statusUpdate(any(), any(), any())) + .thenAnswer(new Answer[Unit] { + var firstTime = true + override def answer(invocationOnMock: InvocationOnMock): Unit = { + if (firstTime) { + executorSuiteHelper.latch1.countDown() + // here between latch1 and latch2, executor.killAllTasks() is called + executorSuiteHelper.latch2.await() + firstTime = false + } + else { + // save the returned `taskState` and `testFailedReason` into `executorSuiteHelper` + val taskState = invocationOnMock.getArguments()(1).asInstanceOf[TaskState] + executorSuiteHelper.taskState = taskState + val taskEndReason = invocationOnMock.getArguments()(2).asInstanceOf[ByteBuffer] + executorSuiteHelper.testFailedReason + = serializer.newInstance().deserialize(taskEndReason) + // let the main test thread check `taskState` and `testFailedReason` + executorSuiteHelper.latch3.countDown() + } + } + }) + + var executor: Executor = null + try { + executor = new Executor("id", "localhost", mockEnv, userClassPath = Nil, isLocal = true) + // the task will be launched in a dedicated worker thread + executor.launchTask(mockExecutorBackend, 0, 0, "", serializedTask) + + executorSuiteHelper.latch1.await() + // we know the task will be started, but not yet deserialized, because of the latches we + // use in mockExecutorBackend. + executor.killAllTasks(true) + executorSuiteHelper.latch2.countDown() + executorSuiteHelper.latch3.await() + + // `testFailedReason` should be `TaskKilled`; `taskState` should be `KILLED` + assert(executorSuiteHelper.testFailedReason === TaskKilled) + assert(executorSuiteHelper.taskState === TaskState.KILLED) + } + finally { + if (executor != null) { + executor.stop() + } + } + } +} + +// Helps to test("SPARK-15963") +private class ExecutorSuiteHelper { + + val latch1 = new CountDownLatch(1) + val latch2 = new CountDownLatch(1) + val latch3 = new CountDownLatch(1) + + @volatile var taskState: TaskState = _ + @volatile var testFailedReason: TaskFailedReason = _ +} |