aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorHerman van Hovell <hvanhovell@databricks.com>2017-03-15 10:46:05 +0100
committerHerman van Hovell <hvanhovell@databricks.com>2017-03-15 10:46:05 +0100
commit9ff85be3bd6bf3a782c0e52fa9c2598d79f310bb (patch)
tree7824feab4c6da92cec27d0e1c3e5f2aba5c9fffd /core
parentee36bc1c9043ead3c3ba4fba7e68c6c47ad7ae7a (diff)
downloadspark-9ff85be3bd6bf3a782c0e52fa9c2598d79f310bb.tar.gz
spark-9ff85be3bd6bf3a782c0e52fa9c2598d79f310bb.tar.bz2
spark-9ff85be3bd6bf3a782c0e52fa9c2598d79f310bb.zip
[SPARK-19889][SQL] Make TaskContext callbacks thread safe
## What changes were proposed in this pull request? It is sometimes useful to use multiple threads in a task to parallelize tasks. These threads might register some completion/failure listeners to clean up when the task completes or fails. We currently cannot register such a callback and be sure that it will get called, because the context might be in the process of invoking its callbacks, when the the callback gets registered. This PR improves this by making sure that you cannot add a completion/failure listener from a different thread when the context is being marked as completed/failed in another thread. This is done by synchronizing these methods on the task context itself. Failure listeners were called only once. Completion listeners now follow the same pattern; this lifts the idempotency requirement for completion listeners and makes it easier to implement them. In some cases we can (accidentally) add a completion/failure listener after the fact, these listeners will be called immediately in order make sure we can safely clean-up after a task. As a result of this change we could make the `failure` and `completed` flags non-volatile. The `isCompleted()` method now uses synchronization to ensure that updates are visible across threads. ## How was this patch tested? Adding tests to `TaskContestSuite` to test adding listeners to a completed/failed context. Author: Herman van Hovell <hvanhovell@databricks.com> Closes #17244 from hvanhovell/SPARK-19889.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContext.scala16
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContextImpl.scala85
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala26
3 files changed, 93 insertions, 34 deletions
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index f0867ecb16..5acfce1759 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -105,7 +105,9 @@ abstract class TaskContext extends Serializable {
/**
* Adds a (Java friendly) listener to be executed on task completion.
- * This will be called in all situation - success, failure, or cancellation.
+ * This will be called in all situations - success, failure, or cancellation. Adding a listener
+ * to an already completed task will result in that listener being called immediately.
+ *
* An example use is for HadoopRDD to register a callback to close the input stream.
*
* Exceptions thrown by the listener will result in failure of the task.
@@ -114,7 +116,9 @@ abstract class TaskContext extends Serializable {
/**
* Adds a listener in the form of a Scala closure to be executed on task completion.
- * This will be called in all situations - success, failure, or cancellation.
+ * This will be called in all situations - success, failure, or cancellation. Adding a listener
+ * to an already completed task will result in that listener being called immediately.
+ *
* An example use is for HadoopRDD to register a callback to close the input stream.
*
* Exceptions thrown by the listener will result in failure of the task.
@@ -126,14 +130,14 @@ abstract class TaskContext extends Serializable {
}
/**
- * Adds a listener to be executed on task failure.
- * Operations defined here must be idempotent, as `onTaskFailure` can be called multiple times.
+ * Adds a listener to be executed on task failure. Adding a listener to an already failed task
+ * will result in that listener being called immediately.
*/
def addTaskFailureListener(listener: TaskFailureListener): TaskContext
/**
- * Adds a listener to be executed on task failure.
- * Operations defined here must be idempotent, as `onTaskFailure` can be called multiple times.
+ * Adds a listener to be executed on task failure. Adding a listener to an already failed task
+ * will result in that listener being called immediately.
*/
def addTaskFailureListener(f: (TaskContext, Throwable) => Unit): TaskContext = {
addTaskFailureListener(new TaskFailureListener {
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index dc0d128785..ea8dcdfd5d 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -18,6 +18,7 @@
package org.apache.spark
import java.util.Properties
+import javax.annotation.concurrent.GuardedBy
import scala.collection.mutable.ArrayBuffer
@@ -29,6 +30,16 @@ import org.apache.spark.metrics.source.Source
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.util._
+/**
+ * A [[TaskContext]] implementation.
+ *
+ * A small note on thread safety. The interrupted & fetchFailed fields are volatile, this makes
+ * sure that updates are always visible across threads. The complete & failed flags and their
+ * callbacks are protected by locking on the context instance. For instance, this ensures
+ * that you cannot add a completion listener in one thread while we are completing (and calling
+ * the completion listeners) in another thread. Other state is immutable, however the exposed
+ * [[TaskMetrics]] & [[MetricsSystem]] objects are not thread safe.
+ */
private[spark] class TaskContextImpl(
val stageId: Int,
val partitionId: Int,
@@ -52,62 +63,79 @@ private[spark] class TaskContextImpl(
@volatile private var interrupted: Boolean = false
// Whether the task has completed.
- @volatile private var completed: Boolean = false
+ private var completed: Boolean = false
// Whether the task has failed.
- @volatile private var failed: Boolean = false
+ private var failed: Boolean = false
+
+ // Throwable that caused the task to fail
+ private var failure: Throwable = _
// If there was a fetch failure in the task, we store it here, to make sure user-code doesn't
// hide the exception. See SPARK-19276
@volatile private var _fetchFailedException: Option[FetchFailedException] = None
- override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {
- onCompleteCallbacks += listener
+ @GuardedBy("this")
+ override def addTaskCompletionListener(listener: TaskCompletionListener)
+ : this.type = synchronized {
+ if (completed) {
+ listener.onTaskCompletion(this)
+ } else {
+ onCompleteCallbacks += listener
+ }
this
}
- override def addTaskFailureListener(listener: TaskFailureListener): this.type = {
- onFailureCallbacks += listener
+ @GuardedBy("this")
+ override def addTaskFailureListener(listener: TaskFailureListener)
+ : this.type = synchronized {
+ if (failed) {
+ listener.onTaskFailure(this, failure)
+ } else {
+ onFailureCallbacks += listener
+ }
this
}
/** Marks the task as failed and triggers the failure listeners. */
- private[spark] def markTaskFailed(error: Throwable): Unit = {
- // failure callbacks should only be called once
+ @GuardedBy("this")
+ private[spark] def markTaskFailed(error: Throwable): Unit = synchronized {
if (failed) return
failed = true
- val errorMsgs = new ArrayBuffer[String](2)
- // Process failure callbacks in the reverse order of registration
- onFailureCallbacks.reverse.foreach { listener =>
- try {
- listener.onTaskFailure(this, error)
- } catch {
- case e: Throwable =>
- errorMsgs += e.getMessage
- logError("Error in TaskFailureListener", e)
- }
- }
- if (errorMsgs.nonEmpty) {
- throw new TaskCompletionListenerException(errorMsgs, Option(error))
+ failure = error
+ invokeListeners(onFailureCallbacks, "TaskFailureListener", Option(error)) {
+ _.onTaskFailure(this, error)
}
}
/** Marks the task as completed and triggers the completion listeners. */
- private[spark] def markTaskCompleted(): Unit = {
+ @GuardedBy("this")
+ private[spark] def markTaskCompleted(): Unit = synchronized {
+ if (completed) return
completed = true
+ invokeListeners(onCompleteCallbacks, "TaskCompletionListener", None) {
+ _.onTaskCompletion(this)
+ }
+ }
+
+ private def invokeListeners[T](
+ listeners: Seq[T],
+ name: String,
+ error: Option[Throwable])(
+ callback: T => Unit): Unit = {
val errorMsgs = new ArrayBuffer[String](2)
- // Process complete callbacks in the reverse order of registration
- onCompleteCallbacks.reverse.foreach { listener =>
+ // Process callbacks in the reverse order of registration
+ listeners.reverse.foreach { listener =>
try {
- listener.onTaskCompletion(this)
+ callback(listener)
} catch {
case e: Throwable =>
errorMsgs += e.getMessage
- logError("Error in TaskCompletionListener", e)
+ logError(s"Error in $name", e)
}
}
if (errorMsgs.nonEmpty) {
- throw new TaskCompletionListenerException(errorMsgs)
+ throw new TaskCompletionListenerException(errorMsgs, error)
}
}
@@ -116,7 +144,8 @@ private[spark] class TaskContextImpl(
interrupted = true
}
- override def isCompleted(): Boolean = completed
+ @GuardedBy("this")
+ override def isCompleted(): Boolean = synchronized(completed)
override def isRunningLocally(): Boolean = false
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index 7004128308..8f576daa77 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -228,6 +228,32 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
assert(res === Array("testPropValue,testPropValue"))
}
+ test("immediately call a completion listener if the context is completed") {
+ var invocations = 0
+ val context = TaskContext.empty()
+ context.markTaskCompleted()
+ context.addTaskCompletionListener(_ => invocations += 1)
+ assert(invocations == 1)
+ context.markTaskCompleted()
+ assert(invocations == 1)
+ }
+
+ test("immediately call a failure listener if the context has failed") {
+ var invocations = 0
+ var lastError: Throwable = null
+ val error = new RuntimeException
+ val context = TaskContext.empty()
+ context.markTaskFailed(error)
+ context.addTaskFailureListener { (_, e) =>
+ lastError = e
+ invocations += 1
+ }
+ assert(lastError == error)
+ assert(invocations == 1)
+ context.markTaskFailed(error)
+ assert(lastError == error)
+ assert(invocations == 1)
+ }
}
private object TaskContextSuite {