aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContext.scala16
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContextImpl.scala85
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala26
3 files changed, 93 insertions, 34 deletions
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index f0867ecb16..5acfce1759 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -105,7 +105,9 @@ abstract class TaskContext extends Serializable {
/**
* Adds a (Java friendly) listener to be executed on task completion.
- * This will be called in all situation - success, failure, or cancellation.
+ * This will be called in all situations - success, failure, or cancellation. Adding a listener
+ * to an already completed task will result in that listener being called immediately.
+ *
* An example use is for HadoopRDD to register a callback to close the input stream.
*
* Exceptions thrown by the listener will result in failure of the task.
@@ -114,7 +116,9 @@ abstract class TaskContext extends Serializable {
/**
* Adds a listener in the form of a Scala closure to be executed on task completion.
- * This will be called in all situations - success, failure, or cancellation.
+ * This will be called in all situations - success, failure, or cancellation. Adding a listener
+ * to an already completed task will result in that listener being called immediately.
+ *
* An example use is for HadoopRDD to register a callback to close the input stream.
*
* Exceptions thrown by the listener will result in failure of the task.
@@ -126,14 +130,14 @@ abstract class TaskContext extends Serializable {
}
/**
- * Adds a listener to be executed on task failure.
- * Operations defined here must be idempotent, as `onTaskFailure` can be called multiple times.
+ * Adds a listener to be executed on task failure. Adding a listener to an already failed task
+ * will result in that listener being called immediately.
*/
def addTaskFailureListener(listener: TaskFailureListener): TaskContext
/**
- * Adds a listener to be executed on task failure.
- * Operations defined here must be idempotent, as `onTaskFailure` can be called multiple times.
+ * Adds a listener to be executed on task failure. Adding a listener to an already failed task
+ * will result in that listener being called immediately.
*/
def addTaskFailureListener(f: (TaskContext, Throwable) => Unit): TaskContext = {
addTaskFailureListener(new TaskFailureListener {
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index dc0d128785..ea8dcdfd5d 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -18,6 +18,7 @@
package org.apache.spark
import java.util.Properties
+import javax.annotation.concurrent.GuardedBy
import scala.collection.mutable.ArrayBuffer
@@ -29,6 +30,16 @@ import org.apache.spark.metrics.source.Source
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.util._
+/**
+ * A [[TaskContext]] implementation.
+ *
+ * A small note on thread safety. The interrupted & fetchFailed fields are volatile, this makes
+ * sure that updates are always visible across threads. The complete & failed flags and their
+ * callbacks are protected by locking on the context instance. For instance, this ensures
+ * that you cannot add a completion listener in one thread while we are completing (and calling
+ * the completion listeners) in another thread. Other state is immutable, however the exposed
+ * [[TaskMetrics]] & [[MetricsSystem]] objects are not thread safe.
+ */
private[spark] class TaskContextImpl(
val stageId: Int,
val partitionId: Int,
@@ -52,62 +63,79 @@ private[spark] class TaskContextImpl(
@volatile private var interrupted: Boolean = false
// Whether the task has completed.
- @volatile private var completed: Boolean = false
+ private var completed: Boolean = false
// Whether the task has failed.
- @volatile private var failed: Boolean = false
+ private var failed: Boolean = false
+
+ // Throwable that caused the task to fail
+ private var failure: Throwable = _
// If there was a fetch failure in the task, we store it here, to make sure user-code doesn't
// hide the exception. See SPARK-19276
@volatile private var _fetchFailedException: Option[FetchFailedException] = None
- override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {
- onCompleteCallbacks += listener
+ @GuardedBy("this")
+ override def addTaskCompletionListener(listener: TaskCompletionListener)
+ : this.type = synchronized {
+ if (completed) {
+ listener.onTaskCompletion(this)
+ } else {
+ onCompleteCallbacks += listener
+ }
this
}
- override def addTaskFailureListener(listener: TaskFailureListener): this.type = {
- onFailureCallbacks += listener
+ @GuardedBy("this")
+ override def addTaskFailureListener(listener: TaskFailureListener)
+ : this.type = synchronized {
+ if (failed) {
+ listener.onTaskFailure(this, failure)
+ } else {
+ onFailureCallbacks += listener
+ }
this
}
/** Marks the task as failed and triggers the failure listeners. */
- private[spark] def markTaskFailed(error: Throwable): Unit = {
- // failure callbacks should only be called once
+ @GuardedBy("this")
+ private[spark] def markTaskFailed(error: Throwable): Unit = synchronized {
if (failed) return
failed = true
- val errorMsgs = new ArrayBuffer[String](2)
- // Process failure callbacks in the reverse order of registration
- onFailureCallbacks.reverse.foreach { listener =>
- try {
- listener.onTaskFailure(this, error)
- } catch {
- case e: Throwable =>
- errorMsgs += e.getMessage
- logError("Error in TaskFailureListener", e)
- }
- }
- if (errorMsgs.nonEmpty) {
- throw new TaskCompletionListenerException(errorMsgs, Option(error))
+ failure = error
+ invokeListeners(onFailureCallbacks, "TaskFailureListener", Option(error)) {
+ _.onTaskFailure(this, error)
}
}
/** Marks the task as completed and triggers the completion listeners. */
- private[spark] def markTaskCompleted(): Unit = {
+ @GuardedBy("this")
+ private[spark] def markTaskCompleted(): Unit = synchronized {
+ if (completed) return
completed = true
+ invokeListeners(onCompleteCallbacks, "TaskCompletionListener", None) {
+ _.onTaskCompletion(this)
+ }
+ }
+
+ private def invokeListeners[T](
+ listeners: Seq[T],
+ name: String,
+ error: Option[Throwable])(
+ callback: T => Unit): Unit = {
val errorMsgs = new ArrayBuffer[String](2)
- // Process complete callbacks in the reverse order of registration
- onCompleteCallbacks.reverse.foreach { listener =>
+ // Process callbacks in the reverse order of registration
+ listeners.reverse.foreach { listener =>
try {
- listener.onTaskCompletion(this)
+ callback(listener)
} catch {
case e: Throwable =>
errorMsgs += e.getMessage
- logError("Error in TaskCompletionListener", e)
+ logError(s"Error in $name", e)
}
}
if (errorMsgs.nonEmpty) {
- throw new TaskCompletionListenerException(errorMsgs)
+ throw new TaskCompletionListenerException(errorMsgs, error)
}
}
@@ -116,7 +144,8 @@ private[spark] class TaskContextImpl(
interrupted = true
}
- override def isCompleted(): Boolean = completed
+ @GuardedBy("this")
+ override def isCompleted(): Boolean = synchronized(completed)
override def isRunningLocally(): Boolean = false
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index 7004128308..8f576daa77 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -228,6 +228,32 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
assert(res === Array("testPropValue,testPropValue"))
}
+ test("immediately call a completion listener if the context is completed") {
+ var invocations = 0
+ val context = TaskContext.empty()
+ context.markTaskCompleted()
+ context.addTaskCompletionListener(_ => invocations += 1)
+ assert(invocations == 1)
+ context.markTaskCompleted()
+ assert(invocations == 1)
+ }
+
+ test("immediately call a failure listener if the context has failed") {
+ var invocations = 0
+ var lastError: Throwable = null
+ val error = new RuntimeException
+ val context = TaskContext.empty()
+ context.markTaskFailed(error)
+ context.addTaskFailureListener { (_, e) =>
+ lastError = e
+ invocations += 1
+ }
+ assert(lastError == error)
+ assert(invocations == 1)
+ context.markTaskFailed(error)
+ assert(lastError == error)
+ assert(invocations == 1)
+ }
}
private object TaskContextSuite {