aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorImran Rashid <irashid@cloudera.com>2017-03-02 16:46:01 -0800
committerKay Ousterhout <kayousterhout@gmail.com>2017-03-02 16:46:01 -0800
commit8417a7ae6c0ea3fb8dc41bc492fc9513d1ad24af (patch)
tree47d8411d2324b6d85b85c1936588ab1b4ff3ca46
parent433d9eb6151a547af967cc1ac983a789bed60704 (diff)
downloadspark-8417a7ae6c0ea3fb8dc41bc492fc9513d1ad24af.tar.gz
spark-8417a7ae6c0ea3fb8dc41bc492fc9513d1ad24af.tar.bz2
spark-8417a7ae6c0ea3fb8dc41bc492fc9513d1ad24af.zip
[SPARK-19276][CORE] Fetch Failure handling robust to user error handling
## What changes were proposed in this pull request? Fault-tolerance in spark requires special handling of shuffle fetch failures. The Executor would catch FetchFailedException and send a special msg back to the driver. However, intervening user code could intercept that exception, and wrap it with something else. This even happens in SparkSQL. So rather than checking the thrown exception only, we'll store the fetch failure directly in the TaskContext, where users can't touch it. ## How was this patch tested? Added a test case which failed before the fix. Full test suite via jenkins. Author: Imran Rashid <irashid@cloudera.com> Closes #16639 from squito/SPARK-19276.
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContext.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContextImpl.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala33
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Task.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala13
-rw-r--r--core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala139
-rw-r--r--project/MimaExcludes.scala3
7 files changed, 198 insertions, 17 deletions
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index 0fd777ed12..f0867ecb16 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -24,6 +24,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.metrics.source.Source
+import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.util.{AccumulatorV2, TaskCompletionListener, TaskFailureListener}
@@ -190,4 +191,10 @@ abstract class TaskContext extends Serializable {
*/
private[spark] def registerAccumulator(a: AccumulatorV2[_, _]): Unit
+ /**
+ * Record that this task has failed due to a fetch failure from a remote host. This allows
+ * fetch-failure handling to get triggered by the driver, regardless of intervening user-code.
+ */
+ private[spark] def setFetchFailed(fetchFailed: FetchFailedException): Unit
+
}
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index c904e08391..dc0d128785 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -26,6 +26,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.metrics.source.Source
+import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.util._
private[spark] class TaskContextImpl(
@@ -56,6 +57,10 @@ private[spark] class TaskContextImpl(
// Whether the task has failed.
@volatile private var failed: Boolean = false
+ // If there was a fetch failure in the task, we store it here, to make sure user-code doesn't
+ // hide the exception. See SPARK-19276
+ @volatile private var _fetchFailedException: Option[FetchFailedException] = None
+
override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {
onCompleteCallbacks += listener
this
@@ -126,4 +131,10 @@ private[spark] class TaskContextImpl(
taskMetrics.registerAccumulator(a)
}
+ private[spark] override def setFetchFailed(fetchFailed: FetchFailedException): Unit = {
+ this._fetchFailedException = Option(fetchFailed)
+ }
+
+ private[spark] def fetchFailed: Option[FetchFailedException] = _fetchFailedException
+
}
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 975a6e4eeb..790c1ae942 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -18,6 +18,7 @@
package org.apache.spark.executor
import java.io.{File, NotSerializableException}
+import java.lang.Thread.UncaughtExceptionHandler
import java.lang.management.ManagementFactory
import java.net.{URI, URL}
import java.nio.ByteBuffer
@@ -52,7 +53,8 @@ private[spark] class Executor(
executorHostname: String,
env: SparkEnv,
userClassPath: Seq[URL] = Nil,
- isLocal: Boolean = false)
+ isLocal: Boolean = false,
+ uncaughtExceptionHandler: UncaughtExceptionHandler = SparkUncaughtExceptionHandler)
extends Logging {
logInfo(s"Starting executor ID $executorId on host $executorHostname")
@@ -78,7 +80,7 @@ private[spark] class Executor(
// Setup an uncaught exception handler for non-local mode.
// Make any thread terminations due to uncaught exceptions kill the entire
// executor process to avoid surprising stalls.
- Thread.setDefaultUncaughtExceptionHandler(SparkUncaughtExceptionHandler)
+ Thread.setDefaultUncaughtExceptionHandler(uncaughtExceptionHandler)
}
// Start worker thread pool
@@ -342,6 +344,14 @@ private[spark] class Executor(
}
}
}
+ task.context.fetchFailed.foreach { fetchFailure =>
+ // uh-oh. it appears the user code has caught the fetch-failure without throwing any
+ // other exceptions. Its *possible* this is what the user meant to do (though highly
+ // unlikely). So we will log an error and keep going.
+ logError(s"TID ${taskId} completed successfully though internally it encountered " +
+ s"unrecoverable fetch failures! Most likely this means user code is incorrectly " +
+ s"swallowing Spark's internal ${classOf[FetchFailedException]}", fetchFailure)
+ }
val taskFinish = System.currentTimeMillis()
val taskFinishCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime
@@ -402,8 +412,17 @@ private[spark] class Executor(
execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
} catch {
- case ffe: FetchFailedException =>
- val reason = ffe.toTaskFailedReason
+ case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) =>
+ val reason = task.context.fetchFailed.get.toTaskFailedReason
+ if (!t.isInstanceOf[FetchFailedException]) {
+ // there was a fetch failure in the task, but some user code wrapped that exception
+ // and threw something else. Regardless, we treat it as a fetch failure.
+ val fetchFailedCls = classOf[FetchFailedException].getName
+ logWarning(s"TID ${taskId} encountered a ${fetchFailedCls} and " +
+ s"failed, but the ${fetchFailedCls} was hidden by another " +
+ s"exception. Spark is handling this like a fetch failure and ignoring the " +
+ s"other exception: $t")
+ }
setTaskFinishedAndClearInterruptStatus()
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
@@ -455,13 +474,17 @@ private[spark] class Executor(
// Don't forcibly exit unless the exception was inherently fatal, to avoid
// stopping other tasks unnecessarily.
if (Utils.isFatalError(t)) {
- SparkUncaughtExceptionHandler.uncaughtException(t)
+ uncaughtExceptionHandler.uncaughtException(Thread.currentThread(), t)
}
} finally {
runningTasks.remove(taskId)
}
}
+
+ private def hasFetchFailure: Boolean = {
+ task != null && task.context != null && task.context.fetchFailed.isDefined
+ }
}
/**
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 7b726d5659..70213722aa 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -17,19 +17,14 @@
package org.apache.spark.scheduler
-import java.io.{DataInputStream, DataOutputStream}
import java.nio.ByteBuffer
import java.util.Properties
-import scala.collection.mutable
-import scala.collection.mutable.HashMap
-
import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.internal.config.APP_CALLER_CONTEXT
import org.apache.spark.memory.{MemoryMode, TaskMemoryManager}
import org.apache.spark.metrics.MetricsSystem
-import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.util._
/**
@@ -137,6 +132,8 @@ private[spark] abstract class Task[T](
memoryManager.synchronized { memoryManager.notifyAll() }
}
} finally {
+ // Though we unset the ThreadLocal here, the context member variable itself is still queried
+ // directly in the TaskRunner to check for FetchFailedExceptions.
TaskContext.unset()
}
}
@@ -156,7 +153,7 @@ private[spark] abstract class Task[T](
var epoch: Long = -1
// Task context, to be initialized in run().
- @transient protected var context: TaskContextImpl = _
+ @transient var context: TaskContextImpl = _
// The actual Thread on which the task is running, if any. Initialized in run().
@volatile @transient private var taskThread: Thread = _
diff --git a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
index 498c12e196..265a8acfa8 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
@@ -17,7 +17,7 @@
package org.apache.spark.shuffle
-import org.apache.spark.{FetchFailed, TaskFailedReason}
+import org.apache.spark.{FetchFailed, TaskContext, TaskFailedReason}
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.Utils
@@ -26,6 +26,11 @@ import org.apache.spark.util.Utils
* back to DAGScheduler (through TaskEndReason) so we'd resubmit the previous stage.
*
* Note that bmAddress can be null.
+ *
+ * To prevent user code from hiding this fetch failure, in the constructor we call
+ * [[TaskContext.setFetchFailed()]]. This means that you *must* throw this exception immediately
+ * after creating it -- you cannot create it, check some condition, and then decide to ignore it
+ * (or risk triggering any other exceptions). See SPARK-19276.
*/
private[spark] class FetchFailedException(
bmAddress: BlockManagerId,
@@ -45,6 +50,12 @@ private[spark] class FetchFailedException(
this(bmAddress, shuffleId, mapId, reduceId, cause.getMessage, cause)
}
+ // SPARK-19276. We set the fetch failure in the task context, so that even if there is user-code
+ // which intercepts this exception (possibly wrapping it), the Executor can still tell there was
+ // a fetch failure, and send the correct error msg back to the driver. We wrap with an Option
+ // because the TaskContext is not defined in some test cases.
+ Option(TaskContext.get()).map(_.setFetchFailed(this))
+
def toTaskFailedReason: TaskFailedReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId,
Utils.exceptionString(this))
}
diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
index b743ff5376..8150fff2d0 100644
--- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.executor
import java.io.{Externalizable, ObjectInput, ObjectOutput}
+import java.lang.Thread.UncaughtExceptionHandler
import java.nio.ByteBuffer
import java.util.Properties
import java.util.concurrent.{CountDownLatch, TimeUnit}
@@ -27,7 +28,7 @@ import scala.concurrent.duration._
import org.mockito.ArgumentCaptor
import org.mockito.Matchers.{any, eq => meq}
-import org.mockito.Mockito.{inOrder, when}
+import org.mockito.Mockito.{inOrder, verify, when}
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.scalatest.concurrent.Eventually
@@ -37,9 +38,12 @@ import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.memory.MemoryManager
import org.apache.spark.metrics.MetricsSystem
+import org.apache.spark.rdd.RDD
import org.apache.spark.rpc.RpcEnv
-import org.apache.spark.scheduler.{FakeTask, TaskDescription}
+import org.apache.spark.scheduler.{FakeTask, ResultTask, TaskDescription}
import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.shuffle.FetchFailedException
+import org.apache.spark.storage.BlockManagerId
class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar with Eventually {
@@ -123,6 +127,75 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
}
}
+ test("SPARK-19276: Handle FetchFailedExceptions that are hidden by user exceptions") {
+ val conf = new SparkConf().setMaster("local").setAppName("executor suite test")
+ sc = new SparkContext(conf)
+ val serializer = SparkEnv.get.closureSerializer.newInstance()
+ val resultFunc = (context: TaskContext, itr: Iterator[Int]) => itr.size
+
+ // Submit a job where a fetch failure is thrown, but user code has a try/catch which hides
+ // the fetch failure. The executor should still tell the driver that the task failed due to a
+ // fetch failure, not a generic exception from user code.
+ val inputRDD = new FetchFailureThrowingRDD(sc)
+ val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = false)
+ val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array())
+ val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array()
+ val task = new ResultTask(
+ stageId = 1,
+ stageAttemptId = 0,
+ taskBinary = taskBinary,
+ partition = secondRDD.partitions(0),
+ locs = Seq(),
+ outputId = 0,
+ localProperties = new Properties(),
+ serializedTaskMetrics = serializedTaskMetrics
+ )
+
+ val serTask = serializer.serialize(task)
+ val taskDescription = createFakeTaskDescription(serTask)
+
+ val failReason = runTaskAndGetFailReason(taskDescription)
+ assert(failReason.isInstanceOf[FetchFailed])
+ }
+
+ test("SPARK-19276: OOMs correctly handled with a FetchFailure") {
+ // when there is a fatal error like an OOM, we don't do normal fetch failure handling, since it
+ // may be a false positive. And we should call the uncaught exception handler.
+ val conf = new SparkConf().setMaster("local").setAppName("executor suite test")
+ sc = new SparkContext(conf)
+ val serializer = SparkEnv.get.closureSerializer.newInstance()
+ val resultFunc = (context: TaskContext, itr: Iterator[Int]) => itr.size
+
+ // Submit a job where a fetch failure is thrown, but then there is an OOM. We should treat
+ // the fetch failure as a false positive, and just do normal OOM handling.
+ val inputRDD = new FetchFailureThrowingRDD(sc)
+ val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = true)
+ val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array())
+ val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array()
+ val task = new ResultTask(
+ stageId = 1,
+ stageAttemptId = 0,
+ taskBinary = taskBinary,
+ partition = secondRDD.partitions(0),
+ locs = Seq(),
+ outputId = 0,
+ localProperties = new Properties(),
+ serializedTaskMetrics = serializedTaskMetrics
+ )
+
+ val serTask = serializer.serialize(task)
+ val taskDescription = createFakeTaskDescription(serTask)
+
+ val (failReason, uncaughtExceptionHandler) =
+ runTaskGetFailReasonAndExceptionHandler(taskDescription)
+ // make sure the task failure just looks like a OOM, not a fetch failure
+ assert(failReason.isInstanceOf[ExceptionFailure])
+ val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable])
+ verify(uncaughtExceptionHandler).uncaughtException(any(), exceptionCaptor.capture())
+ assert(exceptionCaptor.getAllValues.size === 1)
+ assert(exceptionCaptor.getAllValues.get(0).isInstanceOf[OutOfMemoryError])
+ }
+
test("Gracefully handle error in task deserialization") {
val conf = new SparkConf
val serializer = new JavaSerializer(conf)
@@ -169,13 +242,20 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
}
private def runTaskAndGetFailReason(taskDescription: TaskDescription): TaskFailedReason = {
+ runTaskGetFailReasonAndExceptionHandler(taskDescription)._1
+ }
+
+ private def runTaskGetFailReasonAndExceptionHandler(
+ taskDescription: TaskDescription): (TaskFailedReason, UncaughtExceptionHandler) = {
val mockBackend = mock[ExecutorBackend]
+ val mockUncaughtExceptionHandler = mock[UncaughtExceptionHandler]
var executor: Executor = null
try {
- executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true)
+ executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true,
+ uncaughtExceptionHandler = mockUncaughtExceptionHandler)
// the task will be launched in a dedicated worker thread
executor.launchTask(mockBackend, taskDescription)
- eventually(timeout(5 seconds), interval(10 milliseconds)) {
+ eventually(timeout(5.seconds), interval(10.milliseconds)) {
assert(executor.numRunningTasks === 0)
}
} finally {
@@ -193,7 +273,56 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
assert(statusCaptor.getAllValues().get(0).remaining() === 0)
// second update is more interesting
val failureData = statusCaptor.getAllValues.get(1)
- SparkEnv.get.closureSerializer.newInstance().deserialize[TaskFailedReason](failureData)
+ val failReason =
+ SparkEnv.get.closureSerializer.newInstance().deserialize[TaskFailedReason](failureData)
+ (failReason, mockUncaughtExceptionHandler)
+ }
+}
+
+class FetchFailureThrowingRDD(sc: SparkContext) extends RDD[Int](sc, Nil) {
+ override def compute(split: Partition, context: TaskContext): Iterator[Int] = {
+ new Iterator[Int] {
+ override def hasNext: Boolean = true
+ override def next(): Int = {
+ throw new FetchFailedException(
+ bmAddress = BlockManagerId("1", "hostA", 1234),
+ shuffleId = 0,
+ mapId = 0,
+ reduceId = 0,
+ message = "fake fetch failure"
+ )
+ }
+ }
+ }
+ override protected def getPartitions: Array[Partition] = {
+ Array(new SimplePartition)
+ }
+}
+
+class SimplePartition extends Partition {
+ override def index: Int = 0
+}
+
+class FetchFailureHidingRDD(
+ sc: SparkContext,
+ val input: FetchFailureThrowingRDD,
+ throwOOM: Boolean) extends RDD[Int](input) {
+ override def compute(split: Partition, context: TaskContext): Iterator[Int] = {
+ val inItr = input.compute(split, context)
+ try {
+ Iterator(inItr.size)
+ } catch {
+ case t: Throwable =>
+ if (throwOOM) {
+ throw new OutOfMemoryError("OOM while handling another exception")
+ } else {
+ throw new RuntimeException("User Exception that hides the original exception", t)
+ }
+ }
+ }
+
+ override protected def getPartitions: Array[Partition] = {
+ Array(new SimplePartition)
}
}
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 511686fb4f..56b8c0b95e 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -55,6 +55,9 @@ object MimaExcludes {
// [SPARK-14272][ML] Add logLikelihood in GaussianMixtureSummary
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.GaussianMixtureSummary.this"),
+ // [SPARK-19267] Fetch Failure handling robust to user error handling
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.setFetchFailed"),
+
// [SPARK-19069] [CORE] Expose task 'status' and 'duration' in spark history server REST API.
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.this"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.<init>$default$10"),