aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContext.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContextImpl.scala33
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Task.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/util/TaskCompletionListenerException.scala34
-rw-r--r--core/src/main/scala/org/apache/spark/util/taskListeners.scala (renamed from core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala)37
-rw-r--r--core/src/test/java/test/org/apache/spark/JavaTaskCompletionListenerImpl.java39
-rw-r--r--core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java30
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala44
-rw-r--r--project/MimaExcludes.scala4
9 files changed, 169 insertions, 85 deletions
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index 9f49cf1c4c..bfcacbf229 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -23,7 +23,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.metrics.source.Source
-import org.apache.spark.util.TaskCompletionListener
+import org.apache.spark.util.{TaskCompletionListener, TaskFailureListener}
object TaskContext {
@@ -106,6 +106,8 @@ abstract class TaskContext extends Serializable {
* Adds a (Java friendly) listener to be executed on task completion.
* This will be called in all situation - success, failure, or cancellation.
* An example use is for HadoopRDD to register a callback to close the input stream.
+ *
+ * Exceptions thrown by the listener will result in failure of the task.
*/
def addTaskCompletionListener(listener: TaskCompletionListener): TaskContext
@@ -113,8 +115,30 @@ abstract class TaskContext extends Serializable {
* Adds a listener in the form of a Scala closure to be executed on task completion.
* This will be called in all situations - success, failure, or cancellation.
* An example use is for HadoopRDD to register a callback to close the input stream.
+ *
+ * Exceptions thrown by the listener will result in failure of the task.
*/
- def addTaskCompletionListener(f: (TaskContext) => Unit): TaskContext
+ def addTaskCompletionListener(f: (TaskContext) => Unit): TaskContext = {
+ addTaskCompletionListener(new TaskCompletionListener {
+ override def onTaskCompletion(context: TaskContext): Unit = f(context)
+ })
+ }
+
+ /**
+ * Adds a listener to be executed on task failure.
+ * Operations defined here must be idempotent, as `onTaskFailure` can be called multiple times.
+ */
+ def addTaskFailureListener(listener: TaskFailureListener): TaskContext
+
+ /**
+ * Adds a listener to be executed on task failure.
+ * Operations defined here must be idempotent, as `onTaskFailure` can be called multiple times.
+ */
+ def addTaskFailureListener(f: (TaskContext, Throwable) => Unit): TaskContext = {
+ addTaskFailureListener(new TaskFailureListener {
+ override def onTaskFailure(context: TaskContext, error: Throwable): Unit = f(context, error)
+ })
+ }
/**
* The ID of the stage that this task belong to.
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index 1d228b6b86..65f6f741f7 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -23,7 +23,7 @@ import org.apache.spark.executor.TaskMetrics
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.metrics.source.Source
-import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException}
+import org.apache.spark.util._
private[spark] class TaskContextImpl(
val stageId: Int,
@@ -41,9 +41,12 @@ private[spark] class TaskContextImpl(
*/
override val taskMetrics: TaskMetrics = new TaskMetrics(initialAccumulators)
- // List of callback functions to execute when the task completes.
+ /** List of callback functions to execute when the task completes. */
@transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener]
+ /** List of callback functions to execute when the task fails. */
+ @transient private val onFailureCallbacks = new ArrayBuffer[TaskFailureListener]
+
// Whether the corresponding task has been killed.
@volatile private var interrupted: Boolean = false
@@ -55,14 +58,30 @@ private[spark] class TaskContextImpl(
this
}
- override def addTaskCompletionListener(f: TaskContext => Unit): this.type = {
- onCompleteCallbacks += new TaskCompletionListener {
- override def onTaskCompletion(context: TaskContext): Unit = f(context)
- }
+ override def addTaskFailureListener(listener: TaskFailureListener): this.type = {
+ onFailureCallbacks += listener
this
}
- /** Marks the task as completed and triggers the listeners. */
+ /** Marks the task as completed and triggers the failure listeners. */
+ private[spark] def markTaskFailed(error: Throwable): Unit = {
+ val errorMsgs = new ArrayBuffer[String](2)
+ // Process complete callbacks in the reverse order of registration
+ onFailureCallbacks.reverse.foreach { listener =>
+ try {
+ listener.onTaskFailure(this, error)
+ } catch {
+ case e: Throwable =>
+ errorMsgs += e.getMessage
+ logError("Error in TaskFailureListener", e)
+ }
+ }
+ if (errorMsgs.nonEmpty) {
+ throw new TaskCompletionListenerException(errorMsgs, Option(error))
+ }
+ }
+
+ /** Marks the task as completed and triggers the completion listeners. */
private[spark] def markTaskCompleted(): Unit = {
completed = true
val errorMsgs = new ArrayBuffer[String](2)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 5c68d001f2..d2b8ca90a9 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -80,7 +80,12 @@ private[spark] abstract class Task[T](
}
try {
runTask(context)
+ } catch { case e: Throwable =>
+ // Catch all errors; run task failure callbacks, and rethrow the exception.
+ context.markTaskFailed(e)
+ throw e
} finally {
+ // Call the task completion callbacks.
context.markTaskCompleted()
try {
Utils.tryLogNonFatalError {
diff --git a/core/src/main/scala/org/apache/spark/util/TaskCompletionListenerException.scala b/core/src/main/scala/org/apache/spark/util/TaskCompletionListenerException.scala
deleted file mode 100644
index f64e069cd1..0000000000
--- a/core/src/main/scala/org/apache/spark/util/TaskCompletionListenerException.scala
+++ /dev/null
@@ -1,34 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util
-
-/**
- * Exception thrown when there is an exception in
- * executing the callback in TaskCompletionListener.
- */
-private[spark]
-class TaskCompletionListenerException(errorMessages: Seq[String]) extends Exception {
-
- override def getMessage: String = {
- if (errorMessages.size == 1) {
- errorMessages.head
- } else {
- errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n")
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala b/core/src/main/scala/org/apache/spark/util/taskListeners.scala
index c1b8bf052c..1be31e88ab 100644
--- a/core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala
+++ b/core/src/main/scala/org/apache/spark/util/taskListeners.scala
@@ -29,5 +29,40 @@ import org.apache.spark.annotation.DeveloperApi
*/
@DeveloperApi
trait TaskCompletionListener extends EventListener {
- def onTaskCompletion(context: TaskContext)
+ def onTaskCompletion(context: TaskContext): Unit
+}
+
+
+/**
+ * :: DeveloperApi ::
+ *
+ * Listener providing a callback function to invoke when a task's execution encounters an error.
+ * Operations defined here must be idempotent, as `onTaskFailure` can be called multiple times.
+ */
+@DeveloperApi
+trait TaskFailureListener extends EventListener {
+ def onTaskFailure(context: TaskContext, error: Throwable): Unit
+}
+
+
+/**
+ * Exception thrown when there is an exception in executing the callback in TaskCompletionListener.
+ */
+private[spark]
+class TaskCompletionListenerException(
+ errorMessages: Seq[String],
+ val previousError: Option[Throwable] = None)
+ extends RuntimeException {
+
+ override def getMessage: String = {
+ if (errorMessages.size == 1) {
+ errorMessages.head
+ } else {
+ errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n")
+ } +
+ previousError.map { e =>
+ "\n\nPrevious exception in task: " + e.getMessage + "\n" +
+ e.getStackTrace.mkString("\t", "\n\t", "")
+ }.getOrElse("")
+ }
}
diff --git a/core/src/test/java/test/org/apache/spark/JavaTaskCompletionListenerImpl.java b/core/src/test/java/test/org/apache/spark/JavaTaskCompletionListenerImpl.java
deleted file mode 100644
index e38bc38949..0000000000
--- a/core/src/test/java/test/org/apache/spark/JavaTaskCompletionListenerImpl.java
+++ /dev/null
@@ -1,39 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package test.org.apache.spark;
-
-import org.apache.spark.TaskContext;
-import org.apache.spark.util.TaskCompletionListener;
-
-
-/**
- * A simple implementation of TaskCompletionListener that makes sure TaskCompletionListener and
- * TaskContext is Java friendly.
- */
-public class JavaTaskCompletionListenerImpl implements TaskCompletionListener {
-
- @Override
- public void onTaskCompletion(TaskContext context) {
- context.isCompleted();
- context.isInterrupted();
- context.stageId();
- context.partitionId();
- context.isRunningLocally();
- context.addTaskCompletionListener(this);
- }
-}
diff --git a/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java
index 4a918f725d..f914081d7d 100644
--- a/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java
+++ b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java
@@ -18,6 +18,8 @@
package test.org.apache.spark;
import org.apache.spark.TaskContext;
+import org.apache.spark.util.TaskCompletionListener;
+import org.apache.spark.util.TaskFailureListener;
/**
* Something to make sure that TaskContext can be used in Java.
@@ -32,10 +34,38 @@ public class JavaTaskContextCompileCheck {
tc.isRunningLocally();
tc.addTaskCompletionListener(new JavaTaskCompletionListenerImpl());
+ tc.addTaskFailureListener(new JavaTaskFailureListenerImpl());
tc.attemptNumber();
tc.partitionId();
tc.stageId();
tc.taskAttemptId();
}
+
+ /**
+ * A simple implementation of TaskCompletionListener that makes sure TaskCompletionListener and
+ * TaskContext is Java friendly.
+ */
+ static class JavaTaskCompletionListenerImpl implements TaskCompletionListener {
+ @Override
+ public void onTaskCompletion(TaskContext context) {
+ context.isCompleted();
+ context.isInterrupted();
+ context.stageId();
+ context.partitionId();
+ context.isRunningLocally();
+ context.addTaskCompletionListener(this);
+ }
+ }
+
+ /**
+ * A simple implementation of TaskCompletionListener that makes sure TaskCompletionListener and
+ * TaskContext is Java friendly.
+ */
+ static class JavaTaskFailureListenerImpl implements TaskFailureListener {
+ @Override
+ public void onTaskFailure(TaskContext context, Throwable error) {
+ }
+ }
+
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index 850e470ca1..c4cf2f9f70 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -27,7 +27,7 @@ import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.metrics.source.JvmSource
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.rdd.RDD
-import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException}
+import org.apache.spark.util._
class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext {
@@ -66,6 +66,26 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
assert(TaskContextSuite.completed === true)
}
+ test("calls TaskFailureListeners after failure") {
+ TaskContextSuite.lastError = null
+ sc = new SparkContext("local", "test")
+ val rdd = new RDD[String](sc, List()) {
+ override def getPartitions = Array[Partition](StubPartition(0))
+ override def compute(split: Partition, context: TaskContext) = {
+ context.addTaskFailureListener((context, error) => TaskContextSuite.lastError = error)
+ sys.error("damn error")
+ }
+ }
+ val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
+ val func = (c: TaskContext, i: Iterator[String]) => i.next()
+ val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func))))
+ val task = new ResultTask[String, String](0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0)
+ intercept[RuntimeException] {
+ task.run(0, 0, null)
+ }
+ assert(TaskContextSuite.lastError.getMessage == "damn error")
+ }
+
test("all TaskCompletionListeners should be called even if some fail") {
val context = TaskContext.empty()
val listener = mock(classOf[TaskCompletionListener])
@@ -80,6 +100,26 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
verify(listener, times(1)).onTaskCompletion(any())
}
+ test("all TaskFailureListeners should be called even if some fail") {
+ val context = TaskContext.empty()
+ val listener = mock(classOf[TaskFailureListener])
+ context.addTaskFailureListener((_, _) => throw new Exception("exception in listener1"))
+ context.addTaskFailureListener(listener)
+ context.addTaskFailureListener((_, _) => throw new Exception("exception in listener3"))
+
+ val e = intercept[TaskCompletionListenerException] {
+ context.markTaskFailed(new Exception("exception in task"))
+ }
+
+ // Make sure listener 2 was called.
+ verify(listener, times(1)).onTaskFailure(any(), any())
+
+ // also need to check failure in TaskFailureListener does not mask earlier exception
+ assert(e.getMessage.contains("exception in listener1"))
+ assert(e.getMessage.contains("exception in listener3"))
+ assert(e.getMessage.contains("exception in task"))
+ }
+
test("TaskContext.attemptNumber should return attempt number, not task id (SPARK-4014)") {
sc = new SparkContext("local[1,2]", "test") // use maxRetries = 2 because we test failed tasks
// Check that attemptIds are 0 for all tasks' initial attempts
@@ -153,6 +193,8 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
private object TaskContextSuite {
@volatile var completed = false
+
+ @volatile var lastError: Throwable = _
}
private case class StubPartition(index: Int) extends Partition
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 14e3c90f1b..165280a1b2 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -271,7 +271,9 @@ object MimaExcludes {
) ++ Seq(
// SPARK-13220 Deprecate yarn-client and yarn-cluster mode
ProblemFilters.exclude[MissingMethodProblem](
- "org.apache.spark.SparkContext.org$apache$spark$SparkContext$$createTaskScheduler")
+ "org.apache.spark.SparkContext.org$apache$spark$SparkContext$$createTaskScheduler"),
+ // SPARK-13465 TaskContext.
+ ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.addTaskFailureListener")
) ++ Seq (
// SPARK-7729 Executor which has been killed should also be displayed on Executor Tab
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.ExecutorSummary.this")