aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/java/org/apache/spark/TaskContext.java225
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContextHelper.scala29
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContextImpl.scala91
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Task.scala10
-rw-r--r--core/src/test/java/org/apache/spark/JavaAPISuite.java2
-rw-r--r--core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java4
-rw-r--r--core/src/test/scala/org/apache/spark/CacheManagerSuite.scala8
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala8
-rw-r--r--project/MimaBuild.scala2
-rw-r--r--project/MimaExcludes.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala4
16 files changed, 186 insertions, 223 deletions
diff --git a/core/src/main/java/org/apache/spark/TaskContext.java b/core/src/main/java/org/apache/spark/TaskContext.java
index 4e6d708af0..2d998d4c7a 100644
--- a/core/src/main/java/org/apache/spark/TaskContext.java
+++ b/core/src/main/java/org/apache/spark/TaskContext.java
@@ -18,131 +18,55 @@
package org.apache.spark;
import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.List;
import scala.Function0;
import scala.Function1;
import scala.Unit;
-import scala.collection.JavaConversions;
import org.apache.spark.annotation.DeveloperApi;
import org.apache.spark.executor.TaskMetrics;
import org.apache.spark.util.TaskCompletionListener;
-import org.apache.spark.util.TaskCompletionListenerException;
/**
-* :: DeveloperApi ::
-* Contextual information about a task which can be read or mutated during execution.
-*/
-@DeveloperApi
-public class TaskContext implements Serializable {
-
- private int stageId;
- private int partitionId;
- private long attemptId;
- private boolean runningLocally;
- private TaskMetrics taskMetrics;
-
- /**
- * :: DeveloperApi ::
- * Contextual information about a task which can be read or mutated during execution.
- *
- * @param stageId stage id
- * @param partitionId index of the partition
- * @param attemptId the number of attempts to execute this task
- * @param runningLocally whether the task is running locally in the driver JVM
- * @param taskMetrics performance metrics of the task
- */
- @DeveloperApi
- public TaskContext(int stageId, int partitionId, long attemptId, boolean runningLocally,
- TaskMetrics taskMetrics) {
- this.attemptId = attemptId;
- this.partitionId = partitionId;
- this.runningLocally = runningLocally;
- this.stageId = stageId;
- this.taskMetrics = taskMetrics;
- }
-
- /**
- * :: DeveloperApi ::
- * Contextual information about a task which can be read or mutated during execution.
- *
- * @param stageId stage id
- * @param partitionId index of the partition
- * @param attemptId the number of attempts to execute this task
- * @param runningLocally whether the task is running locally in the driver JVM
- */
- @DeveloperApi
- public TaskContext(int stageId, int partitionId, long attemptId, boolean runningLocally) {
- this.attemptId = attemptId;
- this.partitionId = partitionId;
- this.runningLocally = runningLocally;
- this.stageId = stageId;
- this.taskMetrics = TaskMetrics.empty();
- }
-
+ * Contextual information about a task which can be read or mutated during
+ * execution. To access the TaskContext for a running task use
+ * TaskContext.get().
+ */
+public abstract class TaskContext implements Serializable {
/**
- * :: DeveloperApi ::
- * Contextual information about a task which can be read or mutated during execution.
- *
- * @param stageId stage id
- * @param partitionId index of the partition
- * @param attemptId the number of attempts to execute this task
+ * Return the currently active TaskContext. This can be called inside of
+ * user functions to access contextual information about running tasks.
*/
- @DeveloperApi
- public TaskContext(int stageId, int partitionId, long attemptId) {
- this.attemptId = attemptId;
- this.partitionId = partitionId;
- this.runningLocally = false;
- this.stageId = stageId;
- this.taskMetrics = TaskMetrics.empty();
+ public static TaskContext get() {
+ return taskContext.get();
}
private static ThreadLocal<TaskContext> taskContext =
new ThreadLocal<TaskContext>();
- /**
- * :: Internal API ::
- * This is spark internal API, not intended to be called from user programs.
- */
- public static void setTaskContext(TaskContext tc) {
+ static void setTaskContext(TaskContext tc) {
taskContext.set(tc);
}
- public static TaskContext get() {
- return taskContext.get();
- }
-
- /** :: Internal API :: */
- public static void unset() {
+ static void unset() {
taskContext.remove();
}
- // List of callback functions to execute when the task completes.
- private transient List<TaskCompletionListener> onCompleteCallbacks =
- new ArrayList<TaskCompletionListener>();
-
- // Whether the corresponding task has been killed.
- private volatile boolean interrupted = false;
-
- // Whether the task has completed.
- private volatile boolean completed = false;
-
/**
- * Checks whether the task has completed.
+ * Whether the task has completed.
*/
- public boolean isCompleted() {
- return completed;
- }
+ public abstract boolean isCompleted();
/**
- * Checks whether the task has been killed.
+ * Whether the task has been killed.
*/
- public boolean isInterrupted() {
- return interrupted;
- }
+ public abstract boolean isInterrupted();
+
+ /** @deprecated: use isRunningLocally() */
+ @Deprecated
+ public abstract boolean runningLocally();
+
+ public abstract boolean isRunningLocally();
/**
* Add a (Java friendly) listener to be executed on task completion.
@@ -150,10 +74,7 @@ public class TaskContext implements Serializable {
* <p/>
* An example use is for HadoopRDD to register a callback to close the input stream.
*/
- public TaskContext addTaskCompletionListener(TaskCompletionListener listener) {
- onCompleteCallbacks.add(listener);
- return this;
- }
+ public abstract TaskContext addTaskCompletionListener(TaskCompletionListener listener);
/**
* Add a listener in the form of a Scala closure to be executed on task completion.
@@ -161,109 +82,27 @@ public class TaskContext implements Serializable {
* <p/>
* An example use is for HadoopRDD to register a callback to close the input stream.
*/
- public TaskContext addTaskCompletionListener(final Function1<TaskContext, Unit> f) {
- onCompleteCallbacks.add(new TaskCompletionListener() {
- @Override
- public void onTaskCompletion(TaskContext context) {
- f.apply(context);
- }
- });
- return this;
- }
+ public abstract TaskContext addTaskCompletionListener(final Function1<TaskContext, Unit> f);
/**
* Add a callback function to be executed on task completion. An example use
* is for HadoopRDD to register a callback to close the input stream.
* Will be called in any situation - success, failure, or cancellation.
*
- * Deprecated: use addTaskCompletionListener
- *
+ * @deprecated: use addTaskCompletionListener
+ *
* @param f Callback function.
*/
@Deprecated
- public void addOnCompleteCallback(final Function0<Unit> f) {
- onCompleteCallbacks.add(new TaskCompletionListener() {
- @Override
- public void onTaskCompletion(TaskContext context) {
- f.apply();
- }
- });
- }
-
- /**
- * ::Internal API::
- * Marks the task as completed and triggers the listeners.
- */
- public void markTaskCompleted() throws TaskCompletionListenerException {
- completed = true;
- List<String> errorMsgs = new ArrayList<String>(2);
- // Process complete callbacks in the reverse order of registration
- List<TaskCompletionListener> revlist =
- new ArrayList<TaskCompletionListener>(onCompleteCallbacks);
- Collections.reverse(revlist);
- for (TaskCompletionListener tcl: revlist) {
- try {
- tcl.onTaskCompletion(this);
- } catch (Throwable e) {
- errorMsgs.add(e.getMessage());
- }
- }
-
- if (!errorMsgs.isEmpty()) {
- throw new TaskCompletionListenerException(JavaConversions.asScalaBuffer(errorMsgs));
- }
- }
-
- /**
- * ::Internal API::
- * Marks the task for interruption, i.e. cancellation.
- */
- public void markInterrupted() {
- interrupted = true;
- }
-
- @Deprecated
- /** Deprecated: use getStageId() */
- public int stageId() {
- return stageId;
- }
-
- @Deprecated
- /** Deprecated: use getPartitionId() */
- public int partitionId() {
- return partitionId;
- }
-
- @Deprecated
- /** Deprecated: use getAttemptId() */
- public long attemptId() {
- return attemptId;
- }
-
- @Deprecated
- /** Deprecated: use isRunningLocally() */
- public boolean runningLocally() {
- return runningLocally;
- }
-
- public boolean isRunningLocally() {
- return runningLocally;
- }
+ public abstract void addOnCompleteCallback(final Function0<Unit> f);
- public int getStageId() {
- return stageId;
- }
+ public abstract int stageId();
- public int getPartitionId() {
- return partitionId;
- }
+ public abstract int partitionId();
- public long getAttemptId() {
- return attemptId;
- }
+ public abstract long attemptId();
- /** ::Internal API:: */
- public TaskMetrics taskMetrics() {
- return taskMetrics;
- }
+ /** ::DeveloperApi:: */
+ @DeveloperApi
+ public abstract TaskMetrics taskMetrics();
}
diff --git a/core/src/main/scala/org/apache/spark/TaskContextHelper.scala b/core/src/main/scala/org/apache/spark/TaskContextHelper.scala
new file mode 100644
index 0000000000..4636c4600a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/TaskContextHelper.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+/**
+ * This class exists to restrict the visibility of TaskContext setters.
+ */
+private [spark] object TaskContextHelper {
+
+ def setTaskContext(tc: TaskContext): Unit = TaskContext.setTaskContext(tc)
+
+ def unset(): Unit = TaskContext.unset()
+
+}
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
new file mode 100644
index 0000000000..afd2b85d33
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException}
+
+import scala.collection.mutable.ArrayBuffer
+
+private[spark] class TaskContextImpl(val stageId: Int,
+ val partitionId: Int,
+ val attemptId: Long,
+ val runningLocally: Boolean = false,
+ val taskMetrics: TaskMetrics = TaskMetrics.empty)
+ extends TaskContext
+ with Logging {
+
+ // List of callback functions to execute when the task completes.
+ @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener]
+
+ // Whether the corresponding task has been killed.
+ @volatile private var interrupted: Boolean = false
+
+ // Whether the task has completed.
+ @volatile private var completed: Boolean = false
+
+ override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {
+ onCompleteCallbacks += listener
+ this
+ }
+
+ override def addTaskCompletionListener(f: TaskContext => Unit): this.type = {
+ onCompleteCallbacks += new TaskCompletionListener {
+ override def onTaskCompletion(context: TaskContext): Unit = f(context)
+ }
+ this
+ }
+
+ @deprecated("use addTaskCompletionListener", "1.1.0")
+ override def addOnCompleteCallback(f: () => Unit) {
+ onCompleteCallbacks += new TaskCompletionListener {
+ override def onTaskCompletion(context: TaskContext): Unit = f()
+ }
+ }
+
+ /** Marks the task as completed and triggers the listeners. */
+ private[spark] def markTaskCompleted(): Unit = {
+ completed = true
+ val errorMsgs = new ArrayBuffer[String](2)
+ // Process complete callbacks in the reverse order of registration
+ onCompleteCallbacks.reverse.foreach { listener =>
+ try {
+ listener.onTaskCompletion(this)
+ } catch {
+ case e: Throwable =>
+ errorMsgs += e.getMessage
+ logError("Error in TaskCompletionListener", e)
+ }
+ }
+ if (errorMsgs.nonEmpty) {
+ throw new TaskCompletionListenerException(errorMsgs)
+ }
+ }
+
+ /** Marks the task for interruption, i.e. cancellation. */
+ private[spark] def markInterrupted(): Unit = {
+ interrupted = true
+ }
+
+ override def isCompleted: Boolean = completed
+
+ override def isRunningLocally: Boolean = runningLocally
+
+ override def isInterrupted: Boolean = interrupted
+}
+
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index 6b63eb23e9..8010dd9008 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -196,7 +196,7 @@ class HadoopRDD[K, V](
val jobConf = getJobConf()
val inputFormat = getInputFormat(jobConf)
HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime),
- context.getStageId, theSplit.index, context.getAttemptId.toInt, jobConf)
+ context.stageId, theSplit.index, context.attemptId.toInt, jobConf)
reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
// Register an on-task-completion callback to close the input stream.
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index 0d97506450..929ded58a3 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -956,9 +956,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
val writeShard = (context: TaskContext, iter: Iterator[(K,V)]) => {
// Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
// around by taking a mod. We expect that no task will be attempted 2 billion times.
- val attemptNumber = (context.getAttemptId % Int.MaxValue).toInt
+ val attemptNumber = (context.attemptId % Int.MaxValue).toInt
/* "reduce task" <split #> <attempt # = spark task #> */
- val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.getPartitionId,
+ val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId,
attemptNumber)
val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId)
val format = outfmt.newInstance
@@ -1027,9 +1027,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
val writeToFile = (context: TaskContext, iter: Iterator[(K, V)]) => {
// Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
// around by taking a mod. We expect that no task will be attempted 2 billion times.
- val attemptNumber = (context.getAttemptId % Int.MaxValue).toInt
+ val attemptNumber = (context.attemptId % Int.MaxValue).toInt
- writer.setup(context.getStageId, context.getPartitionId, attemptNumber)
+ writer.setup(context.stageId, context.partitionId, attemptNumber)
writer.open()
try {
var count = 0
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 788eb1ff4e..f81fa6d808 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -633,14 +633,14 @@ class DAGScheduler(
val rdd = job.finalStage.rdd
val split = rdd.partitions(job.partitions(0))
val taskContext =
- new TaskContext(job.finalStage.id, job.partitions(0), 0, true)
- TaskContext.setTaskContext(taskContext)
+ new TaskContextImpl(job.finalStage.id, job.partitions(0), 0, true)
+ TaskContextHelper.setTaskContext(taskContext)
try {
val result = job.func(taskContext, rdd.iterator(split, taskContext))
job.listener.taskSucceeded(0, result)
} finally {
taskContext.markTaskCompleted()
- TaskContext.unset()
+ TaskContextHelper.unset()
}
} catch {
case e: Exception =>
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index c6e47c84a0..2552d03d18 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -22,7 +22,7 @@ import java.nio.ByteBuffer
import scala.collection.mutable.HashMap
-import org.apache.spark.TaskContext
+import org.apache.spark.{TaskContextHelper, TaskContextImpl, TaskContext}
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.util.ByteBufferInputStream
@@ -45,8 +45,8 @@ import org.apache.spark.util.Utils
private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
final def run(attemptId: Long): T = {
- context = new TaskContext(stageId, partitionId, attemptId, false)
- TaskContext.setTaskContext(context)
+ context = new TaskContextImpl(stageId, partitionId, attemptId, false)
+ TaskContextHelper.setTaskContext(context)
context.taskMetrics.hostname = Utils.localHostName()
taskThread = Thread.currentThread()
if (_killed) {
@@ -56,7 +56,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
runTask(context)
} finally {
context.markTaskCompleted()
- TaskContext.unset()
+ TaskContextHelper.unset()
}
}
@@ -70,7 +70,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
var metrics: Option[TaskMetrics] = None
// Task context, to be initialized in run().
- @transient protected var context: TaskContext = _
+ @transient protected var context: TaskContextImpl = _
// The actual Thread on which the task is running, if any. Initialized in run().
@volatile @transient private var taskThread: Thread = _
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index 4a07843544..b8fa822ae4 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -776,7 +776,7 @@ public class JavaAPISuite implements Serializable {
@Test
public void iterator() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
- TaskContext context = new TaskContext(0, 0, 0L, false, new TaskMetrics());
+ TaskContext context = new TaskContextImpl(0, 0, 0L, false, new TaskMetrics());
Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue());
}
diff --git a/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java b/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java
index 0944bf8cd5..e9ec700e32 100644
--- a/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java
+++ b/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java
@@ -30,8 +30,8 @@ public class JavaTaskCompletionListenerImpl implements TaskCompletionListener {
public void onTaskCompletion(TaskContext context) {
context.isCompleted();
context.isInterrupted();
- context.getStageId();
- context.getPartitionId();
+ context.stageId();
+ context.partitionId();
context.isRunningLocally();
context.addTaskCompletionListener(this);
}
diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
index d735010d7c..c0735f448d 100644
--- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
@@ -66,7 +66,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
// in blockManager.put is a losing battle. You have been warned.
blockManager = sc.env.blockManager
cacheManager = sc.env.cacheManager
- val context = new TaskContext(0, 0, 0)
+ val context = new TaskContextImpl(0, 0, 0)
val computeValue = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
val getValue = blockManager.get(RDDBlockId(rdd.id, split.index))
assert(computeValue.toList === List(1, 2, 3, 4))
@@ -81,7 +81,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
}
whenExecuting(blockManager) {
- val context = new TaskContext(0, 0, 0)
+ val context = new TaskContextImpl(0, 0, 0)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(5, 6, 7))
}
@@ -94,7 +94,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
}
whenExecuting(blockManager) {
- val context = new TaskContext(0, 0, 0, true)
+ val context = new TaskContextImpl(0, 0, 0, true)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(1, 2, 3, 4))
}
@@ -102,7 +102,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
test("verify task metrics updated correctly") {
cacheManager = sc.env.cacheManager
- val context = new TaskContext(0, 0, 0)
+ val context = new TaskContextImpl(0, 0, 0)
cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY)
assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2)
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
index be972c5e97..271a90c664 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
@@ -174,7 +174,7 @@ class PipedRDDSuite extends FunSuite with SharedSparkContext {
}
val hadoopPart1 = generateFakeHadoopPartition()
val pipedRdd = new PipedRDD(nums, "printenv " + varName)
- val tContext = new TaskContext(0, 0, 0)
+ val tContext = new TaskContextImpl(0, 0, 0)
val rddIter = pipedRdd.compute(hadoopPart1, tContext)
val arr = rddIter.toArray
assert(arr(0) == "/some/path")
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index faba5508c9..561a5e9cd9 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -51,7 +51,7 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte
}
test("all TaskCompletionListeners should be called even if some fail") {
- val context = new TaskContext(0, 0, 0)
+ val context = new TaskContextImpl(0, 0, 0)
val listener = mock(classOf[TaskCompletionListener])
context.addTaskCompletionListener(_ => throw new Exception("blah"))
context.addTaskCompletionListener(listener)
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index 809bd70929..a8c049d749 100644
--- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.storage
-import org.apache.spark.TaskContext
+import org.apache.spark.{TaskContextImpl, TaskContext}
import org.apache.spark.network.{BlockFetchingListener, BlockTransferService}
import org.mockito.Mockito._
@@ -62,7 +62,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite {
)
val iterator = new ShuffleBlockFetcherIterator(
- new TaskContext(0, 0, 0),
+ new TaskContextImpl(0, 0, 0),
transfer,
blockManager,
blocksByAddress,
@@ -120,7 +120,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite {
)
val iterator = new ShuffleBlockFetcherIterator(
- new TaskContext(0, 0, 0),
+ new TaskContextImpl(0, 0, 0),
transfer,
blockManager,
blocksByAddress,
@@ -169,7 +169,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite {
(bmId, Seq((blId1, 1L), (blId2, 1L))))
val iterator = new ShuffleBlockFetcherIterator(
- new TaskContext(0, 0, 0),
+ new TaskContextImpl(0, 0, 0),
transfer,
blockManager,
blocksByAddress,
diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala
index 39f8ba4745..d919b18e09 100644
--- a/project/MimaBuild.scala
+++ b/project/MimaBuild.scala
@@ -32,7 +32,7 @@ object MimaBuild {
ProblemFilters.exclude[MissingMethodProblem](fullName),
// Sometimes excluded methods have default arguments and
// they are translated into public methods/fields($default$) in generated
- // bytecode. It is not possible to exhustively list everything.
+ // bytecode. It is not possible to exhaustively list everything.
// But this should be okay.
ProblemFilters.exclude[MissingMethodProblem](fullName+"$default$2"),
ProblemFilters.exclude[MissingMethodProblem](fullName+"$default$1"),
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index d499302124..350aad4773 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -50,7 +50,11 @@ object MimaExcludes {
"org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL2"),
// MapStatus should be private[spark]
ProblemFilters.exclude[IncompatibleTemplateDefProblem](
- "org.apache.spark.scheduler.MapStatus")
+ "org.apache.spark.scheduler.MapStatus"),
+ // TaskContext was promoted to Abstract class
+ ProblemFilters.exclude[AbstractClassProblem](
+ "org.apache.spark.TaskContext")
+
)
case v if v.startsWith("1.1") =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
index 1f4237d7ed..5c6fa78ae3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
@@ -289,9 +289,9 @@ case class InsertIntoParquetTable(
def writeShard(context: TaskContext, iter: Iterator[Row]): Int = {
// Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
// around by taking a mod. We expect that no task will be attempted 2 billion times.
- val attemptNumber = (context.getAttemptId % Int.MaxValue).toInt
+ val attemptNumber = (context.attemptId % Int.MaxValue).toInt
/* "reduce task" <split #> <attempt # = spark task #> */
- val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.getPartitionId,
+ val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId,
attemptNumber)
val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId)
val format = new AppendingParquetOutputFormat(taskIdOffset)