aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/InterruptibleIterator.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContext.scala63
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Task.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala33
-rw-r--r--core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java39
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala2
14 files changed, 144 insertions, 23 deletions
diff --git a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala
index f40baa8e43..5c262bcbdd 100644
--- a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala
+++ b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala
@@ -33,7 +33,7 @@ class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator
// is allowed. The assumption is that Thread.interrupted does not have a memory fence in read
// (just a volatile field in C), while context.interrupted is a volatile in the JVM, which
// introduces an expensive read fence.
- if (context.interrupted) {
+ if (context.isInterrupted) {
throw new TaskKilledException
} else {
delegate.hasNext
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index 51f40c339d..2b99b8a5af 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -21,10 +21,18 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.util.TaskCompletionListener
+
/**
* :: DeveloperApi ::
* Contextual information about a task which can be read or mutated during execution.
+ *
+ * @param stageId stage id
+ * @param partitionId index of the partition
+ * @param attemptId the number of attempts to execute this task
+ * @param runningLocally whether the task is running locally in the driver JVM
+ * @param taskMetrics performance metrics of the task
*/
@DeveloperApi
class TaskContext(
@@ -39,13 +47,45 @@ class TaskContext(
def splitId = partitionId
// List of callback functions to execute when the task completes.
- @transient private val onCompleteCallbacks = new ArrayBuffer[() => Unit]
+ @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener]
// Whether the corresponding task has been killed.
- @volatile var interrupted: Boolean = false
+ @volatile private var interrupted: Boolean = false
+
+ // Whether the task has completed.
+ @volatile private var completed: Boolean = false
+
+ /** Checks whether the task has completed. */
+ def isCompleted: Boolean = completed
- // Whether the task has completed, before the onCompleteCallbacks are executed.
- @volatile var completed: Boolean = false
+ /** Checks whether the task has been killed. */
+ def isInterrupted: Boolean = interrupted
+
+ // TODO: Also track whether the task has completed successfully or with exception.
+
+ /**
+ * Add a (Java friendly) listener to be executed on task completion.
+ * This will be called in all situation - success, failure, or cancellation.
+ *
+ * An example use is for HadoopRDD to register a callback to close the input stream.
+ */
+ def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {
+ onCompleteCallbacks += listener
+ this
+ }
+
+ /**
+ * Add a listener in the form of a Scala closure to be executed on task completion.
+ * This will be called in all situation - success, failure, or cancellation.
+ *
+ * An example use is for HadoopRDD to register a callback to close the input stream.
+ */
+ def addTaskCompletionListener(f: TaskContext => Unit): this.type = {
+ onCompleteCallbacks += new TaskCompletionListener {
+ override def onTaskCompletion(context: TaskContext): Unit = f(context)
+ }
+ this
+ }
/**
* Add a callback function to be executed on task completion. An example use
@@ -53,13 +93,22 @@ class TaskContext(
* Will be called in any situation - success, failure, or cancellation.
* @param f Callback function.
*/
+ @deprecated("use addTaskCompletionListener", "1.1.0")
def addOnCompleteCallback(f: () => Unit) {
- onCompleteCallbacks += f
+ onCompleteCallbacks += new TaskCompletionListener {
+ override def onTaskCompletion(context: TaskContext): Unit = f()
+ }
}
- def executeOnCompleteCallbacks() {
+ /** Marks the task as completed and triggers the listeners. */
+ private[spark] def markTaskCompleted(): Unit = {
completed = true
// Process complete callbacks in the reverse order of registration
- onCompleteCallbacks.reverse.foreach { _() }
+ onCompleteCallbacks.reverse.foreach { _.onTaskCompletion(this) }
+ }
+
+ /** Marks the task for interruption, i.e. cancellation. */
+ private[spark] def markInterrupted(): Unit = {
+ interrupted = true
}
}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 0b5322c6fb..fefe1cb6f1 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -68,7 +68,7 @@ private[spark] class PythonRDD(
// Start a thread to feed the process input from our parent's iterator
val writerThread = new WriterThread(env, worker, split, context)
- context.addOnCompleteCallback { () =>
+ context.addTaskCompletionListener { context =>
writerThread.shutdownOnTaskCompletion()
// Cleanup the worker socket. This will also cause the Python worker to exit.
@@ -137,7 +137,7 @@ private[spark] class PythonRDD(
}
} catch {
- case e: Exception if context.interrupted =>
+ case e: Exception if context.isInterrupted =>
logDebug("Exception thrown after task interruption", e)
throw new TaskKilledException
@@ -176,7 +176,7 @@ private[spark] class PythonRDD(
/** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */
def shutdownOnTaskCompletion() {
- assert(context.completed)
+ assert(context.isCompleted)
this.interrupt()
}
@@ -209,7 +209,7 @@ private[spark] class PythonRDD(
PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut)
dataOut.flush()
} catch {
- case e: Exception if context.completed || context.interrupted =>
+ case e: Exception if context.isCompleted || context.isInterrupted =>
logDebug("Exception thrown after task completion (likely due to cleanup)", e)
case e: Exception =>
@@ -235,10 +235,10 @@ private[spark] class PythonRDD(
override def run() {
// Kill the worker if it is interrupted, checking until task completion.
// TODO: This has a race condition if interruption occurs, as completed may still become true.
- while (!context.interrupted && !context.completed) {
+ while (!context.isInterrupted && !context.isCompleted) {
Thread.sleep(2000)
}
- if (!context.completed) {
+ if (!context.isCompleted) {
try {
logWarning("Incomplete task interrupted: Attempting to kill Python Worker")
env.destroyPythonWorker(pythonExec, envVars.toMap, worker)
diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
index 34c51b8330..20938781ac 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
@@ -141,7 +141,7 @@ private[spark] object CheckpointRDD extends Logging {
val deserializeStream = serializer.deserializeStream(fileInputStream)
// Register an on-task-completion callback to close the input stream.
- context.addOnCompleteCallback(() => deserializeStream.close())
+ context.addTaskCompletionListener(context => deserializeStream.close())
deserializeStream.asIterator.asInstanceOf[Iterator[T]]
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index 8d92ea01d9..c8623314c9 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -197,7 +197,7 @@ class HadoopRDD[K, V](
reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
// Register an on-task-completion callback to close the input stream.
- context.addOnCompleteCallback{ () => closeIfNeeded() }
+ context.addTaskCompletionListener{ context => closeIfNeeded() }
val key: K = reader.createKey()
val value: V = reader.createValue()
diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
index 8947e66f45..0e38f224ac 100644
--- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
@@ -68,7 +68,7 @@ class JdbcRDD[T: ClassTag](
}
override def compute(thePart: Partition, context: TaskContext) = new NextIterator[T] {
- context.addOnCompleteCallback{ () => closeIfNeeded() }
+ context.addTaskCompletionListener{ context => closeIfNeeded() }
val part = thePart.asInstanceOf[JdbcPartition]
val conn = getConnection()
val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index 7dfec9a18e..58f707b9b4 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -129,7 +129,7 @@ class NewHadoopRDD[K, V](
context.taskMetrics.inputMetrics = Some(inputMetrics)
// Register an on-task-completion callback to close the input stream.
- context.addOnCompleteCallback(() => close())
+ context.addTaskCompletionListener(context => close())
var havePair = false
var finished = false
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 36bbaaa3f1..b86cfbfa48 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -634,7 +634,7 @@ class DAGScheduler(
val result = job.func(taskContext, rdd.iterator(split, taskContext))
job.listener.taskSucceeded(0, result)
} finally {
- taskContext.executeOnCompleteCallbacks()
+ taskContext.markTaskCompleted()
}
} catch {
case e: Exception =>
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
index d09fd7aa57..2ccbd8edeb 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
@@ -61,7 +61,7 @@ private[spark] class ResultTask[T, U](
try {
func(context, rdd.iterator(partition, context))
} finally {
- context.executeOnCompleteCallbacks()
+ context.markTaskCompleted()
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index 11255c0746..381eff2147 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -74,7 +74,7 @@ private[spark] class ShuffleMapTask(
}
throw e
} finally {
- context.executeOnCompleteCallbacks()
+ context.markTaskCompleted()
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index cbe0bc0bcb..6aa0cca068 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -87,7 +87,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
def kill(interruptThread: Boolean) {
_killed = true
if (context != null) {
- context.interrupted = true
+ context.markInterrupted()
}
if (interruptThread && taskThread != null) {
taskThread.interrupt()
diff --git a/core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala b/core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala
new file mode 100644
index 0000000000..c1b8bf052c
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/TaskCompletionListener.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.util.EventListener
+
+import org.apache.spark.TaskContext
+import org.apache.spark.annotation.DeveloperApi
+
+/**
+ * :: DeveloperApi ::
+ *
+ * Listener providing a callback function to invoke when a task's execution completes.
+ */
+@DeveloperApi
+trait TaskCompletionListener extends EventListener {
+ def onTaskCompletion(context: TaskContext)
+}
diff --git a/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java b/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java
new file mode 100644
index 0000000000..af34cdb03e
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util;
+
+import org.apache.spark.TaskContext;
+
+
+/**
+ * A simple implementation of TaskCompletionListener that makes sure TaskCompletionListener and
+ * TaskContext is Java friendly.
+ */
+public class JavaTaskCompletionListenerImpl implements TaskCompletionListener {
+
+ @Override
+ public void onTaskCompletion(TaskContext context) {
+ context.isCompleted();
+ context.isInterrupted();
+ context.stageId();
+ context.partitionId();
+ context.runningLocally();
+ context.taskMetrics();
+ context.addTaskCompletionListener(this);
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index 270f7e6610..db2ad829a4 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -32,7 +32,7 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte
val rdd = new RDD[String](sc, List()) {
override def getPartitions = Array[Partition](StubPartition(0))
override def compute(split: Partition, context: TaskContext) = {
- context.addOnCompleteCallback(() => TaskContextSuite.completed = true)
+ context.addTaskCompletionListener(context => TaskContextSuite.completed = true)
sys.error("failed")
}
}