diff options
Diffstat (limited to 'core')
3 files changed, 93 insertions, 34 deletions
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index f0867ecb16..5acfce1759 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -105,7 +105,9 @@ abstract class TaskContext extends Serializable { /** * Adds a (Java friendly) listener to be executed on task completion. - * This will be called in all situation - success, failure, or cancellation. + * This will be called in all situations - success, failure, or cancellation. Adding a listener + * to an already completed task will result in that listener being called immediately. + * * An example use is for HadoopRDD to register a callback to close the input stream. * * Exceptions thrown by the listener will result in failure of the task. @@ -114,7 +116,9 @@ abstract class TaskContext extends Serializable { /** * Adds a listener in the form of a Scala closure to be executed on task completion. - * This will be called in all situations - success, failure, or cancellation. + * This will be called in all situations - success, failure, or cancellation. Adding a listener + * to an already completed task will result in that listener being called immediately. + * * An example use is for HadoopRDD to register a callback to close the input stream. * * Exceptions thrown by the listener will result in failure of the task. @@ -126,14 +130,14 @@ abstract class TaskContext extends Serializable { } /** - * Adds a listener to be executed on task failure. - * Operations defined here must be idempotent, as `onTaskFailure` can be called multiple times. + * Adds a listener to be executed on task failure. Adding a listener to an already failed task + * will result in that listener being called immediately. */ def addTaskFailureListener(listener: TaskFailureListener): TaskContext /** - * Adds a listener to be executed on task failure. - * Operations defined here must be idempotent, as `onTaskFailure` can be called multiple times. + * Adds a listener to be executed on task failure. Adding a listener to an already failed task + * will result in that listener being called immediately. */ def addTaskFailureListener(f: (TaskContext, Throwable) => Unit): TaskContext = { addTaskFailureListener(new TaskFailureListener { diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index dc0d128785..ea8dcdfd5d 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -18,6 +18,7 @@ package org.apache.spark import java.util.Properties +import javax.annotation.concurrent.GuardedBy import scala.collection.mutable.ArrayBuffer @@ -29,6 +30,16 @@ import org.apache.spark.metrics.source.Source import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util._ +/** + * A [[TaskContext]] implementation. + * + * A small note on thread safety. The interrupted & fetchFailed fields are volatile, this makes + * sure that updates are always visible across threads. The complete & failed flags and their + * callbacks are protected by locking on the context instance. For instance, this ensures + * that you cannot add a completion listener in one thread while we are completing (and calling + * the completion listeners) in another thread. Other state is immutable, however the exposed + * [[TaskMetrics]] & [[MetricsSystem]] objects are not thread safe. + */ private[spark] class TaskContextImpl( val stageId: Int, val partitionId: Int, @@ -52,62 +63,79 @@ private[spark] class TaskContextImpl( @volatile private var interrupted: Boolean = false // Whether the task has completed. - @volatile private var completed: Boolean = false + private var completed: Boolean = false // Whether the task has failed. - @volatile private var failed: Boolean = false + private var failed: Boolean = false + + // Throwable that caused the task to fail + private var failure: Throwable = _ // If there was a fetch failure in the task, we store it here, to make sure user-code doesn't // hide the exception. See SPARK-19276 @volatile private var _fetchFailedException: Option[FetchFailedException] = None - override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = { - onCompleteCallbacks += listener + @GuardedBy("this") + override def addTaskCompletionListener(listener: TaskCompletionListener) + : this.type = synchronized { + if (completed) { + listener.onTaskCompletion(this) + } else { + onCompleteCallbacks += listener + } this } - override def addTaskFailureListener(listener: TaskFailureListener): this.type = { - onFailureCallbacks += listener + @GuardedBy("this") + override def addTaskFailureListener(listener: TaskFailureListener) + : this.type = synchronized { + if (failed) { + listener.onTaskFailure(this, failure) + } else { + onFailureCallbacks += listener + } this } /** Marks the task as failed and triggers the failure listeners. */ - private[spark] def markTaskFailed(error: Throwable): Unit = { - // failure callbacks should only be called once + @GuardedBy("this") + private[spark] def markTaskFailed(error: Throwable): Unit = synchronized { if (failed) return failed = true - val errorMsgs = new ArrayBuffer[String](2) - // Process failure callbacks in the reverse order of registration - onFailureCallbacks.reverse.foreach { listener => - try { - listener.onTaskFailure(this, error) - } catch { - case e: Throwable => - errorMsgs += e.getMessage - logError("Error in TaskFailureListener", e) - } - } - if (errorMsgs.nonEmpty) { - throw new TaskCompletionListenerException(errorMsgs, Option(error)) + failure = error + invokeListeners(onFailureCallbacks, "TaskFailureListener", Option(error)) { + _.onTaskFailure(this, error) } } /** Marks the task as completed and triggers the completion listeners. */ - private[spark] def markTaskCompleted(): Unit = { + @GuardedBy("this") + private[spark] def markTaskCompleted(): Unit = synchronized { + if (completed) return completed = true + invokeListeners(onCompleteCallbacks, "TaskCompletionListener", None) { + _.onTaskCompletion(this) + } + } + + private def invokeListeners[T]( + listeners: Seq[T], + name: String, + error: Option[Throwable])( + callback: T => Unit): Unit = { val errorMsgs = new ArrayBuffer[String](2) - // Process complete callbacks in the reverse order of registration - onCompleteCallbacks.reverse.foreach { listener => + // Process callbacks in the reverse order of registration + listeners.reverse.foreach { listener => try { - listener.onTaskCompletion(this) + callback(listener) } catch { case e: Throwable => errorMsgs += e.getMessage - logError("Error in TaskCompletionListener", e) + logError(s"Error in $name", e) } } if (errorMsgs.nonEmpty) { - throw new TaskCompletionListenerException(errorMsgs) + throw new TaskCompletionListenerException(errorMsgs, error) } } @@ -116,7 +144,8 @@ private[spark] class TaskContextImpl( interrupted = true } - override def isCompleted(): Boolean = completed + @GuardedBy("this") + override def isCompleted(): Boolean = synchronized(completed) override def isRunningLocally(): Boolean = false diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 7004128308..8f576daa77 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -228,6 +228,32 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark assert(res === Array("testPropValue,testPropValue")) } + test("immediately call a completion listener if the context is completed") { + var invocations = 0 + val context = TaskContext.empty() + context.markTaskCompleted() + context.addTaskCompletionListener(_ => invocations += 1) + assert(invocations == 1) + context.markTaskCompleted() + assert(invocations == 1) + } + + test("immediately call a failure listener if the context has failed") { + var invocations = 0 + var lastError: Throwable = null + val error = new RuntimeException + val context = TaskContext.empty() + context.markTaskFailed(error) + context.addTaskFailureListener { (_, e) => + lastError = e + invocations += 1 + } + assert(lastError == error) + assert(invocations == 1) + context.markTaskFailed(error) + assert(lastError == error) + assert(invocations == 1) + } } private object TaskContextSuite { |