aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala
diff options
context:
space:
mode:
authorPrashant Sharma <prashant.s@imaginea.com>2014-09-26 21:29:54 -0700
committerReynold Xin <rxin@apache.org>2014-09-26 21:29:54 -0700
commit5e34855cf04145cc3b7bae996c2a6e668f144a11 (patch)
tree654cd526ff24b2cc6465198c45acd0c78e7cd950 /core/src/main/scala
parentf872e4fb80b8429800daa9c44c0cac620c1ff303 (diff)
downloadspark-5e34855cf04145cc3b7bae996c2a6e668f144a11.tar.gz
spark-5e34855cf04145cc3b7bae996c2a6e668f144a11.tar.bz2
spark-5e34855cf04145cc3b7bae996c2a6e668f144a11.zip
[SPARK-3543] Write TaskContext in Java and expose it through a static accessor.
Author: Prashant Sharma <prashant.s@imaginea.com> Author: Shashank Sharma <shashank21j@gmail.com> Closes #2425 from ScrapCodes/SPARK-3543/withTaskContext and squashes the following commits: 8ae414c [Shashank Sharma] CR ee8bd00 [Prashant Sharma] Added internal API in docs comments. ddb8cbe [Prashant Sharma] Moved setting the thread local to where TaskContext is instantiated. a7d5e23 [Prashant Sharma] Added doc comments. edf945e [Prashant Sharma] Code review git add -A f716fd1 [Prashant Sharma] introduced thread local for getting the task context. 333c7d6 [Prashant Sharma] Translated Task context from scala to java.
Diffstat (limited to 'core/src/main/scala')
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContext.scala126
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Task.scala6
4 files changed, 8 insertions, 129 deletions
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
deleted file mode 100644
index 51b3e4d5e0..0000000000
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ /dev/null
@@ -1,126 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark
-
-import scala.collection.mutable.ArrayBuffer
-
-import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.executor.TaskMetrics
-import org.apache.spark.util.{TaskCompletionListenerException, TaskCompletionListener}
-
-
-/**
- * :: DeveloperApi ::
- * Contextual information about a task which can be read or mutated during execution.
- *
- * @param stageId stage id
- * @param partitionId index of the partition
- * @param attemptId the number of attempts to execute this task
- * @param runningLocally whether the task is running locally in the driver JVM
- * @param taskMetrics performance metrics of the task
- */
-@DeveloperApi
-class TaskContext(
- val stageId: Int,
- val partitionId: Int,
- val attemptId: Long,
- val runningLocally: Boolean = false,
- private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty)
- extends Serializable with Logging {
-
- @deprecated("use partitionId", "0.8.1")
- def splitId = partitionId
-
- // List of callback functions to execute when the task completes.
- @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener]
-
- // Whether the corresponding task has been killed.
- @volatile private var interrupted: Boolean = false
-
- // Whether the task has completed.
- @volatile private var completed: Boolean = false
-
- /** Checks whether the task has completed. */
- def isCompleted: Boolean = completed
-
- /** Checks whether the task has been killed. */
- def isInterrupted: Boolean = interrupted
-
- // TODO: Also track whether the task has completed successfully or with exception.
-
- /**
- * Add a (Java friendly) listener to be executed on task completion.
- * This will be called in all situation - success, failure, or cancellation.
- *
- * An example use is for HadoopRDD to register a callback to close the input stream.
- */
- def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {
- onCompleteCallbacks += listener
- this
- }
-
- /**
- * Add a listener in the form of a Scala closure to be executed on task completion.
- * This will be called in all situation - success, failure, or cancellation.
- *
- * An example use is for HadoopRDD to register a callback to close the input stream.
- */
- def addTaskCompletionListener(f: TaskContext => Unit): this.type = {
- onCompleteCallbacks += new TaskCompletionListener {
- override def onTaskCompletion(context: TaskContext): Unit = f(context)
- }
- this
- }
-
- /**
- * Add a callback function to be executed on task completion. An example use
- * is for HadoopRDD to register a callback to close the input stream.
- * Will be called in any situation - success, failure, or cancellation.
- * @param f Callback function.
- */
- @deprecated("use addTaskCompletionListener", "1.1.0")
- def addOnCompleteCallback(f: () => Unit) {
- onCompleteCallbacks += new TaskCompletionListener {
- override def onTaskCompletion(context: TaskContext): Unit = f()
- }
- }
-
- /** Marks the task as completed and triggers the listeners. */
- private[spark] def markTaskCompleted(): Unit = {
- completed = true
- val errorMsgs = new ArrayBuffer[String](2)
- // Process complete callbacks in the reverse order of registration
- onCompleteCallbacks.reverse.foreach { listener =>
- try {
- listener.onTaskCompletion(this)
- } catch {
- case e: Throwable =>
- errorMsgs += e.getMessage
- logError("Error in TaskCompletionListener", e)
- }
- }
- if (errorMsgs.nonEmpty) {
- throw new TaskCompletionListenerException(errorMsgs)
- }
- }
-
- /** Marks the task for interruption, i.e. cancellation. */
- private[spark] def markInterrupted(): Unit = {
- interrupted = true
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 0e90caa5c9..ba712c9d77 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -619,6 +619,7 @@ abstract class RDD[T: ClassTag](
* should be `false` unless this is a pair RDD and the input function doesn't modify the keys.
*/
@DeveloperApi
+ @deprecated("use TaskContext.get", "1.2.0")
def mapPartitionsWithContext[U: ClassTag](
f: (TaskContext, Iterator[T]) => Iterator[U],
preservesPartitioning: Boolean = false): RDD[U] = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index b2774dfc47..32cf29ed14 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -634,12 +634,14 @@ class DAGScheduler(
val rdd = job.finalStage.rdd
val split = rdd.partitions(job.partitions(0))
val taskContext =
- new TaskContext(job.finalStage.id, job.partitions(0), 0, runningLocally = true)
+ new TaskContext(job.finalStage.id, job.partitions(0), 0, true)
+ TaskContext.setTaskContext(taskContext)
try {
val result = job.func(taskContext, rdd.iterator(split, taskContext))
job.listener.taskSucceeded(0, result)
} finally {
taskContext.markTaskCompleted()
+ TaskContext.remove()
}
} catch {
case e: Exception =>
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 6aa0cca068..bf73f6f7bd 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -45,7 +45,8 @@ import org.apache.spark.util.Utils
private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
final def run(attemptId: Long): T = {
- context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false)
+ context = new TaskContext(stageId, partitionId, attemptId, false)
+ TaskContext.setTaskContext(context)
context.taskMetrics.hostname = Utils.localHostName()
taskThread = Thread.currentThread()
if (_killed) {
@@ -92,7 +93,8 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
if (interruptThread && taskThread != null) {
taskThread.interrupt()
}
- }
+ TaskContext.remove()
+ }
}
/**