aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorPrashant Sharma <prashant.s@imaginea.com>2014-09-26 21:29:54 -0700
committerReynold Xin <rxin@apache.org>2014-09-26 21:29:54 -0700
commit5e34855cf04145cc3b7bae996c2a6e668f144a11 (patch)
tree654cd526ff24b2cc6465198c45acd0c78e7cd950 /core
parentf872e4fb80b8429800daa9c44c0cac620c1ff303 (diff)
downloadspark-5e34855cf04145cc3b7bae996c2a6e668f144a11.tar.gz
spark-5e34855cf04145cc3b7bae996c2a6e668f144a11.tar.bz2
spark-5e34855cf04145cc3b7bae996c2a6e668f144a11.zip
[SPARK-3543] Write TaskContext in Java and expose it through a static accessor.
Author: Prashant Sharma <prashant.s@imaginea.com> Author: Shashank Sharma <shashank21j@gmail.com> Closes #2425 from ScrapCodes/SPARK-3543/withTaskContext and squashes the following commits: 8ae414c [Shashank Sharma] CR ee8bd00 [Prashant Sharma] Added internal API in docs comments. ddb8cbe [Prashant Sharma] Moved setting the thread local to where TaskContext is instantiated. a7d5e23 [Prashant Sharma] Added doc comments. edf945e [Prashant Sharma] Code review git add -A f716fd1 [Prashant Sharma] introduced thread local for getting the task context. 333c7d6 [Prashant Sharma] Translated Task context from scala to java.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/java/org/apache/spark/TaskContext.java274
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContext.scala126
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Task.scala6
-rw-r--r--core/src/test/java/org/apache/spark/JavaAPISuite.java2
-rw-r--r--core/src/test/scala/org/apache/spark/CacheManagerSuite.scala2
7 files changed, 284 insertions, 131 deletions
diff --git a/core/src/main/java/org/apache/spark/TaskContext.java b/core/src/main/java/org/apache/spark/TaskContext.java
new file mode 100644
index 0000000000..09b8ce02bd
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/TaskContext.java
@@ -0,0 +1,274 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+import scala.Function0;
+import scala.Function1;
+import scala.Unit;
+import scala.collection.JavaConversions;
+
+import org.apache.spark.annotation.DeveloperApi;
+import org.apache.spark.executor.TaskMetrics;
+import org.apache.spark.util.TaskCompletionListener;
+import org.apache.spark.util.TaskCompletionListenerException;
+
+/**
+* :: DeveloperApi ::
+* Contextual information about a task which can be read or mutated during execution.
+*/
+@DeveloperApi
+public class TaskContext implements Serializable {
+
+ private int stageId;
+ private int partitionId;
+ private long attemptId;
+ private boolean runningLocally;
+ private TaskMetrics taskMetrics;
+
+ /**
+ * :: DeveloperApi ::
+ * Contextual information about a task which can be read or mutated during execution.
+ *
+ * @param stageId stage id
+ * @param partitionId index of the partition
+ * @param attemptId the number of attempts to execute this task
+ * @param runningLocally whether the task is running locally in the driver JVM
+ * @param taskMetrics performance metrics of the task
+ */
+ @DeveloperApi
+ public TaskContext(Integer stageId, Integer partitionId, Long attemptId, Boolean runningLocally,
+ TaskMetrics taskMetrics) {
+ this.attemptId = attemptId;
+ this.partitionId = partitionId;
+ this.runningLocally = runningLocally;
+ this.stageId = stageId;
+ this.taskMetrics = taskMetrics;
+ }
+
+
+ /**
+ * :: DeveloperApi ::
+ * Contextual information about a task which can be read or mutated during execution.
+ *
+ * @param stageId stage id
+ * @param partitionId index of the partition
+ * @param attemptId the number of attempts to execute this task
+ * @param runningLocally whether the task is running locally in the driver JVM
+ */
+ @DeveloperApi
+ public TaskContext(Integer stageId, Integer partitionId, Long attemptId,
+ Boolean runningLocally) {
+ this.attemptId = attemptId;
+ this.partitionId = partitionId;
+ this.runningLocally = runningLocally;
+ this.stageId = stageId;
+ this.taskMetrics = TaskMetrics.empty();
+ }
+
+
+ /**
+ * :: DeveloperApi ::
+ * Contextual information about a task which can be read or mutated during execution.
+ *
+ * @param stageId stage id
+ * @param partitionId index of the partition
+ * @param attemptId the number of attempts to execute this task
+ */
+ @DeveloperApi
+ public TaskContext(Integer stageId, Integer partitionId, Long attemptId) {
+ this.attemptId = attemptId;
+ this.partitionId = partitionId;
+ this.runningLocally = false;
+ this.stageId = stageId;
+ this.taskMetrics = TaskMetrics.empty();
+ }
+
+ private static ThreadLocal<TaskContext> taskContext =
+ new ThreadLocal<TaskContext>();
+
+ /**
+ * :: Internal API ::
+ * This is spark internal API, not intended to be called from user programs.
+ */
+ public static void setTaskContext(TaskContext tc) {
+ taskContext.set(tc);
+ }
+
+ public static TaskContext get() {
+ return taskContext.get();
+ }
+
+ /**
+ * :: Internal API ::
+ */
+ public static void remove() {
+ taskContext.remove();
+ }
+
+ // List of callback functions to execute when the task completes.
+ private transient List<TaskCompletionListener> onCompleteCallbacks =
+ new ArrayList<TaskCompletionListener>();
+
+ // Whether the corresponding task has been killed.
+ private volatile Boolean interrupted = false;
+
+ // Whether the task has completed.
+ private volatile Boolean completed = false;
+
+ /**
+ * Checks whether the task has completed.
+ */
+ public Boolean isCompleted() {
+ return completed;
+ }
+
+ /**
+ * Checks whether the task has been killed.
+ */
+ public Boolean isInterrupted() {
+ return interrupted;
+ }
+
+ /**
+ * Add a (Java friendly) listener to be executed on task completion.
+ * This will be called in all situation - success, failure, or cancellation.
+ * <p/>
+ * An example use is for HadoopRDD to register a callback to close the input stream.
+ */
+ public TaskContext addTaskCompletionListener(TaskCompletionListener listener) {
+ onCompleteCallbacks.add(listener);
+ return this;
+ }
+
+ /**
+ * Add a listener in the form of a Scala closure to be executed on task completion.
+ * This will be called in all situations - success, failure, or cancellation.
+ * <p/>
+ * An example use is for HadoopRDD to register a callback to close the input stream.
+ */
+ public TaskContext addTaskCompletionListener(final Function1<TaskContext, Unit> f) {
+ onCompleteCallbacks.add(new TaskCompletionListener() {
+ @Override
+ public void onTaskCompletion(TaskContext context) {
+ f.apply(context);
+ }
+ });
+ return this;
+ }
+
+ /**
+ * Add a callback function to be executed on task completion. An example use
+ * is for HadoopRDD to register a callback to close the input stream.
+ * Will be called in any situation - success, failure, or cancellation.
+ *
+ * Deprecated: use addTaskCompletionListener
+ *
+ * @param f Callback function.
+ */
+ @Deprecated
+ public void addOnCompleteCallback(final Function0<Unit> f) {
+ onCompleteCallbacks.add(new TaskCompletionListener() {
+ @Override
+ public void onTaskCompletion(TaskContext context) {
+ f.apply();
+ }
+ });
+ }
+
+ /**
+ * ::Internal API::
+ * Marks the task as completed and triggers the listeners.
+ */
+ public void markTaskCompleted() throws TaskCompletionListenerException {
+ completed = true;
+ List<String> errorMsgs = new ArrayList<String>(2);
+ // Process complete callbacks in the reverse order of registration
+ List<TaskCompletionListener> revlist =
+ new ArrayList<TaskCompletionListener>(onCompleteCallbacks);
+ Collections.reverse(revlist);
+ for (TaskCompletionListener tcl: revlist) {
+ try {
+ tcl.onTaskCompletion(this);
+ } catch (Throwable e) {
+ errorMsgs.add(e.getMessage());
+ }
+ }
+
+ if (!errorMsgs.isEmpty()) {
+ throw new TaskCompletionListenerException(JavaConversions.asScalaBuffer(errorMsgs));
+ }
+ }
+
+ /**
+ * ::Internal API::
+ * Marks the task for interruption, i.e. cancellation.
+ */
+ public void markInterrupted() {
+ interrupted = true;
+ }
+
+ @Deprecated
+ /** Deprecated: use getStageId() */
+ public int stageId() {
+ return stageId;
+ }
+
+ @Deprecated
+ /** Deprecated: use getPartitionId() */
+ public int partitionId() {
+ return partitionId;
+ }
+
+ @Deprecated
+ /** Deprecated: use getAttemptId() */
+ public long attemptId() {
+ return attemptId;
+ }
+
+ @Deprecated
+ /** Deprecated: use getRunningLocally() */
+ public boolean runningLocally() {
+ return runningLocally;
+ }
+
+ public boolean getRunningLocally() {
+ return runningLocally;
+ }
+
+ public int getStageId() {
+ return stageId;
+ }
+
+ public int getPartitionId() {
+ return partitionId;
+ }
+
+ public long getAttemptId() {
+ return attemptId;
+ }
+
+ /** ::Internal API:: */
+ public TaskMetrics taskMetrics() {
+ return taskMetrics;
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
deleted file mode 100644
index 51b3e4d5e0..0000000000
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ /dev/null
@@ -1,126 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark
-
-import scala.collection.mutable.ArrayBuffer
-
-import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.executor.TaskMetrics
-import org.apache.spark.util.{TaskCompletionListenerException, TaskCompletionListener}
-
-
-/**
- * :: DeveloperApi ::
- * Contextual information about a task which can be read or mutated during execution.
- *
- * @param stageId stage id
- * @param partitionId index of the partition
- * @param attemptId the number of attempts to execute this task
- * @param runningLocally whether the task is running locally in the driver JVM
- * @param taskMetrics performance metrics of the task
- */
-@DeveloperApi
-class TaskContext(
- val stageId: Int,
- val partitionId: Int,
- val attemptId: Long,
- val runningLocally: Boolean = false,
- private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty)
- extends Serializable with Logging {
-
- @deprecated("use partitionId", "0.8.1")
- def splitId = partitionId
-
- // List of callback functions to execute when the task completes.
- @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener]
-
- // Whether the corresponding task has been killed.
- @volatile private var interrupted: Boolean = false
-
- // Whether the task has completed.
- @volatile private var completed: Boolean = false
-
- /** Checks whether the task has completed. */
- def isCompleted: Boolean = completed
-
- /** Checks whether the task has been killed. */
- def isInterrupted: Boolean = interrupted
-
- // TODO: Also track whether the task has completed successfully or with exception.
-
- /**
- * Add a (Java friendly) listener to be executed on task completion.
- * This will be called in all situation - success, failure, or cancellation.
- *
- * An example use is for HadoopRDD to register a callback to close the input stream.
- */
- def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {
- onCompleteCallbacks += listener
- this
- }
-
- /**
- * Add a listener in the form of a Scala closure to be executed on task completion.
- * This will be called in all situation - success, failure, or cancellation.
- *
- * An example use is for HadoopRDD to register a callback to close the input stream.
- */
- def addTaskCompletionListener(f: TaskContext => Unit): this.type = {
- onCompleteCallbacks += new TaskCompletionListener {
- override def onTaskCompletion(context: TaskContext): Unit = f(context)
- }
- this
- }
-
- /**
- * Add a callback function to be executed on task completion. An example use
- * is for HadoopRDD to register a callback to close the input stream.
- * Will be called in any situation - success, failure, or cancellation.
- * @param f Callback function.
- */
- @deprecated("use addTaskCompletionListener", "1.1.0")
- def addOnCompleteCallback(f: () => Unit) {
- onCompleteCallbacks += new TaskCompletionListener {
- override def onTaskCompletion(context: TaskContext): Unit = f()
- }
- }
-
- /** Marks the task as completed and triggers the listeners. */
- private[spark] def markTaskCompleted(): Unit = {
- completed = true
- val errorMsgs = new ArrayBuffer[String](2)
- // Process complete callbacks in the reverse order of registration
- onCompleteCallbacks.reverse.foreach { listener =>
- try {
- listener.onTaskCompletion(this)
- } catch {
- case e: Throwable =>
- errorMsgs += e.getMessage
- logError("Error in TaskCompletionListener", e)
- }
- }
- if (errorMsgs.nonEmpty) {
- throw new TaskCompletionListenerException(errorMsgs)
- }
- }
-
- /** Marks the task for interruption, i.e. cancellation. */
- private[spark] def markInterrupted(): Unit = {
- interrupted = true
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 0e90caa5c9..ba712c9d77 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -619,6 +619,7 @@ abstract class RDD[T: ClassTag](
* should be `false` unless this is a pair RDD and the input function doesn't modify the keys.
*/
@DeveloperApi
+ @deprecated("use TaskContext.get", "1.2.0")
def mapPartitionsWithContext[U: ClassTag](
f: (TaskContext, Iterator[T]) => Iterator[U],
preservesPartitioning: Boolean = false): RDD[U] = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index b2774dfc47..32cf29ed14 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -634,12 +634,14 @@ class DAGScheduler(
val rdd = job.finalStage.rdd
val split = rdd.partitions(job.partitions(0))
val taskContext =
- new TaskContext(job.finalStage.id, job.partitions(0), 0, runningLocally = true)
+ new TaskContext(job.finalStage.id, job.partitions(0), 0, true)
+ TaskContext.setTaskContext(taskContext)
try {
val result = job.func(taskContext, rdd.iterator(split, taskContext))
job.listener.taskSucceeded(0, result)
} finally {
taskContext.markTaskCompleted()
+ TaskContext.remove()
}
} catch {
case e: Exception =>
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 6aa0cca068..bf73f6f7bd 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -45,7 +45,8 @@ import org.apache.spark.util.Utils
private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
final def run(attemptId: Long): T = {
- context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false)
+ context = new TaskContext(stageId, partitionId, attemptId, false)
+ TaskContext.setTaskContext(context)
context.taskMetrics.hostname = Utils.localHostName()
taskThread = Thread.currentThread()
if (_killed) {
@@ -92,7 +93,8 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
if (interruptThread && taskThread != null) {
taskThread.interrupt()
}
- }
+ TaskContext.remove()
+ }
}
/**
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index b8c23d524e..4a07843544 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -776,7 +776,7 @@ public class JavaAPISuite implements Serializable {
@Test
public void iterator() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
- TaskContext context = new TaskContext(0, 0, 0, false, new TaskMetrics());
+ TaskContext context = new TaskContext(0, 0, 0L, false, new TaskMetrics());
Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue());
}
diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
index 90dcadcffd..d735010d7c 100644
--- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
@@ -94,7 +94,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
}
whenExecuting(blockManager) {
- val context = new TaskContext(0, 0, 0, runningLocally = true)
+ val context = new TaskContext(0, 0, 0, true)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(1, 2, 3, 4))
}