aboutsummaryrefslogtreecommitdiff
path: root/core/src/main
diff options
context:
space:
mode:
authorTom White <tom@cloudera.com>2015-08-12 10:06:27 -0500
committerImran Rashid <irashid@cloudera.com>2015-08-12 10:07:11 -0500
commit2e680668f7b6fc158aa068aedd19c1878ecf759e (patch)
tree29f1e49a5a52394f3fcf8fef64458558d454b359 /core/src/main
parent3ecb3794302dc12d0989f8d725483b2cc37762cf (diff)
downloadspark-2e680668f7b6fc158aa068aedd19c1878ecf759e.tar.gz
spark-2e680668f7b6fc158aa068aedd19c1878ecf759e.tar.bz2
spark-2e680668f7b6fc158aa068aedd19c1878ecf759e.zip
[SPARK-8625] [CORE] Propagate user exceptions in tasks back to driver
This allows clients to retrieve the original exception from the cause field of the SparkException that is thrown by the driver. If the original exception is not in fact Serializable then it will not be returned, but the message and stacktrace will be. (All Java Throwables implement the Serializable interface, but this is no guarantee that a particular implementation can actually be serialized.) Author: Tom White <tom@cloudera.com> Closes #7014 from tomwhite/propagate-user-exceptions.
Diffstat (limited to 'core/src/main')
-rw-r--r--core/src/main/scala/org/apache/spark/TaskEndReason.scala44
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala44
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/util/JsonProtocol.scala2
6 files changed, 91 insertions, 28 deletions
diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
index 48fd3e7e23..934d00dc70 100644
--- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala
+++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
@@ -17,6 +17,8 @@
package org.apache.spark
+import java.io.{IOException, ObjectInputStream, ObjectOutputStream}
+
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.storage.BlockManagerId
@@ -90,6 +92,10 @@ case class FetchFailed(
*
* `fullStackTrace` is a better representation of the stack trace because it contains the whole
* stack trace including the exception and its causes
+ *
+ * `exception` is the actual exception that caused the task to fail. It may be `None` in
+ * the case that the exception is not in fact serializable. If a task fails more than
+ * once (due to retries), `exception` is that one that caused the last failure.
*/
@DeveloperApi
case class ExceptionFailure(
@@ -97,11 +103,26 @@ case class ExceptionFailure(
description: String,
stackTrace: Array[StackTraceElement],
fullStackTrace: String,
- metrics: Option[TaskMetrics])
+ metrics: Option[TaskMetrics],
+ private val exceptionWrapper: Option[ThrowableSerializationWrapper])
extends TaskFailedReason {
+ /**
+ * `preserveCause` is used to keep the exception itself so it is available to the
+ * driver. This may be set to `false` in the event that the exception is not in fact
+ * serializable.
+ */
+ private[spark] def this(e: Throwable, metrics: Option[TaskMetrics], preserveCause: Boolean) {
+ this(e.getClass.getName, e.getMessage, e.getStackTrace, Utils.exceptionString(e), metrics,
+ if (preserveCause) Some(new ThrowableSerializationWrapper(e)) else None)
+ }
+
private[spark] def this(e: Throwable, metrics: Option[TaskMetrics]) {
- this(e.getClass.getName, e.getMessage, e.getStackTrace, Utils.exceptionString(e), metrics)
+ this(e, metrics, preserveCause = true)
+ }
+
+ def exception: Option[Throwable] = exceptionWrapper.flatMap {
+ (w: ThrowableSerializationWrapper) => Option(w.exception)
}
override def toErrorString: String =
@@ -128,6 +149,25 @@ case class ExceptionFailure(
}
/**
+ * A class for recovering from exceptions when deserializing a Throwable that was
+ * thrown in user task code. If the Throwable cannot be deserialized it will be null,
+ * but the stacktrace and message will be preserved correctly in SparkException.
+ */
+private[spark] class ThrowableSerializationWrapper(var exception: Throwable) extends
+ Serializable with Logging {
+ private def writeObject(out: ObjectOutputStream): Unit = {
+ out.writeObject(exception)
+ }
+ private def readObject(in: ObjectInputStream): Unit = {
+ try {
+ exception = in.readObject().asInstanceOf[Throwable]
+ } catch {
+ case e : Exception => log.warn("Task exception could not be deserialized", e)
+ }
+ }
+}
+
+/**
* :: DeveloperApi ::
* The task finished successfully, but the result was lost from the executor's block manager before
* it was fetched.
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 5d78a9dc88..42a85e42ea 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -17,7 +17,7 @@
package org.apache.spark.executor
-import java.io.File
+import java.io.{File, NotSerializableException}
import java.lang.management.ManagementFactory
import java.net.URL
import java.nio.ByteBuffer
@@ -305,8 +305,16 @@ private[spark] class Executor(
m
}
}
- val taskEndReason = new ExceptionFailure(t, metrics)
- execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(taskEndReason))
+ val serializedTaskEndReason = {
+ try {
+ ser.serialize(new ExceptionFailure(t, metrics))
+ } catch {
+ case _: NotSerializableException =>
+ // t is not serializable so just send the stacktrace
+ ser.serialize(new ExceptionFailure(t, metrics, false))
+ }
+ }
+ execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason)
// Don't forcibly exit unless the exception was inherently fatal, to avoid
// stopping other tasks unnecessarily.
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index bb489c6b6e..7ab5ccf50a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -200,8 +200,8 @@ class DAGScheduler(
// Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or
// cancellation of the job itself.
- def taskSetFailed(taskSet: TaskSet, reason: String): Unit = {
- eventProcessLoop.post(TaskSetFailed(taskSet, reason))
+ def taskSetFailed(taskSet: TaskSet, reason: String, exception: Option[Throwable]): Unit = {
+ eventProcessLoop.post(TaskSetFailed(taskSet, reason, exception))
}
private[scheduler]
@@ -677,8 +677,11 @@ class DAGScheduler(
submitWaitingStages()
}
- private[scheduler] def handleTaskSetFailed(taskSet: TaskSet, reason: String) {
- stageIdToStage.get(taskSet.stageId).foreach {abortStage(_, reason) }
+ private[scheduler] def handleTaskSetFailed(
+ taskSet: TaskSet,
+ reason: String,
+ exception: Option[Throwable]): Unit = {
+ stageIdToStage.get(taskSet.stageId).foreach { abortStage(_, reason, exception) }
submitWaitingStages()
}
@@ -762,7 +765,7 @@ class DAGScheduler(
}
}
} else {
- abortStage(stage, "No active job for stage " + stage.id)
+ abortStage(stage, "No active job for stage " + stage.id, None)
}
}
@@ -816,7 +819,7 @@ class DAGScheduler(
case NonFatal(e) =>
stage.makeNewStageAttempt(partitionsToCompute.size)
listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties))
- abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}")
+ abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}", Some(e))
runningStages -= stage
return
}
@@ -845,13 +848,13 @@ class DAGScheduler(
} catch {
// In the case of a failure during serialization, abort the stage.
case e: NotSerializableException =>
- abortStage(stage, "Task not serializable: " + e.toString)
+ abortStage(stage, "Task not serializable: " + e.toString, Some(e))
runningStages -= stage
// Abort execution
return
case NonFatal(e) =>
- abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}")
+ abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}", Some(e))
runningStages -= stage
return
}
@@ -878,7 +881,7 @@ class DAGScheduler(
}
} catch {
case NonFatal(e) =>
- abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}")
+ abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}", Some(e))
runningStages -= stage
return
}
@@ -1098,7 +1101,8 @@ class DAGScheduler(
}
if (disallowStageRetryForTest) {
- abortStage(failedStage, "Fetch failure will not retry stage due to testing config")
+ abortStage(failedStage, "Fetch failure will not retry stage due to testing config",
+ None)
} else if (failedStages.isEmpty) {
// Don't schedule an event to resubmit failed stages if failed isn't empty, because
// in that case the event will already have been scheduled.
@@ -1126,7 +1130,7 @@ class DAGScheduler(
case commitDenied: TaskCommitDenied =>
// Do nothing here, left up to the TaskScheduler to decide how to handle denied commits
- case ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics) =>
+ case exceptionFailure: ExceptionFailure =>
// Do nothing here, left up to the TaskScheduler to decide how to handle user failures
case TaskResultLost =>
@@ -1235,7 +1239,10 @@ class DAGScheduler(
* Aborts all jobs depending on a particular Stage. This is called in response to a task set
* being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.
*/
- private[scheduler] def abortStage(failedStage: Stage, reason: String) {
+ private[scheduler] def abortStage(
+ failedStage: Stage,
+ reason: String,
+ exception: Option[Throwable]): Unit = {
if (!stageIdToStage.contains(failedStage.id)) {
// Skip all the actions if the stage has been removed.
return
@@ -1244,7 +1251,7 @@ class DAGScheduler(
activeJobs.filter(job => stageDependsOn(job.finalStage, failedStage)).toSeq
failedStage.latestInfo.completionTime = Some(clock.getTimeMillis())
for (job <- dependentJobs) {
- failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason")
+ failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason", exception)
}
if (dependentJobs.isEmpty) {
logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done")
@@ -1252,8 +1259,11 @@ class DAGScheduler(
}
/** Fails a job and all stages that are only used by that job, and cleans up relevant state. */
- private def failJobAndIndependentStages(job: ActiveJob, failureReason: String) {
- val error = new SparkException(failureReason)
+ private def failJobAndIndependentStages(
+ job: ActiveJob,
+ failureReason: String,
+ exception: Option[Throwable] = None): Unit = {
+ val error = new SparkException(failureReason, exception.getOrElse(null))
var ableToCancelStages = true
val shouldInterruptThread =
@@ -1462,8 +1472,8 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler
case completion @ CompletionEvent(task, reason, _, _, taskInfo, taskMetrics) =>
dagScheduler.handleTaskCompletion(completion)
- case TaskSetFailed(taskSet, reason) =>
- dagScheduler.handleTaskSetFailed(taskSet, reason)
+ case TaskSetFailed(taskSet, reason, exception) =>
+ dagScheduler.handleTaskSetFailed(taskSet, reason, exception)
case ResubmitFailedStages =>
dagScheduler.resubmitFailedStages()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
index a213d419cf..f72a52e85d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
@@ -73,6 +73,7 @@ private[scheduler] case class ExecutorAdded(execId: String, host: String) extend
private[scheduler] case class ExecutorLost(execId: String) extends DAGSchedulerEvent
private[scheduler]
-case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent
+case class TaskSetFailed(taskSet: TaskSet, reason: String, exception: Option[Throwable])
+ extends DAGSchedulerEvent
private[scheduler] case object ResubmitFailedStages extends DAGSchedulerEvent
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 82455b0426..818b95d67f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -662,7 +662,7 @@ private[spark] class TaskSetManager(
val failureReason = s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid, ${info.host}): " +
reason.asInstanceOf[TaskFailedReason].toErrorString
- reason match {
+ val failureException: Option[Throwable] = reason match {
case fetchFailed: FetchFailed =>
logWarning(failureReason)
if (!successful(index)) {
@@ -671,6 +671,7 @@ private[spark] class TaskSetManager(
}
// Not adding to failed executors for FetchFailed.
isZombie = true
+ None
case ef: ExceptionFailure =>
taskMetrics = ef.metrics.orNull
@@ -706,12 +707,15 @@ private[spark] class TaskSetManager(
s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid) on executor ${info.host}: " +
s"${ef.className} (${ef.description}) [duplicate $dupCount]")
}
+ ef.exception
case e: TaskFailedReason => // TaskResultLost, TaskKilled, and others
logWarning(failureReason)
+ None
case e: TaskEndReason =>
logError("Unknown TaskEndReason: " + e)
+ None
}
// always add to failed executors
failedExecutors.getOrElseUpdate(index, new HashMap[String, Long]()).
@@ -728,16 +732,16 @@ private[spark] class TaskSetManager(
logError("Task %d in stage %s failed %d times; aborting job".format(
index, taskSet.id, maxTaskFailures))
abort("Task %d in stage %s failed %d times, most recent failure: %s\nDriver stacktrace:"
- .format(index, taskSet.id, maxTaskFailures, failureReason))
+ .format(index, taskSet.id, maxTaskFailures, failureReason), failureException)
return
}
}
maybeFinishTaskSet()
}
- def abort(message: String): Unit = sched.synchronized {
+ def abort(message: String, exception: Option[Throwable] = None): Unit = sched.synchronized {
// TODO: Kill running tasks if we were not terminated due to a Mesos error
- sched.dagScheduler.taskSetFailed(taskSet, message)
+ sched.dagScheduler.taskSetFailed(taskSet, message, exception)
isZombie = true
maybeFinishTaskSet()
}
diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
index c600319d9d..cbc94fd6d5 100644
--- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
@@ -790,7 +790,7 @@ private[spark] object JsonProtocol {
val fullStackTrace = Utils.jsonOption(json \ "Full Stack Trace").
map(_.extract[String]).orNull
val metrics = Utils.jsonOption(json \ "Metrics").map(taskMetricsFromJson)
- ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics)
+ ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics, None)
case `taskResultLost` => TaskResultLost
case `taskKilled` => TaskKilled
case `executorLostFailure` =>