aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/TaskEndReason.scala44
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala44
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/util/JsonProtocol.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/FailureSuite.scala66
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala5
-rw-r--r--core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala3
12 files changed, 165 insertions, 34 deletions
diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
index 48fd3e7e23..934d00dc70 100644
--- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala
+++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
@@ -17,6 +17,8 @@
package org.apache.spark
+import java.io.{IOException, ObjectInputStream, ObjectOutputStream}
+
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.storage.BlockManagerId
@@ -90,6 +92,10 @@ case class FetchFailed(
*
* `fullStackTrace` is a better representation of the stack trace because it contains the whole
* stack trace including the exception and its causes
+ *
+ * `exception` is the actual exception that caused the task to fail. It may be `None` in
+ * the case that the exception is not in fact serializable. If a task fails more than
+ * once (due to retries), `exception` is that one that caused the last failure.
*/
@DeveloperApi
case class ExceptionFailure(
@@ -97,11 +103,26 @@ case class ExceptionFailure(
description: String,
stackTrace: Array[StackTraceElement],
fullStackTrace: String,
- metrics: Option[TaskMetrics])
+ metrics: Option[TaskMetrics],
+ private val exceptionWrapper: Option[ThrowableSerializationWrapper])
extends TaskFailedReason {
+ /**
+ * `preserveCause` is used to keep the exception itself so it is available to the
+ * driver. This may be set to `false` in the event that the exception is not in fact
+ * serializable.
+ */
+ private[spark] def this(e: Throwable, metrics: Option[TaskMetrics], preserveCause: Boolean) {
+ this(e.getClass.getName, e.getMessage, e.getStackTrace, Utils.exceptionString(e), metrics,
+ if (preserveCause) Some(new ThrowableSerializationWrapper(e)) else None)
+ }
+
private[spark] def this(e: Throwable, metrics: Option[TaskMetrics]) {
- this(e.getClass.getName, e.getMessage, e.getStackTrace, Utils.exceptionString(e), metrics)
+ this(e, metrics, preserveCause = true)
+ }
+
+ def exception: Option[Throwable] = exceptionWrapper.flatMap {
+ (w: ThrowableSerializationWrapper) => Option(w.exception)
}
override def toErrorString: String =
@@ -128,6 +149,25 @@ case class ExceptionFailure(
}
/**
+ * A class for recovering from exceptions when deserializing a Throwable that was
+ * thrown in user task code. If the Throwable cannot be deserialized it will be null,
+ * but the stacktrace and message will be preserved correctly in SparkException.
+ */
+private[spark] class ThrowableSerializationWrapper(var exception: Throwable) extends
+ Serializable with Logging {
+ private def writeObject(out: ObjectOutputStream): Unit = {
+ out.writeObject(exception)
+ }
+ private def readObject(in: ObjectInputStream): Unit = {
+ try {
+ exception = in.readObject().asInstanceOf[Throwable]
+ } catch {
+ case e : Exception => log.warn("Task exception could not be deserialized", e)
+ }
+ }
+}
+
+/**
* :: DeveloperApi ::
* The task finished successfully, but the result was lost from the executor's block manager before
* it was fetched.
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 5d78a9dc88..42a85e42ea 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -17,7 +17,7 @@
package org.apache.spark.executor
-import java.io.File
+import java.io.{File, NotSerializableException}
import java.lang.management.ManagementFactory
import java.net.URL
import java.nio.ByteBuffer
@@ -305,8 +305,16 @@ private[spark] class Executor(
m
}
}
- val taskEndReason = new ExceptionFailure(t, metrics)
- execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(taskEndReason))
+ val serializedTaskEndReason = {
+ try {
+ ser.serialize(new ExceptionFailure(t, metrics))
+ } catch {
+ case _: NotSerializableException =>
+ // t is not serializable so just send the stacktrace
+ ser.serialize(new ExceptionFailure(t, metrics, false))
+ }
+ }
+ execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason)
// Don't forcibly exit unless the exception was inherently fatal, to avoid
// stopping other tasks unnecessarily.
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index bb489c6b6e..7ab5ccf50a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -200,8 +200,8 @@ class DAGScheduler(
// Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or
// cancellation of the job itself.
- def taskSetFailed(taskSet: TaskSet, reason: String): Unit = {
- eventProcessLoop.post(TaskSetFailed(taskSet, reason))
+ def taskSetFailed(taskSet: TaskSet, reason: String, exception: Option[Throwable]): Unit = {
+ eventProcessLoop.post(TaskSetFailed(taskSet, reason, exception))
}
private[scheduler]
@@ -677,8 +677,11 @@ class DAGScheduler(
submitWaitingStages()
}
- private[scheduler] def handleTaskSetFailed(taskSet: TaskSet, reason: String) {
- stageIdToStage.get(taskSet.stageId).foreach {abortStage(_, reason) }
+ private[scheduler] def handleTaskSetFailed(
+ taskSet: TaskSet,
+ reason: String,
+ exception: Option[Throwable]): Unit = {
+ stageIdToStage.get(taskSet.stageId).foreach { abortStage(_, reason, exception) }
submitWaitingStages()
}
@@ -762,7 +765,7 @@ class DAGScheduler(
}
}
} else {
- abortStage(stage, "No active job for stage " + stage.id)
+ abortStage(stage, "No active job for stage " + stage.id, None)
}
}
@@ -816,7 +819,7 @@ class DAGScheduler(
case NonFatal(e) =>
stage.makeNewStageAttempt(partitionsToCompute.size)
listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties))
- abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}")
+ abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}", Some(e))
runningStages -= stage
return
}
@@ -845,13 +848,13 @@ class DAGScheduler(
} catch {
// In the case of a failure during serialization, abort the stage.
case e: NotSerializableException =>
- abortStage(stage, "Task not serializable: " + e.toString)
+ abortStage(stage, "Task not serializable: " + e.toString, Some(e))
runningStages -= stage
// Abort execution
return
case NonFatal(e) =>
- abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}")
+ abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}", Some(e))
runningStages -= stage
return
}
@@ -878,7 +881,7 @@ class DAGScheduler(
}
} catch {
case NonFatal(e) =>
- abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}")
+ abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}", Some(e))
runningStages -= stage
return
}
@@ -1098,7 +1101,8 @@ class DAGScheduler(
}
if (disallowStageRetryForTest) {
- abortStage(failedStage, "Fetch failure will not retry stage due to testing config")
+ abortStage(failedStage, "Fetch failure will not retry stage due to testing config",
+ None)
} else if (failedStages.isEmpty) {
// Don't schedule an event to resubmit failed stages if failed isn't empty, because
// in that case the event will already have been scheduled.
@@ -1126,7 +1130,7 @@ class DAGScheduler(
case commitDenied: TaskCommitDenied =>
// Do nothing here, left up to the TaskScheduler to decide how to handle denied commits
- case ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics) =>
+ case exceptionFailure: ExceptionFailure =>
// Do nothing here, left up to the TaskScheduler to decide how to handle user failures
case TaskResultLost =>
@@ -1235,7 +1239,10 @@ class DAGScheduler(
* Aborts all jobs depending on a particular Stage. This is called in response to a task set
* being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.
*/
- private[scheduler] def abortStage(failedStage: Stage, reason: String) {
+ private[scheduler] def abortStage(
+ failedStage: Stage,
+ reason: String,
+ exception: Option[Throwable]): Unit = {
if (!stageIdToStage.contains(failedStage.id)) {
// Skip all the actions if the stage has been removed.
return
@@ -1244,7 +1251,7 @@ class DAGScheduler(
activeJobs.filter(job => stageDependsOn(job.finalStage, failedStage)).toSeq
failedStage.latestInfo.completionTime = Some(clock.getTimeMillis())
for (job <- dependentJobs) {
- failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason")
+ failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason", exception)
}
if (dependentJobs.isEmpty) {
logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done")
@@ -1252,8 +1259,11 @@ class DAGScheduler(
}
/** Fails a job and all stages that are only used by that job, and cleans up relevant state. */
- private def failJobAndIndependentStages(job: ActiveJob, failureReason: String) {
- val error = new SparkException(failureReason)
+ private def failJobAndIndependentStages(
+ job: ActiveJob,
+ failureReason: String,
+ exception: Option[Throwable] = None): Unit = {
+ val error = new SparkException(failureReason, exception.getOrElse(null))
var ableToCancelStages = true
val shouldInterruptThread =
@@ -1462,8 +1472,8 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler
case completion @ CompletionEvent(task, reason, _, _, taskInfo, taskMetrics) =>
dagScheduler.handleTaskCompletion(completion)
- case TaskSetFailed(taskSet, reason) =>
- dagScheduler.handleTaskSetFailed(taskSet, reason)
+ case TaskSetFailed(taskSet, reason, exception) =>
+ dagScheduler.handleTaskSetFailed(taskSet, reason, exception)
case ResubmitFailedStages =>
dagScheduler.resubmitFailedStages()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
index a213d419cf..f72a52e85d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
@@ -73,6 +73,7 @@ private[scheduler] case class ExecutorAdded(execId: String, host: String) extend
private[scheduler] case class ExecutorLost(execId: String) extends DAGSchedulerEvent
private[scheduler]
-case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent
+case class TaskSetFailed(taskSet: TaskSet, reason: String, exception: Option[Throwable])
+ extends DAGSchedulerEvent
private[scheduler] case object ResubmitFailedStages extends DAGSchedulerEvent
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 82455b0426..818b95d67f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -662,7 +662,7 @@ private[spark] class TaskSetManager(
val failureReason = s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid, ${info.host}): " +
reason.asInstanceOf[TaskFailedReason].toErrorString
- reason match {
+ val failureException: Option[Throwable] = reason match {
case fetchFailed: FetchFailed =>
logWarning(failureReason)
if (!successful(index)) {
@@ -671,6 +671,7 @@ private[spark] class TaskSetManager(
}
// Not adding to failed executors for FetchFailed.
isZombie = true
+ None
case ef: ExceptionFailure =>
taskMetrics = ef.metrics.orNull
@@ -706,12 +707,15 @@ private[spark] class TaskSetManager(
s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid) on executor ${info.host}: " +
s"${ef.className} (${ef.description}) [duplicate $dupCount]")
}
+ ef.exception
case e: TaskFailedReason => // TaskResultLost, TaskKilled, and others
logWarning(failureReason)
+ None
case e: TaskEndReason =>
logError("Unknown TaskEndReason: " + e)
+ None
}
// always add to failed executors
failedExecutors.getOrElseUpdate(index, new HashMap[String, Long]()).
@@ -728,16 +732,16 @@ private[spark] class TaskSetManager(
logError("Task %d in stage %s failed %d times; aborting job".format(
index, taskSet.id, maxTaskFailures))
abort("Task %d in stage %s failed %d times, most recent failure: %s\nDriver stacktrace:"
- .format(index, taskSet.id, maxTaskFailures, failureReason))
+ .format(index, taskSet.id, maxTaskFailures, failureReason), failureException)
return
}
}
maybeFinishTaskSet()
}
- def abort(message: String): Unit = sched.synchronized {
+ def abort(message: String, exception: Option[Throwable] = None): Unit = sched.synchronized {
// TODO: Kill running tasks if we were not terminated due to a Mesos error
- sched.dagScheduler.taskSetFailed(taskSet, message)
+ sched.dagScheduler.taskSetFailed(taskSet, message, exception)
isZombie = true
maybeFinishTaskSet()
}
diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
index c600319d9d..cbc94fd6d5 100644
--- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
@@ -790,7 +790,7 @@ private[spark] object JsonProtocol {
val fullStackTrace = Utils.jsonOption(json \ "Full Stack Trace").
map(_.extract[String]).orNull
val metrics = Utils.jsonOption(json \ "Metrics").map(taskMetricsFromJson)
- ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics)
+ ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics, None)
case `taskResultLost` => TaskResultLost
case `taskKilled` => TaskKilled
case `executorLostFailure` =>
diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
index f374f97f87..116f027a0f 100644
--- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
@@ -800,7 +800,7 @@ class ExecutorAllocationManagerSuite
assert(maxNumExecutorsNeeded(manager) === 1)
// If the task is failed, we expect it to be resubmitted later.
- val taskEndReason = ExceptionFailure(null, null, null, null, null)
+ val taskEndReason = ExceptionFailure(null, null, null, null, null, None)
sc.listenerBus.postToAll(SparkListenerTaskEnd(0, 0, null, taskEndReason, taskInfo, null))
assert(maxNumExecutorsNeeded(manager) === 1)
}
diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala
index 69cb4b44cf..aa50a49c50 100644
--- a/core/src/test/scala/org/apache/spark/FailureSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark
import org.apache.spark.util.NonSerializable
-import java.io.NotSerializableException
+import java.io.{IOException, NotSerializableException, ObjectInputStream}
// Common state shared by FailureSuite-launched tasks. We use a global object
// for this because any local variables used in the task closures will rightfully
@@ -166,5 +166,69 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext {
assert(thrownDueToMemoryLeak.getMessage.contains("memory leak"))
}
+ // Run a 3-task map job in which task 1 always fails with a exception message that
+ // depends on the failure number, and check that we get the last failure.
+ test("last failure cause is sent back to driver") {
+ sc = new SparkContext("local[1,2]", "test")
+ val data = sc.makeRDD(1 to 3, 3).map { x =>
+ FailureSuiteState.synchronized {
+ FailureSuiteState.tasksRun += 1
+ if (x == 3) {
+ FailureSuiteState.tasksFailed += 1
+ throw new UserException("oops",
+ new IllegalArgumentException("failed=" + FailureSuiteState.tasksFailed))
+ }
+ }
+ x * x
+ }
+ val thrown = intercept[SparkException] {
+ data.collect()
+ }
+ FailureSuiteState.synchronized {
+ assert(FailureSuiteState.tasksRun === 4)
+ }
+ assert(thrown.getClass === classOf[SparkException])
+ assert(thrown.getCause.getClass === classOf[UserException])
+ assert(thrown.getCause.getMessage === "oops")
+ assert(thrown.getCause.getCause.getClass === classOf[IllegalArgumentException])
+ assert(thrown.getCause.getCause.getMessage === "failed=2")
+ FailureSuiteState.clear()
+ }
+
+ test("failure cause stacktrace is sent back to driver if exception is not serializable") {
+ sc = new SparkContext("local", "test")
+ val thrown = intercept[SparkException] {
+ sc.makeRDD(1 to 3).foreach { _ => throw new NonSerializableUserException }
+ }
+ assert(thrown.getClass === classOf[SparkException])
+ assert(thrown.getCause === null)
+ assert(thrown.getMessage.contains("NonSerializableUserException"))
+ FailureSuiteState.clear()
+ }
+
+ test("failure cause stacktrace is sent back to driver if exception is not deserializable") {
+ sc = new SparkContext("local", "test")
+ val thrown = intercept[SparkException] {
+ sc.makeRDD(1 to 3).foreach { _ => throw new NonDeserializableUserException }
+ }
+ assert(thrown.getClass === classOf[SparkException])
+ assert(thrown.getCause === null)
+ assert(thrown.getMessage.contains("NonDeserializableUserException"))
+ FailureSuiteState.clear()
+ }
+
// TODO: Need to add tests with shuffle fetch failures.
}
+
+class UserException(message: String, cause: Throwable)
+ extends RuntimeException(message, cause)
+
+class NonSerializableUserException extends RuntimeException {
+ val nonSerializableInstanceVariable = new NonSerializable
+}
+
+class NonDeserializableUserException extends RuntimeException {
+ private def readObject(in: ObjectInputStream): Unit = {
+ throw new IOException("Intentional exception during deserialization.")
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 86dff8fb57..b0ca49cbea 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -242,7 +242,7 @@ class DAGSchedulerSuite
/** Sends TaskSetFailed to the scheduler. */
private def failed(taskSet: TaskSet, message: String) {
- runEvent(TaskSetFailed(taskSet, message))
+ runEvent(TaskSetFailed(taskSet, message, None))
}
/** Sends JobCancelled to the DAG scheduler. */
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index f7cc4bb61d..edbdb485c5 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -48,7 +48,10 @@ class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler)
override def executorLost(execId: String) {}
- override def taskSetFailed(taskSet: TaskSet, reason: String) {
+ override def taskSetFailed(
+ taskSet: TaskSet,
+ reason: String,
+ exception: Option[Throwable]): Unit = {
taskScheduler.taskSetsFailed += taskSet.id
}
}
diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
index 56f7b9cf1f..b140387d30 100644
--- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
@@ -240,7 +240,7 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with
val taskFailedReasons = Seq(
Resubmitted,
new FetchFailed(null, 0, 0, 0, "ignored"),
- ExceptionFailure("Exception", "description", null, null, None),
+ ExceptionFailure("Exception", "description", null, null, None, None),
TaskResultLost,
TaskKilled,
ExecutorLostFailure("0"),
diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
index dde95f3778..343a4139b0 100644
--- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
@@ -163,7 +163,8 @@ class JsonProtocolSuite extends SparkFunSuite {
}
test("ExceptionFailure backward compatibility") {
- val exceptionFailure = ExceptionFailure("To be", "or not to be", stackTrace, null, None)
+ val exceptionFailure = ExceptionFailure("To be", "or not to be", stackTrace, null,
+ None, None)
val oldEvent = JsonProtocol.taskEndReasonToJson(exceptionFailure)
.removeField({ _._1 == "Full Stack Trace" })
assertEquals(exceptionFailure, JsonProtocol.taskEndReasonFromJson(oldEvent))