diff options
-rw-r--r-- | core/src/main/scala/org/apache/spark/TaskContext.scala | 28 | ||||
-rw-r--r-- | core/src/main/scala/org/apache/spark/TaskContextImpl.scala | 33 | ||||
-rw-r--r-- | core/src/main/scala/org/apache/spark/scheduler/Task.scala | 5 | ||||
-rw-r--r-- | core/src/main/scala/org/apache/spark/util/TaskCompletionListenerException.scala | 34 | ||||
-rw-r--r-- | core/src/main/scala/org/apache/spark/util/taskListeners.scala (renamed from core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala) | 37 | ||||
-rw-r--r-- | core/src/test/java/test/org/apache/spark/JavaTaskCompletionListenerImpl.java | 39 | ||||
-rw-r--r-- | core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java | 30 | ||||
-rw-r--r-- | core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala | 44 | ||||
-rw-r--r-- | project/MimaExcludes.scala | 4 |
9 files changed, 169 insertions, 85 deletions
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 9f49cf1c4c..bfcacbf229 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -23,7 +23,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.source.Source -import org.apache.spark.util.TaskCompletionListener +import org.apache.spark.util.{TaskCompletionListener, TaskFailureListener} object TaskContext { @@ -106,6 +106,8 @@ abstract class TaskContext extends Serializable { * Adds a (Java friendly) listener to be executed on task completion. * This will be called in all situation - success, failure, or cancellation. * An example use is for HadoopRDD to register a callback to close the input stream. + * + * Exceptions thrown by the listener will result in failure of the task. */ def addTaskCompletionListener(listener: TaskCompletionListener): TaskContext @@ -113,8 +115,30 @@ abstract class TaskContext extends Serializable { * Adds a listener in the form of a Scala closure to be executed on task completion. * This will be called in all situations - success, failure, or cancellation. * An example use is for HadoopRDD to register a callback to close the input stream. + * + * Exceptions thrown by the listener will result in failure of the task. */ - def addTaskCompletionListener(f: (TaskContext) => Unit): TaskContext + def addTaskCompletionListener(f: (TaskContext) => Unit): TaskContext = { + addTaskCompletionListener(new TaskCompletionListener { + override def onTaskCompletion(context: TaskContext): Unit = f(context) + }) + } + + /** + * Adds a listener to be executed on task failure. + * Operations defined here must be idempotent, as `onTaskFailure` can be called multiple times. + */ + def addTaskFailureListener(listener: TaskFailureListener): TaskContext + + /** + * Adds a listener to be executed on task failure. + * Operations defined here must be idempotent, as `onTaskFailure` can be called multiple times. + */ + def addTaskFailureListener(f: (TaskContext, Throwable) => Unit): TaskContext = { + addTaskFailureListener(new TaskFailureListener { + override def onTaskFailure(context: TaskContext, error: Throwable): Unit = f(context, error) + }) + } /** * The ID of the stage that this task belong to. diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 1d228b6b86..65f6f741f7 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -23,7 +23,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.MetricsSystem import org.apache.spark.metrics.source.Source -import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException} +import org.apache.spark.util._ private[spark] class TaskContextImpl( val stageId: Int, @@ -41,9 +41,12 @@ private[spark] class TaskContextImpl( */ override val taskMetrics: TaskMetrics = new TaskMetrics(initialAccumulators) - // List of callback functions to execute when the task completes. + /** List of callback functions to execute when the task completes. */ @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener] + /** List of callback functions to execute when the task fails. */ + @transient private val onFailureCallbacks = new ArrayBuffer[TaskFailureListener] + // Whether the corresponding task has been killed. @volatile private var interrupted: Boolean = false @@ -55,14 +58,30 @@ private[spark] class TaskContextImpl( this } - override def addTaskCompletionListener(f: TaskContext => Unit): this.type = { - onCompleteCallbacks += new TaskCompletionListener { - override def onTaskCompletion(context: TaskContext): Unit = f(context) - } + override def addTaskFailureListener(listener: TaskFailureListener): this.type = { + onFailureCallbacks += listener this } - /** Marks the task as completed and triggers the listeners. */ + /** Marks the task as completed and triggers the failure listeners. */ + private[spark] def markTaskFailed(error: Throwable): Unit = { + val errorMsgs = new ArrayBuffer[String](2) + // Process complete callbacks in the reverse order of registration + onFailureCallbacks.reverse.foreach { listener => + try { + listener.onTaskFailure(this, error) + } catch { + case e: Throwable => + errorMsgs += e.getMessage + logError("Error in TaskFailureListener", e) + } + } + if (errorMsgs.nonEmpty) { + throw new TaskCompletionListenerException(errorMsgs, Option(error)) + } + } + + /** Marks the task as completed and triggers the completion listeners. */ private[spark] def markTaskCompleted(): Unit = { completed = true val errorMsgs = new ArrayBuffer[String](2) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 5c68d001f2..d2b8ca90a9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -80,7 +80,12 @@ private[spark] abstract class Task[T]( } try { runTask(context) + } catch { case e: Throwable => + // Catch all errors; run task failure callbacks, and rethrow the exception. + context.markTaskFailed(e) + throw e } finally { + // Call the task completion callbacks. context.markTaskCompleted() try { Utils.tryLogNonFatalError { diff --git a/core/src/main/scala/org/apache/spark/util/TaskCompletionListenerException.scala b/core/src/main/scala/org/apache/spark/util/TaskCompletionListenerException.scala deleted file mode 100644 index f64e069cd1..0000000000 --- a/core/src/main/scala/org/apache/spark/util/TaskCompletionListenerException.scala +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -/** - * Exception thrown when there is an exception in - * executing the callback in TaskCompletionListener. - */ -private[spark] -class TaskCompletionListenerException(errorMessages: Seq[String]) extends Exception { - - override def getMessage: String = { - if (errorMessages.size == 1) { - errorMessages.head - } else { - errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n") - } - } -} diff --git a/core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala b/core/src/main/scala/org/apache/spark/util/taskListeners.scala index c1b8bf052c..1be31e88ab 100644 --- a/core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala +++ b/core/src/main/scala/org/apache/spark/util/taskListeners.scala @@ -29,5 +29,40 @@ import org.apache.spark.annotation.DeveloperApi */ @DeveloperApi trait TaskCompletionListener extends EventListener { - def onTaskCompletion(context: TaskContext) + def onTaskCompletion(context: TaskContext): Unit +} + + +/** + * :: DeveloperApi :: + * + * Listener providing a callback function to invoke when a task's execution encounters an error. + * Operations defined here must be idempotent, as `onTaskFailure` can be called multiple times. + */ +@DeveloperApi +trait TaskFailureListener extends EventListener { + def onTaskFailure(context: TaskContext, error: Throwable): Unit +} + + +/** + * Exception thrown when there is an exception in executing the callback in TaskCompletionListener. + */ +private[spark] +class TaskCompletionListenerException( + errorMessages: Seq[String], + val previousError: Option[Throwable] = None) + extends RuntimeException { + + override def getMessage: String = { + if (errorMessages.size == 1) { + errorMessages.head + } else { + errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n") + } + + previousError.map { e => + "\n\nPrevious exception in task: " + e.getMessage + "\n" + + e.getStackTrace.mkString("\t", "\n\t", "") + }.getOrElse("") + } } diff --git a/core/src/test/java/test/org/apache/spark/JavaTaskCompletionListenerImpl.java b/core/src/test/java/test/org/apache/spark/JavaTaskCompletionListenerImpl.java deleted file mode 100644 index e38bc38949..0000000000 --- a/core/src/test/java/test/org/apache/spark/JavaTaskCompletionListenerImpl.java +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package test.org.apache.spark; - -import org.apache.spark.TaskContext; -import org.apache.spark.util.TaskCompletionListener; - - -/** - * A simple implementation of TaskCompletionListener that makes sure TaskCompletionListener and - * TaskContext is Java friendly. - */ -public class JavaTaskCompletionListenerImpl implements TaskCompletionListener { - - @Override - public void onTaskCompletion(TaskContext context) { - context.isCompleted(); - context.isInterrupted(); - context.stageId(); - context.partitionId(); - context.isRunningLocally(); - context.addTaskCompletionListener(this); - } -} diff --git a/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java index 4a918f725d..f914081d7d 100644 --- a/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java +++ b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java @@ -18,6 +18,8 @@ package test.org.apache.spark; import org.apache.spark.TaskContext; +import org.apache.spark.util.TaskCompletionListener; +import org.apache.spark.util.TaskFailureListener; /** * Something to make sure that TaskContext can be used in Java. @@ -32,10 +34,38 @@ public class JavaTaskContextCompileCheck { tc.isRunningLocally(); tc.addTaskCompletionListener(new JavaTaskCompletionListenerImpl()); + tc.addTaskFailureListener(new JavaTaskFailureListenerImpl()); tc.attemptNumber(); tc.partitionId(); tc.stageId(); tc.taskAttemptId(); } + + /** + * A simple implementation of TaskCompletionListener that makes sure TaskCompletionListener and + * TaskContext is Java friendly. + */ + static class JavaTaskCompletionListenerImpl implements TaskCompletionListener { + @Override + public void onTaskCompletion(TaskContext context) { + context.isCompleted(); + context.isInterrupted(); + context.stageId(); + context.partitionId(); + context.isRunningLocally(); + context.addTaskCompletionListener(this); + } + } + + /** + * A simple implementation of TaskCompletionListener that makes sure TaskCompletionListener and + * TaskContext is Java friendly. + */ + static class JavaTaskFailureListenerImpl implements TaskFailureListener { + @Override + public void onTaskFailure(TaskContext context, Throwable error) { + } + } + } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 850e470ca1..c4cf2f9f70 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.source.JvmSource import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD -import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException} +import org.apache.spark.util._ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { @@ -66,6 +66,26 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark assert(TaskContextSuite.completed === true) } + test("calls TaskFailureListeners after failure") { + TaskContextSuite.lastError = null + sc = new SparkContext("local", "test") + val rdd = new RDD[String](sc, List()) { + override def getPartitions = Array[Partition](StubPartition(0)) + override def compute(split: Partition, context: TaskContext) = { + context.addTaskFailureListener((context, error) => TaskContextSuite.lastError = error) + sys.error("damn error") + } + } + val closureSerializer = SparkEnv.get.closureSerializer.newInstance() + val func = (c: TaskContext, i: Iterator[String]) => i.next() + val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func)))) + val task = new ResultTask[String, String](0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0) + intercept[RuntimeException] { + task.run(0, 0, null) + } + assert(TaskContextSuite.lastError.getMessage == "damn error") + } + test("all TaskCompletionListeners should be called even if some fail") { val context = TaskContext.empty() val listener = mock(classOf[TaskCompletionListener]) @@ -80,6 +100,26 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark verify(listener, times(1)).onTaskCompletion(any()) } + test("all TaskFailureListeners should be called even if some fail") { + val context = TaskContext.empty() + val listener = mock(classOf[TaskFailureListener]) + context.addTaskFailureListener((_, _) => throw new Exception("exception in listener1")) + context.addTaskFailureListener(listener) + context.addTaskFailureListener((_, _) => throw new Exception("exception in listener3")) + + val e = intercept[TaskCompletionListenerException] { + context.markTaskFailed(new Exception("exception in task")) + } + + // Make sure listener 2 was called. + verify(listener, times(1)).onTaskFailure(any(), any()) + + // also need to check failure in TaskFailureListener does not mask earlier exception + assert(e.getMessage.contains("exception in listener1")) + assert(e.getMessage.contains("exception in listener3")) + assert(e.getMessage.contains("exception in task")) + } + test("TaskContext.attemptNumber should return attempt number, not task id (SPARK-4014)") { sc = new SparkContext("local[1,2]", "test") // use maxRetries = 2 because we test failed tasks // Check that attemptIds are 0 for all tasks' initial attempts @@ -153,6 +193,8 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark private object TaskContextSuite { @volatile var completed = false + + @volatile var lastError: Throwable = _ } private case class StubPartition(index: Int) extends Partition diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 14e3c90f1b..165280a1b2 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -271,7 +271,9 @@ object MimaExcludes { ) ++ Seq( // SPARK-13220 Deprecate yarn-client and yarn-cluster mode ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.SparkContext.org$apache$spark$SparkContext$$createTaskScheduler") + "org.apache.spark.SparkContext.org$apache$spark$SparkContext$$createTaskScheduler"), + // SPARK-13465 TaskContext. + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.addTaskFailureListener") ) ++ Seq ( // SPARK-7729 Executor which has been killed should also be displayed on Executor Tab ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.ExecutorSummary.this") |