aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContext.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContextImpl.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala33
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Task.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala13
-rw-r--r--core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala139
-rw-r--r--project/MimaExcludes.scala3
7 files changed, 198 insertions, 17 deletions
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index 0fd777ed12..f0867ecb16 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -24,6 +24,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.metrics.source.Source
+import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.util.{AccumulatorV2, TaskCompletionListener, TaskFailureListener}
@@ -190,4 +191,10 @@ abstract class TaskContext extends Serializable {
*/
private[spark] def registerAccumulator(a: AccumulatorV2[_, _]): Unit
+ /**
+ * Record that this task has failed due to a fetch failure from a remote host. This allows
+ * fetch-failure handling to get triggered by the driver, regardless of intervening user-code.
+ */
+ private[spark] def setFetchFailed(fetchFailed: FetchFailedException): Unit
+
}
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index c904e08391..dc0d128785 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -26,6 +26,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.metrics.source.Source
+import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.util._
private[spark] class TaskContextImpl(
@@ -56,6 +57,10 @@ private[spark] class TaskContextImpl(
// Whether the task has failed.
@volatile private var failed: Boolean = false
+ // If there was a fetch failure in the task, we store it here, to make sure user-code doesn't
+ // hide the exception. See SPARK-19276
+ @volatile private var _fetchFailedException: Option[FetchFailedException] = None
+
override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {
onCompleteCallbacks += listener
this
@@ -126,4 +131,10 @@ private[spark] class TaskContextImpl(
taskMetrics.registerAccumulator(a)
}
+ private[spark] override def setFetchFailed(fetchFailed: FetchFailedException): Unit = {
+ this._fetchFailedException = Option(fetchFailed)
+ }
+
+ private[spark] def fetchFailed: Option[FetchFailedException] = _fetchFailedException
+
}
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 975a6e4eeb..790c1ae942 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -18,6 +18,7 @@
package org.apache.spark.executor
import java.io.{File, NotSerializableException}
+import java.lang.Thread.UncaughtExceptionHandler
import java.lang.management.ManagementFactory
import java.net.{URI, URL}
import java.nio.ByteBuffer
@@ -52,7 +53,8 @@ private[spark] class Executor(
executorHostname: String,
env: SparkEnv,
userClassPath: Seq[URL] = Nil,
- isLocal: Boolean = false)
+ isLocal: Boolean = false,
+ uncaughtExceptionHandler: UncaughtExceptionHandler = SparkUncaughtExceptionHandler)
extends Logging {
logInfo(s"Starting executor ID $executorId on host $executorHostname")
@@ -78,7 +80,7 @@ private[spark] class Executor(
// Setup an uncaught exception handler for non-local mode.
// Make any thread terminations due to uncaught exceptions kill the entire
// executor process to avoid surprising stalls.
- Thread.setDefaultUncaughtExceptionHandler(SparkUncaughtExceptionHandler)
+ Thread.setDefaultUncaughtExceptionHandler(uncaughtExceptionHandler)
}
// Start worker thread pool
@@ -342,6 +344,14 @@ private[spark] class Executor(
}
}
}
+ task.context.fetchFailed.foreach { fetchFailure =>
+ // uh-oh. it appears the user code has caught the fetch-failure without throwing any
+ // other exceptions. Its *possible* this is what the user meant to do (though highly
+ // unlikely). So we will log an error and keep going.
+ logError(s"TID ${taskId} completed successfully though internally it encountered " +
+ s"unrecoverable fetch failures! Most likely this means user code is incorrectly " +
+ s"swallowing Spark's internal ${classOf[FetchFailedException]}", fetchFailure)
+ }
val taskFinish = System.currentTimeMillis()
val taskFinishCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime
@@ -402,8 +412,17 @@ private[spark] class Executor(
execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
} catch {
- case ffe: FetchFailedException =>
- val reason = ffe.toTaskFailedReason
+ case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) =>
+ val reason = task.context.fetchFailed.get.toTaskFailedReason
+ if (!t.isInstanceOf[FetchFailedException]) {
+ // there was a fetch failure in the task, but some user code wrapped that exception
+ // and threw something else. Regardless, we treat it as a fetch failure.
+ val fetchFailedCls = classOf[FetchFailedException].getName
+ logWarning(s"TID ${taskId} encountered a ${fetchFailedCls} and " +
+ s"failed, but the ${fetchFailedCls} was hidden by another " +
+ s"exception. Spark is handling this like a fetch failure and ignoring the " +
+ s"other exception: $t")
+ }
setTaskFinishedAndClearInterruptStatus()
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
@@ -455,13 +474,17 @@ private[spark] class Executor(
// Don't forcibly exit unless the exception was inherently fatal, to avoid
// stopping other tasks unnecessarily.
if (Utils.isFatalError(t)) {
- SparkUncaughtExceptionHandler.uncaughtException(t)
+ uncaughtExceptionHandler.uncaughtException(Thread.currentThread(), t)
}
} finally {
runningTasks.remove(taskId)
}
}
+
+ private def hasFetchFailure: Boolean = {
+ task != null && task.context != null && task.context.fetchFailed.isDefined
+ }
}
/**
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 7b726d5659..70213722aa 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -17,19 +17,14 @@
package org.apache.spark.scheduler
-import java.io.{DataInputStream, DataOutputStream}
import java.nio.ByteBuffer
import java.util.Properties
-import scala.collection.mutable
-import scala.collection.mutable.HashMap
-
import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.internal.config.APP_CALLER_CONTEXT
import org.apache.spark.memory.{MemoryMode, TaskMemoryManager}
import org.apache.spark.metrics.MetricsSystem
-import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.util._
/**
@@ -137,6 +132,8 @@ private[spark] abstract class Task[T](
memoryManager.synchronized { memoryManager.notifyAll() }
}
} finally {
+ // Though we unset the ThreadLocal here, the context member variable itself is still queried
+ // directly in the TaskRunner to check for FetchFailedExceptions.
TaskContext.unset()
}
}
@@ -156,7 +153,7 @@ private[spark] abstract class Task[T](
var epoch: Long = -1
// Task context, to be initialized in run().
- @transient protected var context: TaskContextImpl = _
+ @transient var context: TaskContextImpl = _
// The actual Thread on which the task is running, if any. Initialized in run().
@volatile @transient private var taskThread: Thread = _
diff --git a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
index 498c12e196..265a8acfa8 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
@@ -17,7 +17,7 @@
package org.apache.spark.shuffle
-import org.apache.spark.{FetchFailed, TaskFailedReason}
+import org.apache.spark.{FetchFailed, TaskContext, TaskFailedReason}
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.Utils
@@ -26,6 +26,11 @@ import org.apache.spark.util.Utils
* back to DAGScheduler (through TaskEndReason) so we'd resubmit the previous stage.
*
* Note that bmAddress can be null.
+ *
+ * To prevent user code from hiding this fetch failure, in the constructor we call
+ * [[TaskContext.setFetchFailed()]]. This means that you *must* throw this exception immediately
+ * after creating it -- you cannot create it, check some condition, and then decide to ignore it
+ * (or risk triggering any other exceptions). See SPARK-19276.
*/
private[spark] class FetchFailedException(
bmAddress: BlockManagerId,
@@ -45,6 +50,12 @@ private[spark] class FetchFailedException(
this(bmAddress, shuffleId, mapId, reduceId, cause.getMessage, cause)
}
+ // SPARK-19276. We set the fetch failure in the task context, so that even if there is user-code
+ // which intercepts this exception (possibly wrapping it), the Executor can still tell there was
+ // a fetch failure, and send the correct error msg back to the driver. We wrap with an Option
+ // because the TaskContext is not defined in some test cases.
+ Option(TaskContext.get()).map(_.setFetchFailed(this))
+
def toTaskFailedReason: TaskFailedReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId,
Utils.exceptionString(this))
}
diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
index b743ff5376..8150fff2d0 100644
--- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.executor
import java.io.{Externalizable, ObjectInput, ObjectOutput}
+import java.lang.Thread.UncaughtExceptionHandler
import java.nio.ByteBuffer
import java.util.Properties
import java.util.concurrent.{CountDownLatch, TimeUnit}
@@ -27,7 +28,7 @@ import scala.concurrent.duration._
import org.mockito.ArgumentCaptor
import org.mockito.Matchers.{any, eq => meq}
-import org.mockito.Mockito.{inOrder, when}
+import org.mockito.Mockito.{inOrder, verify, when}
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.scalatest.concurrent.Eventually
@@ -37,9 +38,12 @@ import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.memory.MemoryManager
import org.apache.spark.metrics.MetricsSystem
+import org.apache.spark.rdd.RDD
import org.apache.spark.rpc.RpcEnv
-import org.apache.spark.scheduler.{FakeTask, TaskDescription}
+import org.apache.spark.scheduler.{FakeTask, ResultTask, TaskDescription}
import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.shuffle.FetchFailedException
+import org.apache.spark.storage.BlockManagerId
class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar with Eventually {
@@ -123,6 +127,75 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
}
}
+ test("SPARK-19276: Handle FetchFailedExceptions that are hidden by user exceptions") {
+ val conf = new SparkConf().setMaster("local").setAppName("executor suite test")
+ sc = new SparkContext(conf)
+ val serializer = SparkEnv.get.closureSerializer.newInstance()
+ val resultFunc = (context: TaskContext, itr: Iterator[Int]) => itr.size
+
+ // Submit a job where a fetch failure is thrown, but user code has a try/catch which hides
+ // the fetch failure. The executor should still tell the driver that the task failed due to a
+ // fetch failure, not a generic exception from user code.
+ val inputRDD = new FetchFailureThrowingRDD(sc)
+ val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = false)
+ val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array())
+ val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array()
+ val task = new ResultTask(
+ stageId = 1,
+ stageAttemptId = 0,
+ taskBinary = taskBinary,
+ partition = secondRDD.partitions(0),
+ locs = Seq(),
+ outputId = 0,
+ localProperties = new Properties(),
+ serializedTaskMetrics = serializedTaskMetrics
+ )
+
+ val serTask = serializer.serialize(task)
+ val taskDescription = createFakeTaskDescription(serTask)
+
+ val failReason = runTaskAndGetFailReason(taskDescription)
+ assert(failReason.isInstanceOf[FetchFailed])
+ }
+
+ test("SPARK-19276: OOMs correctly handled with a FetchFailure") {
+ // when there is a fatal error like an OOM, we don't do normal fetch failure handling, since it
+ // may be a false positive. And we should call the uncaught exception handler.
+ val conf = new SparkConf().setMaster("local").setAppName("executor suite test")
+ sc = new SparkContext(conf)
+ val serializer = SparkEnv.get.closureSerializer.newInstance()
+ val resultFunc = (context: TaskContext, itr: Iterator[Int]) => itr.size
+
+ // Submit a job where a fetch failure is thrown, but then there is an OOM. We should treat
+ // the fetch failure as a false positive, and just do normal OOM handling.
+ val inputRDD = new FetchFailureThrowingRDD(sc)
+ val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = true)
+ val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array())
+ val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array()
+ val task = new ResultTask(
+ stageId = 1,
+ stageAttemptId = 0,
+ taskBinary = taskBinary,
+ partition = secondRDD.partitions(0),
+ locs = Seq(),
+ outputId = 0,
+ localProperties = new Properties(),
+ serializedTaskMetrics = serializedTaskMetrics
+ )
+
+ val serTask = serializer.serialize(task)
+ val taskDescription = createFakeTaskDescription(serTask)
+
+ val (failReason, uncaughtExceptionHandler) =
+ runTaskGetFailReasonAndExceptionHandler(taskDescription)
+ // make sure the task failure just looks like a OOM, not a fetch failure
+ assert(failReason.isInstanceOf[ExceptionFailure])
+ val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable])
+ verify(uncaughtExceptionHandler).uncaughtException(any(), exceptionCaptor.capture())
+ assert(exceptionCaptor.getAllValues.size === 1)
+ assert(exceptionCaptor.getAllValues.get(0).isInstanceOf[OutOfMemoryError])
+ }
+
test("Gracefully handle error in task deserialization") {
val conf = new SparkConf
val serializer = new JavaSerializer(conf)
@@ -169,13 +242,20 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
}
private def runTaskAndGetFailReason(taskDescription: TaskDescription): TaskFailedReason = {
+ runTaskGetFailReasonAndExceptionHandler(taskDescription)._1
+ }
+
+ private def runTaskGetFailReasonAndExceptionHandler(
+ taskDescription: TaskDescription): (TaskFailedReason, UncaughtExceptionHandler) = {
val mockBackend = mock[ExecutorBackend]
+ val mockUncaughtExceptionHandler = mock[UncaughtExceptionHandler]
var executor: Executor = null
try {
- executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true)
+ executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true,
+ uncaughtExceptionHandler = mockUncaughtExceptionHandler)
// the task will be launched in a dedicated worker thread
executor.launchTask(mockBackend, taskDescription)
- eventually(timeout(5 seconds), interval(10 milliseconds)) {
+ eventually(timeout(5.seconds), interval(10.milliseconds)) {
assert(executor.numRunningTasks === 0)
}
} finally {
@@ -193,7 +273,56 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
assert(statusCaptor.getAllValues().get(0).remaining() === 0)
// second update is more interesting
val failureData = statusCaptor.getAllValues.get(1)
- SparkEnv.get.closureSerializer.newInstance().deserialize[TaskFailedReason](failureData)
+ val failReason =
+ SparkEnv.get.closureSerializer.newInstance().deserialize[TaskFailedReason](failureData)
+ (failReason, mockUncaughtExceptionHandler)
+ }
+}
+
+class FetchFailureThrowingRDD(sc: SparkContext) extends RDD[Int](sc, Nil) {
+ override def compute(split: Partition, context: TaskContext): Iterator[Int] = {
+ new Iterator[Int] {
+ override def hasNext: Boolean = true
+ override def next(): Int = {
+ throw new FetchFailedException(
+ bmAddress = BlockManagerId("1", "hostA", 1234),
+ shuffleId = 0,
+ mapId = 0,
+ reduceId = 0,
+ message = "fake fetch failure"
+ )
+ }
+ }
+ }
+ override protected def getPartitions: Array[Partition] = {
+ Array(new SimplePartition)
+ }
+}
+
+class SimplePartition extends Partition {
+ override def index: Int = 0
+}
+
+class FetchFailureHidingRDD(
+ sc: SparkContext,
+ val input: FetchFailureThrowingRDD,
+ throwOOM: Boolean) extends RDD[Int](input) {
+ override def compute(split: Partition, context: TaskContext): Iterator[Int] = {
+ val inItr = input.compute(split, context)
+ try {
+ Iterator(inItr.size)
+ } catch {
+ case t: Throwable =>
+ if (throwOOM) {
+ throw new OutOfMemoryError("OOM while handling another exception")
+ } else {
+ throw new RuntimeException("User Exception that hides the original exception", t)
+ }
+ }
+ }
+
+ override protected def getPartitions: Array[Partition] = {
+ Array(new SimplePartition)
}
}
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 511686fb4f..56b8c0b95e 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -55,6 +55,9 @@ object MimaExcludes {
// [SPARK-14272][ML] Add logLikelihood in GaussianMixtureSummary
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.GaussianMixtureSummary.this"),
+ // [SPARK-19267] Fetch Failure handling robust to user error handling
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.setFetchFailed"),
+
// [SPARK-19069] [CORE] Expose task 'status' and 'duration' in spark history server REST API.
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.this"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.<init>$default$10"),