aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorReynold Xin <rxin@apache.org>2014-08-14 18:37:02 -0700
committerReynold Xin <rxin@apache.org>2014-08-14 18:37:02 -0700
commit655699f8b7156e8216431393436368e80626cdb2 (patch)
treead027910e7e3e2cde524f4ccadb3324d284ab95a /core
parentfa5a08e67d1086045ac249c2090c5e4d0a17b828 (diff)
downloadspark-655699f8b7156e8216431393436368e80626cdb2.tar.gz
spark-655699f8b7156e8216431393436368e80626cdb2.tar.bz2
spark-655699f8b7156e8216431393436368e80626cdb2.zip
[SPARK-3027] TaskContext: tighten visibility and provide Java friendly callback API
Note this also passes the TaskContext itself to the TaskCompletionListener. In the future we can mark TaskContext with the exception object if exception occurs during task execution. Author: Reynold Xin <rxin@apache.org> Closes #1938 from rxin/TaskContext and squashes the following commits: 145de43 [Reynold Xin] Added JavaTaskCompletionListenerImpl for Java API friendly guarantee. f435ea5 [Reynold Xin] Added license header for TaskCompletionListener. dc4ed27 [Reynold Xin] [SPARK-3027] TaskContext: tighten the visibility and provide Java friendly callback API
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/InterruptibleIterator.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContext.scala63
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Task.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala33
-rw-r--r--core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java39
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala2
14 files changed, 144 insertions, 23 deletions
diff --git a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala
index f40baa8e43..5c262bcbdd 100644
--- a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala
+++ b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala
@@ -33,7 +33,7 @@ class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator
// is allowed. The assumption is that Thread.interrupted does not have a memory fence in read
// (just a volatile field in C), while context.interrupted is a volatile in the JVM, which
// introduces an expensive read fence.
- if (context.interrupted) {
+ if (context.isInterrupted) {
throw new TaskKilledException
} else {
delegate.hasNext
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index 51f40c339d..2b99b8a5af 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -21,10 +21,18 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.util.TaskCompletionListener
+
/**
* :: DeveloperApi ::
* Contextual information about a task which can be read or mutated during execution.
+ *
+ * @param stageId stage id
+ * @param partitionId index of the partition
+ * @param attemptId the number of attempts to execute this task
+ * @param runningLocally whether the task is running locally in the driver JVM
+ * @param taskMetrics performance metrics of the task
*/
@DeveloperApi
class TaskContext(
@@ -39,13 +47,45 @@ class TaskContext(
def splitId = partitionId
// List of callback functions to execute when the task completes.
- @transient private val onCompleteCallbacks = new ArrayBuffer[() => Unit]
+ @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener]
// Whether the corresponding task has been killed.
- @volatile var interrupted: Boolean = false
+ @volatile private var interrupted: Boolean = false
+
+ // Whether the task has completed.
+ @volatile private var completed: Boolean = false
+
+ /** Checks whether the task has completed. */
+ def isCompleted: Boolean = completed
- // Whether the task has completed, before the onCompleteCallbacks are executed.
- @volatile var completed: Boolean = false
+ /** Checks whether the task has been killed. */
+ def isInterrupted: Boolean = interrupted
+
+ // TODO: Also track whether the task has completed successfully or with exception.
+
+ /**
+ * Add a (Java friendly) listener to be executed on task completion.
+ * This will be called in all situation - success, failure, or cancellation.
+ *
+ * An example use is for HadoopRDD to register a callback to close the input stream.
+ */
+ def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {
+ onCompleteCallbacks += listener
+ this
+ }
+
+ /**
+ * Add a listener in the form of a Scala closure to be executed on task completion.
+ * This will be called in all situation - success, failure, or cancellation.
+ *
+ * An example use is for HadoopRDD to register a callback to close the input stream.
+ */
+ def addTaskCompletionListener(f: TaskContext => Unit): this.type = {
+ onCompleteCallbacks += new TaskCompletionListener {
+ override def onTaskCompletion(context: TaskContext): Unit = f(context)
+ }
+ this
+ }
/**
* Add a callback function to be executed on task completion. An example use
@@ -53,13 +93,22 @@ class TaskContext(
* Will be called in any situation - success, failure, or cancellation.
* @param f Callback function.
*/
+ @deprecated("use addTaskCompletionListener", "1.1.0")
def addOnCompleteCallback(f: () => Unit) {
- onCompleteCallbacks += f
+ onCompleteCallbacks += new TaskCompletionListener {
+ override def onTaskCompletion(context: TaskContext): Unit = f()
+ }
}
- def executeOnCompleteCallbacks() {
+ /** Marks the task as completed and triggers the listeners. */
+ private[spark] def markTaskCompleted(): Unit = {
completed = true
// Process complete callbacks in the reverse order of registration
- onCompleteCallbacks.reverse.foreach { _() }
+ onCompleteCallbacks.reverse.foreach { _.onTaskCompletion(this) }
+ }
+
+ /** Marks the task for interruption, i.e. cancellation. */
+ private[spark] def markInterrupted(): Unit = {
+ interrupted = true
}
}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 0b5322c6fb..fefe1cb6f1 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -68,7 +68,7 @@ private[spark] class PythonRDD(
// Start a thread to feed the process input from our parent's iterator
val writerThread = new WriterThread(env, worker, split, context)
- context.addOnCompleteCallback { () =>
+ context.addTaskCompletionListener { context =>
writerThread.shutdownOnTaskCompletion()
// Cleanup the worker socket. This will also cause the Python worker to exit.
@@ -137,7 +137,7 @@ private[spark] class PythonRDD(
}
} catch {
- case e: Exception if context.interrupted =>
+ case e: Exception if context.isInterrupted =>
logDebug("Exception thrown after task interruption", e)
throw new TaskKilledException
@@ -176,7 +176,7 @@ private[spark] class PythonRDD(
/** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */
def shutdownOnTaskCompletion() {
- assert(context.completed)
+ assert(context.isCompleted)
this.interrupt()
}
@@ -209,7 +209,7 @@ private[spark] class PythonRDD(
PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut)
dataOut.flush()
} catch {
- case e: Exception if context.completed || context.interrupted =>
+ case e: Exception if context.isCompleted || context.isInterrupted =>
logDebug("Exception thrown after task completion (likely due to cleanup)", e)
case e: Exception =>
@@ -235,10 +235,10 @@ private[spark] class PythonRDD(
override def run() {
// Kill the worker if it is interrupted, checking until task completion.
// TODO: This has a race condition if interruption occurs, as completed may still become true.
- while (!context.interrupted && !context.completed) {
+ while (!context.isInterrupted && !context.isCompleted) {
Thread.sleep(2000)
}
- if (!context.completed) {
+ if (!context.isCompleted) {
try {
logWarning("Incomplete task interrupted: Attempting to kill Python Worker")
env.destroyPythonWorker(pythonExec, envVars.toMap, worker)
diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
index 34c51b8330..20938781ac 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
@@ -141,7 +141,7 @@ private[spark] object CheckpointRDD extends Logging {
val deserializeStream = serializer.deserializeStream(fileInputStream)
// Register an on-task-completion callback to close the input stream.
- context.addOnCompleteCallback(() => deserializeStream.close())
+ context.addTaskCompletionListener(context => deserializeStream.close())
deserializeStream.asIterator.asInstanceOf[Iterator[T]]
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index 8d92ea01d9..c8623314c9 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -197,7 +197,7 @@ class HadoopRDD[K, V](
reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
// Register an on-task-completion callback to close the input stream.
- context.addOnCompleteCallback{ () => closeIfNeeded() }
+ context.addTaskCompletionListener{ context => closeIfNeeded() }
val key: K = reader.createKey()
val value: V = reader.createValue()
diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
index 8947e66f45..0e38f224ac 100644
--- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
@@ -68,7 +68,7 @@ class JdbcRDD[T: ClassTag](
}
override def compute(thePart: Partition, context: TaskContext) = new NextIterator[T] {
- context.addOnCompleteCallback{ () => closeIfNeeded() }
+ context.addTaskCompletionListener{ context => closeIfNeeded() }
val part = thePart.asInstanceOf[JdbcPartition]
val conn = getConnection()
val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index 7dfec9a18e..58f707b9b4 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -129,7 +129,7 @@ class NewHadoopRDD[K, V](
context.taskMetrics.inputMetrics = Some(inputMetrics)
// Register an on-task-completion callback to close the input stream.
- context.addOnCompleteCallback(() => close())
+ context.addTaskCompletionListener(context => close())
var havePair = false
var finished = false
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 36bbaaa3f1..b86cfbfa48 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -634,7 +634,7 @@ class DAGScheduler(
val result = job.func(taskContext, rdd.iterator(split, taskContext))
job.listener.taskSucceeded(0, result)
} finally {
- taskContext.executeOnCompleteCallbacks()
+ taskContext.markTaskCompleted()
}
} catch {
case e: Exception =>
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
index d09fd7aa57..2ccbd8edeb 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
@@ -61,7 +61,7 @@ private[spark] class ResultTask[T, U](
try {
func(context, rdd.iterator(partition, context))
} finally {
- context.executeOnCompleteCallbacks()
+ context.markTaskCompleted()
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index 11255c0746..381eff2147 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -74,7 +74,7 @@ private[spark] class ShuffleMapTask(
}
throw e
} finally {
- context.executeOnCompleteCallbacks()
+ context.markTaskCompleted()
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index cbe0bc0bcb..6aa0cca068 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -87,7 +87,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
def kill(interruptThread: Boolean) {
_killed = true
if (context != null) {
- context.interrupted = true
+ context.markInterrupted()
}
if (interruptThread && taskThread != null) {
taskThread.interrupt()
diff --git a/core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala b/core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala
new file mode 100644
index 0000000000..c1b8bf052c
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.util.EventListener
+
+import org.apache.spark.TaskContext
+import org.apache.spark.annotation.DeveloperApi
+
+/**
+ * :: DeveloperApi ::
+ *
+ * Listener providing a callback function to invoke when a task's execution completes.
+ */
+@DeveloperApi
+trait TaskCompletionListener extends EventListener {
+ def onTaskCompletion(context: TaskContext)
+}
diff --git a/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java b/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java
new file mode 100644
index 0000000000..af34cdb03e
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util;
+
+import org.apache.spark.TaskContext;
+
+
+/**
+ * A simple implementation of TaskCompletionListener that makes sure TaskCompletionListener and
+ * TaskContext is Java friendly.
+ */
+public class JavaTaskCompletionListenerImpl implements TaskCompletionListener {
+
+ @Override
+ public void onTaskCompletion(TaskContext context) {
+ context.isCompleted();
+ context.isInterrupted();
+ context.stageId();
+ context.partitionId();
+ context.runningLocally();
+ context.taskMetrics();
+ context.addTaskCompletionListener(this);
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index 270f7e6610..db2ad829a4 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -32,7 +32,7 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte
val rdd = new RDD[String](sc, List()) {
override def getPartitions = Array[Partition](StubPartition(0))
override def compute(split: Partition, context: TaskContext) = {
- context.addOnCompleteCallback(() => TaskContextSuite.completed = true)
+ context.addTaskCompletionListener(context => TaskContextSuite.completed = true)
sys.error("failed")
}
}