aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/java/org/apache/spark/TaskContext.java24
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContextImpl.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala17
-rw-r--r--core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Task.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala46
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala3
-rw-r--r--core/src/test/java/org/apache/spark/JavaAPISuite.java2
-rw-r--r--core/src/test/scala/org/apache/spark/CacheManagerSuite.scala8
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala31
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala6
-rw-r--r--core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala2
-rw-r--r--project/MimaExcludes.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala5
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala5
25 files changed, 168 insertions, 48 deletions
diff --git a/core/src/main/java/org/apache/spark/TaskContext.java b/core/src/main/java/org/apache/spark/TaskContext.java
index 0d6973203e..095f9fb94f 100644
--- a/core/src/main/java/org/apache/spark/TaskContext.java
+++ b/core/src/main/java/org/apache/spark/TaskContext.java
@@ -62,7 +62,7 @@ public abstract class TaskContext implements Serializable {
*/
public abstract boolean isInterrupted();
- /** @deprecated: use isRunningLocally() */
+ /** @deprecated use {@link #isRunningLocally()} */
@Deprecated
public abstract boolean runningLocally();
@@ -87,19 +87,39 @@ public abstract class TaskContext implements Serializable {
* is for HadoopRDD to register a callback to close the input stream.
* Will be called in any situation - success, failure, or cancellation.
*
- * @deprecated: use addTaskCompletionListener
+ * @deprecated use {@link #addTaskCompletionListener(scala.Function1)}
*
* @param f Callback function.
*/
@Deprecated
public abstract void addOnCompleteCallback(final Function0<Unit> f);
+ /**
+ * The ID of the stage that this task belong to.
+ */
public abstract int stageId();
+ /**
+ * The ID of the RDD partition that is computed by this task.
+ */
public abstract int partitionId();
+ /**
+ * How many times this task has been attempted. The first task attempt will be assigned
+ * attemptNumber = 0, and subsequent attempts will have increasing attempt numbers.
+ */
+ public abstract int attemptNumber();
+
+ /** @deprecated use {@link #taskAttemptId()}; it was renamed to avoid ambiguity. */
+ @Deprecated
public abstract long attemptId();
+ /**
+ * An ID that is unique to this task attempt (within the same SparkContext, no two task attempts
+ * will share the same attempt ID). This is roughly equivalent to Hadoop's TaskAttemptID.
+ */
+ public abstract long taskAttemptId();
+
/** ::DeveloperApi:: */
@DeveloperApi
public abstract TaskMetrics taskMetrics();
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index afd2b85d33..9bb0c61e44 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -22,14 +22,19 @@ import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerExce
import scala.collection.mutable.ArrayBuffer
-private[spark] class TaskContextImpl(val stageId: Int,
+private[spark] class TaskContextImpl(
+ val stageId: Int,
val partitionId: Int,
- val attemptId: Long,
+ override val taskAttemptId: Long,
+ override val attemptNumber: Int,
val runningLocally: Boolean = false,
val taskMetrics: TaskMetrics = TaskMetrics.empty)
extends TaskContext
with Logging {
+ // For backwards-compatibility; this method is now deprecated as of 1.3.0.
+ override def attemptId: Long = taskAttemptId
+
// List of callback functions to execute when the task completes.
@transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener]
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index c794a7bc35..9a4adfbbb3 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -71,7 +71,8 @@ private[spark] class CoarseGrainedExecutorBackend(
val ser = env.closureSerializer.newInstance()
val taskDesc = ser.deserialize[TaskDescription](data.value)
logInfo("Got assigned task " + taskDesc.taskId)
- executor.launchTask(this, taskDesc.taskId, taskDesc.name, taskDesc.serializedTask)
+ executor.launchTask(this, taskId = taskDesc.taskId, attemptNumber = taskDesc.attemptNumber,
+ taskDesc.name, taskDesc.serializedTask)
}
case KillTask(taskId, _, interruptThread) =>
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 0f99cd9f3b..b75c77b5b4 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -108,8 +108,13 @@ private[spark] class Executor(
startDriverHeartbeater()
def launchTask(
- context: ExecutorBackend, taskId: Long, taskName: String, serializedTask: ByteBuffer) {
- val tr = new TaskRunner(context, taskId, taskName, serializedTask)
+ context: ExecutorBackend,
+ taskId: Long,
+ attemptNumber: Int,
+ taskName: String,
+ serializedTask: ByteBuffer) {
+ val tr = new TaskRunner(context, taskId = taskId, attemptNumber = attemptNumber, taskName,
+ serializedTask)
runningTasks.put(taskId, tr)
threadPool.execute(tr)
}
@@ -134,7 +139,11 @@ private[spark] class Executor(
private def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum
class TaskRunner(
- execBackend: ExecutorBackend, val taskId: Long, taskName: String, serializedTask: ByteBuffer)
+ execBackend: ExecutorBackend,
+ val taskId: Long,
+ val attemptNumber: Int,
+ taskName: String,
+ serializedTask: ByteBuffer)
extends Runnable {
@volatile private var killed = false
@@ -180,7 +189,7 @@ private[spark] class Executor(
// Run the actual task and measure its runtime.
taskStart = System.currentTimeMillis()
- val value = task.run(taskId.toInt)
+ val value = task.run(taskAttemptId = taskId, attemptNumber = attemptNumber)
val taskFinish = System.currentTimeMillis()
// If the task has been killed, let's fail it.
diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
index 2e23ae0a4f..cfd672e1d8 100644
--- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
@@ -28,6 +28,7 @@ import org.apache.mesos.Protos.{TaskStatus => MesosTaskStatus, _}
import org.apache.spark.{Logging, TaskState, SparkConf, SparkEnv}
import org.apache.spark.TaskState.TaskState
import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.scheduler.cluster.mesos.{MesosTaskLaunchData}
import org.apache.spark.util.{SignalLogger, Utils}
private[spark] class MesosExecutorBackend
@@ -77,11 +78,13 @@ private[spark] class MesosExecutorBackend
override def launchTask(d: ExecutorDriver, taskInfo: TaskInfo) {
val taskId = taskInfo.getTaskId.getValue.toLong
+ val taskData = MesosTaskLaunchData.fromByteString(taskInfo.getData)
if (executor == null) {
logError("Received launchTask but executor was null")
} else {
SparkHadoopUtil.get.runAsSparkUser { () =>
- executor.launchTask(this, taskId, taskInfo.getName, taskInfo.getData.asReadOnlyByteBuffer)
+ executor.launchTask(this, taskId = taskId, attemptNumber = taskData.attemptNumber,
+ taskInfo.getName, taskData.serializedTask)
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
index 7ba1182f0e..1c13e2c372 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
@@ -95,7 +95,8 @@ private[spark] object CheckpointRDD extends Logging {
val finalOutputName = splitIdToFile(ctx.partitionId)
val finalOutputPath = new Path(outputDir, finalOutputName)
- val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptId)
+ val tempOutputPath =
+ new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptNumber)
if (fs.exists(tempOutputPath)) {
throw new IOException("Checkpoint failed: temporary path " +
@@ -119,7 +120,7 @@ private[spark] object CheckpointRDD extends Logging {
logInfo("Deleting tempOutputPath " + tempOutputPath)
fs.delete(tempOutputPath, false)
throw new IOException("Checkpoint failed: failed to save output of task: "
- + ctx.attemptId + " and final output path does not exist")
+ + ctx.attemptNumber + " and final output path does not exist")
} else {
// Some other copy of this task must've finished before us and renamed it
logInfo("Final output path " + finalOutputPath + " already exists; not overwriting it")
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index 0001c2329c..37e0c13029 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -229,7 +229,7 @@ class HadoopRDD[K, V](
var reader: RecordReader[K, V] = null
val inputFormat = getInputFormat(jobConf)
HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime),
- context.stageId, theSplit.index, context.attemptId.toInt, jobConf)
+ context.stageId, theSplit.index, context.attemptNumber, jobConf)
reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
// Register an on-task-completion callback to close the input stream.
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index 38f8f36a4a..e43e506665 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -978,12 +978,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
val writeShard = (context: TaskContext, iter: Iterator[(K,V)]) => {
val config = wrappedConf.value
- // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
- // around by taking a mod. We expect that no task will be attempted 2 billion times.
- val attemptNumber = (context.attemptId % Int.MaxValue).toInt
/* "reduce task" <split #> <attempt # = spark task #> */
val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId,
- attemptNumber)
+ context.attemptNumber)
val hadoopContext = newTaskAttemptContext(config, attemptId)
val format = outfmt.newInstance
format match {
@@ -1062,11 +1059,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
val config = wrappedConf.value
// Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
// around by taking a mod. We expect that no task will be attempted 2 billion times.
- val attemptNumber = (context.attemptId % Int.MaxValue).toInt
+ val taskAttemptId = (context.taskAttemptId % Int.MaxValue).toInt
val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context, config)
- writer.setup(context.stageId, context.partitionId, attemptNumber)
+ writer.setup(context.stageId, context.partitionId, taskAttemptId)
writer.open()
try {
var recordsWritten = 0L
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 61d09d73e1..8cb15918ba 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -635,8 +635,8 @@ class DAGScheduler(
try {
val rdd = job.finalStage.rdd
val split = rdd.partitions(job.partitions(0))
- val taskContext =
- new TaskContextImpl(job.finalStage.id, job.partitions(0), 0, true)
+ val taskContext = new TaskContextImpl(job.finalStage.id, job.partitions(0), taskAttemptId = 0,
+ attemptNumber = 0, runningLocally = true)
TaskContextHelper.setTaskContext(taskContext)
try {
val result = job.func(taskContext, rdd.iterator(split, taskContext))
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index d7dde4fe38..2367f7e2cf 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -44,8 +44,16 @@ import org.apache.spark.util.Utils
*/
private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
- final def run(attemptId: Long): T = {
- context = new TaskContextImpl(stageId, partitionId, attemptId, runningLocally = false)
+ /**
+ * Called by Executor to run this task.
+ *
+ * @param taskAttemptId an identifier for this task attempt that is unique within a SparkContext.
+ * @param attemptNumber how many times this task has been attempted (0 for the first attempt)
+ * @return the result of the task
+ */
+ final def run(taskAttemptId: Long, attemptNumber: Int): T = {
+ context = new TaskContextImpl(stageId = stageId, partitionId = partitionId,
+ taskAttemptId = taskAttemptId, attemptNumber = attemptNumber, runningLocally = false)
TaskContextHelper.setTaskContext(context)
context.taskMetrics.hostname = Utils.localHostName()
taskThread = Thread.currentThread()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala
index 4c96b9e5fe..1c7c81c488 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala
@@ -27,6 +27,7 @@ import org.apache.spark.util.SerializableBuffer
*/
private[spark] class TaskDescription(
val taskId: Long,
+ val attemptNumber: Int,
val executorId: String,
val name: String,
val index: Int, // Index within this task's TaskSet
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 4667850917..5c94c6bbcb 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -487,7 +487,8 @@ private[spark] class TaskSetManager(
taskName, taskId, host, taskLocality, serializedTask.limit))
sched.dagScheduler.taskStarted(task, info)
- return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask))
+ return Some(new TaskDescription(taskId = taskId, attemptNumber = attemptNum, execId,
+ taskName, index, serializedTask))
}
case _ =>
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
index 10e6886c16..75d8ddf375 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
@@ -22,7 +22,7 @@ import java.util.{ArrayList => JArrayList, List => JList}
import java.util.Collections
import scala.collection.JavaConversions._
-import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
+import scala.collection.mutable.{HashMap, HashSet}
import org.apache.mesos.protobuf.ByteString
import org.apache.mesos.{Scheduler => MScheduler}
@@ -296,7 +296,7 @@ private[spark] class MesosSchedulerBackend(
.setExecutor(createExecutorInfo(slaveId))
.setName(task.name)
.addResources(cpuResource)
- .setData(ByteString.copyFrom(task.serializedTask))
+ .setData(MesosTaskLaunchData(task.serializedTask, task.attemptNumber).toByteString)
.build()
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala
new file mode 100644
index 0000000000..4416ce92ad
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster.mesos
+
+import java.nio.ByteBuffer
+
+import org.apache.mesos.protobuf.ByteString
+
+/**
+ * Wrapper for serializing the data sent when launching Mesos tasks.
+ */
+private[spark] case class MesosTaskLaunchData(
+ serializedTask: ByteBuffer,
+ attemptNumber: Int) {
+
+ def toByteString: ByteString = {
+ val dataBuffer = ByteBuffer.allocate(4 + serializedTask.limit)
+ dataBuffer.putInt(attemptNumber)
+ dataBuffer.put(serializedTask)
+ ByteString.copyFrom(dataBuffer)
+ }
+}
+
+private[spark] object MesosTaskLaunchData {
+ def fromByteString(byteString: ByteString): MesosTaskLaunchData = {
+ val byteBuffer = byteString.asReadOnlyByteBuffer()
+ val attemptNumber = byteBuffer.getInt // updates the position by 4 bytes
+ val serializedTask = byteBuffer.slice() // subsequence starting at the current position
+ MesosTaskLaunchData(serializedTask, attemptNumber)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
index b3bd3110ac..05b6fa5456 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
@@ -76,7 +76,8 @@ private[spark] class LocalActor(
val offers = Seq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores))
for (task <- scheduler.resourceOffers(offers).flatten) {
freeCores -= scheduler.CPUS_PER_TASK
- executor.launchTask(executorBackend, task.taskId, task.name, task.serializedTask)
+ executor.launchTask(executorBackend, taskId = task.taskId, attemptNumber = task.attemptNumber,
+ task.name, task.serializedTask)
}
}
}
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index 5ce299d058..07b1e44d04 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -820,7 +820,7 @@ public class JavaAPISuite implements Serializable {
@Test
public void iterator() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
- TaskContext context = new TaskContextImpl(0, 0, 0L, false, new TaskMetrics());
+ TaskContext context = new TaskContextImpl(0, 0, 0L, 0, false, new TaskMetrics());
Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue());
}
diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
index c0735f448d..d7d9dc7b50 100644
--- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
@@ -66,7 +66,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
// in blockManager.put is a losing battle. You have been warned.
blockManager = sc.env.blockManager
cacheManager = sc.env.cacheManager
- val context = new TaskContextImpl(0, 0, 0)
+ val context = new TaskContextImpl(0, 0, 0, 0)
val computeValue = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
val getValue = blockManager.get(RDDBlockId(rdd.id, split.index))
assert(computeValue.toList === List(1, 2, 3, 4))
@@ -81,7 +81,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
}
whenExecuting(blockManager) {
- val context = new TaskContextImpl(0, 0, 0)
+ val context = new TaskContextImpl(0, 0, 0, 0)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(5, 6, 7))
}
@@ -94,7 +94,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
}
whenExecuting(blockManager) {
- val context = new TaskContextImpl(0, 0, 0, true)
+ val context = new TaskContextImpl(0, 0, 0, 0, true)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(1, 2, 3, 4))
}
@@ -102,7 +102,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
test("verify task metrics updated correctly") {
cacheManager = sc.env.cacheManager
- val context = new TaskContextImpl(0, 0, 0)
+ val context = new TaskContextImpl(0, 0, 0, 0)
cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY)
assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2)
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
index 271a90c664..1a9a0e857e 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
@@ -174,7 +174,7 @@ class PipedRDDSuite extends FunSuite with SharedSparkContext {
}
val hadoopPart1 = generateFakeHadoopPartition()
val pipedRdd = new PipedRDD(nums, "printenv " + varName)
- val tContext = new TaskContextImpl(0, 0, 0)
+ val tContext = new TaskContextImpl(0, 0, 0, 0)
val rddIter = pipedRdd.compute(hadoopPart1, tContext)
val arr = rddIter.toArray
assert(arr(0) == "/some/path")
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index 561a5e9cd9..057e226916 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -45,13 +45,13 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte
val task = new ResultTask[String, String](
0, sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0)
intercept[RuntimeException] {
- task.run(0)
+ task.run(0, 0)
}
assert(TaskContextSuite.completed === true)
}
test("all TaskCompletionListeners should be called even if some fail") {
- val context = new TaskContextImpl(0, 0, 0)
+ val context = new TaskContextImpl(0, 0, 0, 0)
val listener = mock(classOf[TaskCompletionListener])
context.addTaskCompletionListener(_ => throw new Exception("blah"))
context.addTaskCompletionListener(listener)
@@ -63,6 +63,33 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte
verify(listener, times(1)).onTaskCompletion(any())
}
+
+ test("TaskContext.attemptNumber should return attempt number, not task id (SPARK-4014)") {
+ sc = new SparkContext("local[1,2]", "test") // use maxRetries = 2 because we test failed tasks
+ // Check that attemptIds are 0 for all tasks' initial attempts
+ val attemptIds = sc.parallelize(Seq(1, 2), 2).mapPartitions { iter =>
+ Seq(TaskContext.get().attemptNumber).iterator
+ }.collect()
+ assert(attemptIds.toSet === Set(0))
+
+ // Test a job with failed tasks
+ val attemptIdsWithFailedTask = sc.parallelize(Seq(1, 2), 2).mapPartitions { iter =>
+ val attemptId = TaskContext.get().attemptNumber
+ if (iter.next() == 1 && attemptId == 0) {
+ throw new Exception("First execution of task failed")
+ }
+ Seq(attemptId).iterator
+ }.collect()
+ assert(attemptIdsWithFailedTask.toSet === Set(0, 1))
+ }
+
+ test("TaskContext.attemptId returns taskAttemptId for backwards-compatibility (SPARK-4014)") {
+ sc = new SparkContext("local", "test")
+ val attemptIds = sc.parallelize(Seq(1, 2, 3, 4), 4).mapPartitions { iter =>
+ Seq(TaskContext.get().attemptId).iterator
+ }.collect()
+ assert(attemptIds.toSet === Set(0, 1, 2, 3))
+ }
}
private object TaskContextSuite {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala
index e60e70afd3..48f5e40f50 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala
@@ -80,7 +80,7 @@ class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with Ea
mesosOffers.get(2).getHostname,
2
))
- val taskDesc = new TaskDescription(1L, "s1", "n1", 0, ByteBuffer.wrap(new Array[Byte](0)))
+ val taskDesc = new TaskDescription(1L, 0, "s1", "n1", 0, ByteBuffer.wrap(new Array[Byte](0)))
EasyMock.expect(taskScheduler.resourceOffers(EasyMock.eq(expectedWorkerOffers))).andReturn(Seq(Seq(taskDesc)))
EasyMock.expect(taskScheduler.CPUS_PER_TASK).andReturn(2).anyTimes()
EasyMock.replay(taskScheduler)
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index 1eaabb93ad..37b593b2c5 100644
--- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -89,7 +89,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite {
)
val iterator = new ShuffleBlockFetcherIterator(
- new TaskContextImpl(0, 0, 0),
+ new TaskContextImpl(0, 0, 0, 0),
transfer,
blockManager,
blocksByAddress,
@@ -154,7 +154,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite {
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))
- val taskContext = new TaskContextImpl(0, 0, 0)
+ val taskContext = new TaskContextImpl(0, 0, 0, 0)
val iterator = new ShuffleBlockFetcherIterator(
taskContext,
transfer,
@@ -217,7 +217,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite {
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))
- val taskContext = new TaskContextImpl(0, 0, 0)
+ val taskContext = new TaskContextImpl(0, 0, 0, 0)
val iterator = new ShuffleBlockFetcherIterator(
taskContext,
transfer,
diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala
index 787f4c2b5a..e85a436cdb 100644
--- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala
@@ -173,7 +173,7 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers {
// Simulate fetch failures:
val mappedData = data.map { x =>
val taskContext = TaskContext.get
- if (taskContext.attemptId() == 1) { // Cause this stage to fail on its first attempt.
+ if (taskContext.attemptNumber == 0) { // Cause this stage to fail on its first attempt.
val env = SparkEnv.get
val bmAddress = env.blockManager.blockManagerId
val shuffleId = shuffleHandle.shuffleId
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index f6f9f491f4..d3ea594245 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -72,6 +72,12 @@ object MimaExcludes {
"org.apache.spark.ml.classification.LogisticRegressionModel.validateAndTransformSchema"),
ProblemFilters.exclude[IncompatibleMethTypeProblem](
"org.apache.spark.ml.classification.LogisticRegression.validateAndTransformSchema")
+ ) ++ Seq(
+ // SPARK-4014
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.TaskContext.taskAttemptId"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.TaskContext.attemptNumber")
)
case v if v.startsWith("1.2") =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
index f5487740d3..28cd17fde4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
@@ -301,12 +301,9 @@ case class InsertIntoParquetTable(
}
def writeShard(context: TaskContext, iter: Iterator[Row]): Int = {
- // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
- // around by taking a mod. We expect that no task will be attempted 2 billion times.
- val attemptNumber = (context.attemptId % Int.MaxValue).toInt
/* "reduce task" <split #> <attempt # = spark task #> */
val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId,
- attemptNumber)
+ context.attemptNumber)
val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId)
val format = new AppendingParquetOutputFormat(taskIdOffset)
val committer = format.getOutputCommitter(hadoopContext)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
index ca0ec15139..42bc8a0b67 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
@@ -100,10 +100,7 @@ case class InsertIntoHiveTable(
val wrappers = fieldOIs.map(wrapperFor)
val outputData = new Array[Any](fieldOIs.length)
- // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
- // around by taking a mod. We expect that no task will be attempted 2 billion times.
- val attemptNumber = (context.attemptId % Int.MaxValue).toInt
- writerContainer.executorSideSetup(context.stageId, context.partitionId, attemptNumber)
+ writerContainer.executorSideSetup(context.stageId, context.partitionId, context.attemptNumber)
iterator.foreach { row =>
var i = 0