aboutsummaryrefslogtreecommitdiff
path: root/core/src/test/scala
diff options
context:
space:
mode:
authorImran Rashid <irashid@cloudera.com>2017-02-24 13:03:37 -0800
committerKay Ousterhout <kayousterhout@gmail.com>2017-02-24 13:03:37 -0800
commit5f74148bb45912b9f867174de46e246215c93ee1 (patch)
tree54a04ad9e57e038f4339d9e4224d9c1261c8b800 /core/src/test/scala
parent5cbd3b59ba6735c59153416fa15721af6da09acf (diff)
downloadspark-5f74148bb45912b9f867174de46e246215c93ee1.tar.gz
spark-5f74148bb45912b9f867174de46e246215c93ee1.tar.bz2
spark-5f74148bb45912b9f867174de46e246215c93ee1.zip
[SPARK-19597][CORE] test case for task deserialization errors
Adds a test case that ensures that Executors gracefully handle a task that fails to deserialize, by sending back a reasonable failure message. This does not change any behavior (the prior behavior was already correct), it just adds a test case to prevent regression. Author: Imran Rashid <irashid@cloudera.com> Closes #16930 from squito/executor_task_deserialization.
Diffstat (limited to 'core/src/test/scala')
-rw-r--r--core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala139
1 files changed, 106 insertions, 33 deletions
diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
index f94baaa30d..b743ff5376 100644
--- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
@@ -17,16 +17,21 @@
package org.apache.spark.executor
+import java.io.{Externalizable, ObjectInput, ObjectOutput}
import java.nio.ByteBuffer
import java.util.Properties
-import java.util.concurrent.CountDownLatch
+import java.util.concurrent.{CountDownLatch, TimeUnit}
import scala.collection.mutable.Map
+import scala.concurrent.duration._
-import org.mockito.Matchers._
-import org.mockito.Mockito.{mock, when}
+import org.mockito.ArgumentCaptor
+import org.mockito.Matchers.{any, eq => meq}
+import org.mockito.Mockito.{inOrder, when}
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
+import org.scalatest.concurrent.Eventually
+import org.scalatest.mock.MockitoSugar
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
@@ -36,35 +41,15 @@ import org.apache.spark.rpc.RpcEnv
import org.apache.spark.scheduler.{FakeTask, TaskDescription}
import org.apache.spark.serializer.JavaSerializer
-class ExecutorSuite extends SparkFunSuite {
+class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar with Eventually {
test("SPARK-15963: Catch `TaskKilledException` correctly in Executor.TaskRunner") {
// mock some objects to make Executor.launchTask() happy
val conf = new SparkConf
val serializer = new JavaSerializer(conf)
- val mockEnv = mock(classOf[SparkEnv])
- val mockRpcEnv = mock(classOf[RpcEnv])
- val mockMetricsSystem = mock(classOf[MetricsSystem])
- val mockMemoryManager = mock(classOf[MemoryManager])
- when(mockEnv.conf).thenReturn(conf)
- when(mockEnv.serializer).thenReturn(serializer)
- when(mockEnv.rpcEnv).thenReturn(mockRpcEnv)
- when(mockEnv.metricsSystem).thenReturn(mockMetricsSystem)
- when(mockEnv.memoryManager).thenReturn(mockMemoryManager)
- when(mockEnv.closureSerializer).thenReturn(serializer)
- val fakeTaskMetrics = serializer.newInstance().serialize(TaskMetrics.registered).array()
- val serializedTask = serializer.newInstance().serialize(
- new FakeTask(0, 0, Nil, fakeTaskMetrics))
- val taskDescription = new TaskDescription(
- taskId = 0,
- attemptNumber = 0,
- executorId = "",
- name = "",
- index = 0,
- addedFiles = Map[String, Long](),
- addedJars = Map[String, Long](),
- properties = new Properties,
- serializedTask)
+ val env = createMockEnv(conf, serializer)
+ val serializedTask = serializer.newInstance().serialize(new FakeTask(0, 0))
+ val taskDescription = createFakeTaskDescription(serializedTask)
// we use latches to force the program to run in this order:
// +-----------------------------+---------------------------------------+
@@ -86,7 +71,7 @@ class ExecutorSuite extends SparkFunSuite {
val executorSuiteHelper = new ExecutorSuiteHelper
- val mockExecutorBackend = mock(classOf[ExecutorBackend])
+ val mockExecutorBackend = mock[ExecutorBackend]
when(mockExecutorBackend.statusUpdate(any(), any(), any()))
.thenAnswer(new Answer[Unit] {
var firstTime = true
@@ -102,8 +87,8 @@ class ExecutorSuite extends SparkFunSuite {
val taskState = invocationOnMock.getArguments()(1).asInstanceOf[TaskState]
executorSuiteHelper.taskState = taskState
val taskEndReason = invocationOnMock.getArguments()(2).asInstanceOf[ByteBuffer]
- executorSuiteHelper.testFailedReason
- = serializer.newInstance().deserialize(taskEndReason)
+ executorSuiteHelper.testFailedReason =
+ serializer.newInstance().deserialize(taskEndReason)
// let the main test thread check `taskState` and `testFailedReason`
executorSuiteHelper.latch3.countDown()
}
@@ -112,16 +97,20 @@ class ExecutorSuite extends SparkFunSuite {
var executor: Executor = null
try {
- executor = new Executor("id", "localhost", mockEnv, userClassPath = Nil, isLocal = true)
+ executor = new Executor("id", "localhost", env, userClassPath = Nil, isLocal = true)
// the task will be launched in a dedicated worker thread
executor.launchTask(mockExecutorBackend, taskDescription)
- executorSuiteHelper.latch1.await()
+ if (!executorSuiteHelper.latch1.await(5, TimeUnit.SECONDS)) {
+ fail("executor did not send first status update in time")
+ }
// we know the task will be started, but not yet deserialized, because of the latches we
// use in mockExecutorBackend.
executor.killAllTasks(true)
executorSuiteHelper.latch2.countDown()
- executorSuiteHelper.latch3.await()
+ if (!executorSuiteHelper.latch3.await(5, TimeUnit.SECONDS)) {
+ fail("executor did not send second status update in time")
+ }
// `testFailedReason` should be `TaskKilled`; `taskState` should be `KILLED`
assert(executorSuiteHelper.testFailedReason === TaskKilled)
@@ -133,6 +122,79 @@ class ExecutorSuite extends SparkFunSuite {
}
}
}
+
+ test("Gracefully handle error in task deserialization") {
+ val conf = new SparkConf
+ val serializer = new JavaSerializer(conf)
+ val env = createMockEnv(conf, serializer)
+ val serializedTask = serializer.newInstance().serialize(new NonDeserializableTask)
+ val taskDescription = createFakeTaskDescription(serializedTask)
+
+ val failReason = runTaskAndGetFailReason(taskDescription)
+ failReason match {
+ case ef: ExceptionFailure =>
+ assert(ef.exception.isDefined)
+ assert(ef.exception.get.getMessage() === NonDeserializableTask.errorMsg)
+ case _ =>
+ fail(s"unexpected failure type: $failReason")
+ }
+ }
+
+ private def createMockEnv(conf: SparkConf, serializer: JavaSerializer): SparkEnv = {
+ val mockEnv = mock[SparkEnv]
+ val mockRpcEnv = mock[RpcEnv]
+ val mockMetricsSystem = mock[MetricsSystem]
+ val mockMemoryManager = mock[MemoryManager]
+ when(mockEnv.conf).thenReturn(conf)
+ when(mockEnv.serializer).thenReturn(serializer)
+ when(mockEnv.rpcEnv).thenReturn(mockRpcEnv)
+ when(mockEnv.metricsSystem).thenReturn(mockMetricsSystem)
+ when(mockEnv.memoryManager).thenReturn(mockMemoryManager)
+ when(mockEnv.closureSerializer).thenReturn(serializer)
+ SparkEnv.set(mockEnv)
+ mockEnv
+ }
+
+ private def createFakeTaskDescription(serializedTask: ByteBuffer): TaskDescription = {
+ new TaskDescription(
+ taskId = 0,
+ attemptNumber = 0,
+ executorId = "",
+ name = "",
+ index = 0,
+ addedFiles = Map[String, Long](),
+ addedJars = Map[String, Long](),
+ properties = new Properties,
+ serializedTask)
+ }
+
+ private def runTaskAndGetFailReason(taskDescription: TaskDescription): TaskFailedReason = {
+ val mockBackend = mock[ExecutorBackend]
+ var executor: Executor = null
+ try {
+ executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true)
+ // the task will be launched in a dedicated worker thread
+ executor.launchTask(mockBackend, taskDescription)
+ eventually(timeout(5 seconds), interval(10 milliseconds)) {
+ assert(executor.numRunningTasks === 0)
+ }
+ } finally {
+ if (executor != null) {
+ executor.stop()
+ }
+ }
+ val orderedMock = inOrder(mockBackend)
+ val statusCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer])
+ orderedMock.verify(mockBackend)
+ .statusUpdate(meq(0L), meq(TaskState.RUNNING), statusCaptor.capture())
+ orderedMock.verify(mockBackend)
+ .statusUpdate(meq(0L), meq(TaskState.FAILED), statusCaptor.capture())
+ // first statusUpdate for RUNNING has empty data
+ assert(statusCaptor.getAllValues().get(0).remaining() === 0)
+ // second update is more interesting
+ val failureData = statusCaptor.getAllValues.get(1)
+ SparkEnv.get.closureSerializer.newInstance().deserialize[TaskFailedReason](failureData)
+ }
}
// Helps to test("SPARK-15963")
@@ -145,3 +207,14 @@ private class ExecutorSuiteHelper {
@volatile var taskState: TaskState = _
@volatile var testFailedReason: TaskFailedReason = _
}
+
+private class NonDeserializableTask extends FakeTask(0, 0) with Externalizable {
+ def writeExternal(out: ObjectOutput): Unit = {}
+ def readExternal(in: ObjectInput): Unit = {
+ throw new RuntimeException(NonDeserializableTask.errorMsg)
+ }
+}
+
+private object NonDeserializableTask {
+ val errorMsg = "failure in deserialization"
+}