diff options
Diffstat (limited to 'core/src/main/scala')
5 files changed, 61 insertions, 12 deletions
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 0fd777ed12..f0867ecb16 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -24,6 +24,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.source.Source +import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util.{AccumulatorV2, TaskCompletionListener, TaskFailureListener} @@ -190,4 +191,10 @@ abstract class TaskContext extends Serializable { */ private[spark] def registerAccumulator(a: AccumulatorV2[_, _]): Unit + /** + * Record that this task has failed due to a fetch failure from a remote host. This allows + * fetch-failure handling to get triggered by the driver, regardless of intervening user-code. + */ + private[spark] def setFetchFailed(fetchFailed: FetchFailedException): Unit + } diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index c904e08391..dc0d128785 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -26,6 +26,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.MetricsSystem import org.apache.spark.metrics.source.Source +import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util._ private[spark] class TaskContextImpl( @@ -56,6 +57,10 @@ private[spark] class TaskContextImpl( // Whether the task has failed. @volatile private var failed: Boolean = false + // If there was a fetch failure in the task, we store it here, to make sure user-code doesn't + // hide the exception. See SPARK-19276 + @volatile private var _fetchFailedException: Option[FetchFailedException] = None + override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = { onCompleteCallbacks += listener this @@ -126,4 +131,10 @@ private[spark] class TaskContextImpl( taskMetrics.registerAccumulator(a) } + private[spark] override def setFetchFailed(fetchFailed: FetchFailedException): Unit = { + this._fetchFailedException = Option(fetchFailed) + } + + private[spark] def fetchFailed: Option[FetchFailedException] = _fetchFailedException + } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 975a6e4eeb..790c1ae942 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -18,6 +18,7 @@ package org.apache.spark.executor import java.io.{File, NotSerializableException} +import java.lang.Thread.UncaughtExceptionHandler import java.lang.management.ManagementFactory import java.net.{URI, URL} import java.nio.ByteBuffer @@ -52,7 +53,8 @@ private[spark] class Executor( executorHostname: String, env: SparkEnv, userClassPath: Seq[URL] = Nil, - isLocal: Boolean = false) + isLocal: Boolean = false, + uncaughtExceptionHandler: UncaughtExceptionHandler = SparkUncaughtExceptionHandler) extends Logging { logInfo(s"Starting executor ID $executorId on host $executorHostname") @@ -78,7 +80,7 @@ private[spark] class Executor( // Setup an uncaught exception handler for non-local mode. // Make any thread terminations due to uncaught exceptions kill the entire // executor process to avoid surprising stalls. - Thread.setDefaultUncaughtExceptionHandler(SparkUncaughtExceptionHandler) + Thread.setDefaultUncaughtExceptionHandler(uncaughtExceptionHandler) } // Start worker thread pool @@ -342,6 +344,14 @@ private[spark] class Executor( } } } + task.context.fetchFailed.foreach { fetchFailure => + // uh-oh. it appears the user code has caught the fetch-failure without throwing any + // other exceptions. Its *possible* this is what the user meant to do (though highly + // unlikely). So we will log an error and keep going. + logError(s"TID ${taskId} completed successfully though internally it encountered " + + s"unrecoverable fetch failures! Most likely this means user code is incorrectly " + + s"swallowing Spark's internal ${classOf[FetchFailedException]}", fetchFailure) + } val taskFinish = System.currentTimeMillis() val taskFinishCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) { threadMXBean.getCurrentThreadCpuTime @@ -402,8 +412,17 @@ private[spark] class Executor( execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult) } catch { - case ffe: FetchFailedException => - val reason = ffe.toTaskFailedReason + case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) => + val reason = task.context.fetchFailed.get.toTaskFailedReason + if (!t.isInstanceOf[FetchFailedException]) { + // there was a fetch failure in the task, but some user code wrapped that exception + // and threw something else. Regardless, we treat it as a fetch failure. + val fetchFailedCls = classOf[FetchFailedException].getName + logWarning(s"TID ${taskId} encountered a ${fetchFailedCls} and " + + s"failed, but the ${fetchFailedCls} was hidden by another " + + s"exception. Spark is handling this like a fetch failure and ignoring the " + + s"other exception: $t") + } setTaskFinishedAndClearInterruptStatus() execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) @@ -455,13 +474,17 @@ private[spark] class Executor( // Don't forcibly exit unless the exception was inherently fatal, to avoid // stopping other tasks unnecessarily. if (Utils.isFatalError(t)) { - SparkUncaughtExceptionHandler.uncaughtException(t) + uncaughtExceptionHandler.uncaughtException(Thread.currentThread(), t) } } finally { runningTasks.remove(taskId) } } + + private def hasFetchFailure: Boolean = { + task != null && task.context != null && task.context.fetchFailed.isDefined + } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 7b726d5659..70213722aa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -17,19 +17,14 @@ package org.apache.spark.scheduler -import java.io.{DataInputStream, DataOutputStream} import java.nio.ByteBuffer import java.util.Properties -import scala.collection.mutable -import scala.collection.mutable.HashMap - import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.internal.config.APP_CALLER_CONTEXT import org.apache.spark.memory.{MemoryMode, TaskMemoryManager} import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.serializer.SerializerInstance import org.apache.spark.util._ /** @@ -137,6 +132,8 @@ private[spark] abstract class Task[T]( memoryManager.synchronized { memoryManager.notifyAll() } } } finally { + // Though we unset the ThreadLocal here, the context member variable itself is still queried + // directly in the TaskRunner to check for FetchFailedExceptions. TaskContext.unset() } } @@ -156,7 +153,7 @@ private[spark] abstract class Task[T]( var epoch: Long = -1 // Task context, to be initialized in run(). - @transient protected var context: TaskContextImpl = _ + @transient var context: TaskContextImpl = _ // The actual Thread on which the task is running, if any. Initialized in run(). @volatile @transient private var taskThread: Thread = _ diff --git a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala index 498c12e196..265a8acfa8 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala @@ -17,7 +17,7 @@ package org.apache.spark.shuffle -import org.apache.spark.{FetchFailed, TaskFailedReason} +import org.apache.spark.{FetchFailed, TaskContext, TaskFailedReason} import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.Utils @@ -26,6 +26,11 @@ import org.apache.spark.util.Utils * back to DAGScheduler (through TaskEndReason) so we'd resubmit the previous stage. * * Note that bmAddress can be null. + * + * To prevent user code from hiding this fetch failure, in the constructor we call + * [[TaskContext.setFetchFailed()]]. This means that you *must* throw this exception immediately + * after creating it -- you cannot create it, check some condition, and then decide to ignore it + * (or risk triggering any other exceptions). See SPARK-19276. */ private[spark] class FetchFailedException( bmAddress: BlockManagerId, @@ -45,6 +50,12 @@ private[spark] class FetchFailedException( this(bmAddress, shuffleId, mapId, reduceId, cause.getMessage, cause) } + // SPARK-19276. We set the fetch failure in the task context, so that even if there is user-code + // which intercepts this exception (possibly wrapping it), the Executor can still tell there was + // a fetch failure, and send the correct error msg back to the driver. We wrap with an Option + // because the TaskContext is not defined in some test cases. + Option(TaskContext.get()).map(_.setFetchFailed(this)) + def toTaskFailedReason: TaskFailedReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId, Utils.exceptionString(this)) } |