diff options
Diffstat (limited to 'core/src/main/scala')
4 files changed, 8 insertions, 10 deletions
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 32cf29ed14..70c235dfff 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -641,7 +641,7 @@ class DAGScheduler( job.listener.taskSucceeded(0, result) } finally { taskContext.markTaskCompleted() - TaskContext.remove() + TaskContext.unset() } } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 2ccbd8edeb..4a9ff918af 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -58,11 +58,7 @@ private[spark] class ResultTask[T, U]( ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) metrics = Some(context.taskMetrics) - try { - func(context, rdd.iterator(partition, context)) - } finally { - context.markTaskCompleted() - } + func(context, rdd.iterator(partition, context)) } // This is only callable on the driver side. diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index a98ee11825..79709089c0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -78,8 +78,6 @@ private[spark] class ShuffleMapTask( log.debug("Could not stop writer", e) } throw e - } finally { - context.markTaskCompleted() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index bf73f6f7bd..c6e47c84a0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -52,7 +52,12 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex if (_killed) { kill(interruptThread = false) } - runTask(context) + try { + runTask(context) + } finally { + context.markTaskCompleted() + TaskContext.unset() + } } def runTask(context: TaskContext): T @@ -93,7 +98,6 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex if (interruptThread && taskThread != null) { taskThread.interrupt() } - TaskContext.remove() } } |