aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2015-01-14 11:45:40 -0800
committerJosh Rosen <joshrosen@databricks.com>2015-01-14 11:45:40 -0800
commit259936be710f367e1d1e019ee7b93fb68dfc33d0 (patch)
treea9637d8e3cdb9f98c748e8155efa9eaadb99437b /core
parent9d4449c4b3d0f5e60e71fef3a9b2e1c58a80b9e8 (diff)
downloadspark-259936be710f367e1d1e019ee7b93fb68dfc33d0.tar.gz
spark-259936be710f367e1d1e019ee7b93fb68dfc33d0.tar.bz2
spark-259936be710f367e1d1e019ee7b93fb68dfc33d0.zip
[SPARK-4014] Add TaskContext.attemptNumber and deprecate TaskContext.attemptId
`TaskContext.attemptId` is misleadingly-named, since it currently returns a taskId, which uniquely identifies a particular task attempt within a particular SparkContext, instead of an attempt number, which conveys how many times a task has been attempted. This patch deprecates `TaskContext.attemptId` and add `TaskContext.taskId` and `TaskContext.attemptNumber` fields. Prior to this change, it was impossible to determine whether a task was being re-attempted (or was a speculative copy), which made it difficult to write unit tests for tasks that fail on early attempts or speculative tasks that complete faster than original tasks. Earlier versions of the TaskContext docs suggest that `attemptId` behaves like `attemptNumber`, so there's an argument to be made in favor of changing this method's implementation. Since we've decided against making that change in maintenance branches, I think it's simpler to add better-named methods and retain the old behavior for `attemptId`; if `attemptId` behaved differently in different branches, then this would cause confusing build-breaks when backporting regression tests that rely on the new `attemptId` behavior. Most of this patch is fairly straightforward, but there is a bit of trickiness related to Mesos tasks: since there's no field in MesosTaskInfo to encode the attemptId, I packed it into the `data` field alongside the task binary. Author: Josh Rosen <joshrosen@databricks.com> Closes #3849 from JoshRosen/SPARK-4014 and squashes the following commits: 89d03e0 [Josh Rosen] Merge remote-tracking branch 'origin/master' into SPARK-4014 5cfff05 [Josh Rosen] Introduce wrapper for serializing Mesos task launch data. 38574d4 [Josh Rosen] attemptId -> taskAttemptId in PairRDDFunctions a180b88 [Josh Rosen] Merge remote-tracking branch 'origin/master' into SPARK-4014 1d43aa6 [Josh Rosen] Merge remote-tracking branch 'origin/master' into SPARK-4014 eee6a45 [Josh Rosen] Merge remote-tracking branch 'origin/master' into SPARK-4014 0b10526 [Josh Rosen] Use putInt instead of putLong (silly mistake) 8c387ce [Josh Rosen] Use local with maxRetries instead of local-cluster. cbe4d76 [Josh Rosen] Preserve attemptId behavior and deprecate it: b2dffa3 [Josh Rosen] Address some of Reynold's minor comments 9d8d4d1 [Josh Rosen] Doc typo 1e7a933 [Josh Rosen] [SPARK-4014] Change TaskContext.attemptId to return attempt number instead of task ID. fd515a5 [Josh Rosen] Add failing test for SPARK-4014
Diffstat (limited to 'core')
-rw-r--r--core/src/main/java/org/apache/spark/TaskContext.java24
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContextImpl.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala17
-rw-r--r--core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Task.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala46
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala3
-rw-r--r--core/src/test/java/org/apache/spark/JavaAPISuite.java2
-rw-r--r--core/src/test/scala/org/apache/spark/CacheManagerSuite.scala8
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala31
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala6
-rw-r--r--core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala2
22 files changed, 160 insertions, 40 deletions
diff --git a/core/src/main/java/org/apache/spark/TaskContext.java b/core/src/main/java/org/apache/spark/TaskContext.java
index 0d6973203e..095f9fb94f 100644
--- a/core/src/main/java/org/apache/spark/TaskContext.java
+++ b/core/src/main/java/org/apache/spark/TaskContext.java
@@ -62,7 +62,7 @@ public abstract class TaskContext implements Serializable {
*/
public abstract boolean isInterrupted();
- /** @deprecated: use isRunningLocally() */
+ /** @deprecated use {@link #isRunningLocally()} */
@Deprecated
public abstract boolean runningLocally();
@@ -87,19 +87,39 @@ public abstract class TaskContext implements Serializable {
* is for HadoopRDD to register a callback to close the input stream.
* Will be called in any situation - success, failure, or cancellation.
*
- * @deprecated: use addTaskCompletionListener
+ * @deprecated use {@link #addTaskCompletionListener(scala.Function1)}
*
* @param f Callback function.
*/
@Deprecated
public abstract void addOnCompleteCallback(final Function0<Unit> f);
+ /**
+ * The ID of the stage that this task belong to.
+ */
public abstract int stageId();
+ /**
+ * The ID of the RDD partition that is computed by this task.
+ */
public abstract int partitionId();
+ /**
+ * How many times this task has been attempted. The first task attempt will be assigned
+ * attemptNumber = 0, and subsequent attempts will have increasing attempt numbers.
+ */
+ public abstract int attemptNumber();
+
+ /** @deprecated use {@link #taskAttemptId()}; it was renamed to avoid ambiguity. */
+ @Deprecated
public abstract long attemptId();
+ /**
+ * An ID that is unique to this task attempt (within the same SparkContext, no two task attempts
+ * will share the same attempt ID). This is roughly equivalent to Hadoop's TaskAttemptID.
+ */
+ public abstract long taskAttemptId();
+
/** ::DeveloperApi:: */
@DeveloperApi
public abstract TaskMetrics taskMetrics();
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index afd2b85d33..9bb0c61e44 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -22,14 +22,19 @@ import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerExce
import scala.collection.mutable.ArrayBuffer
-private[spark] class TaskContextImpl(val stageId: Int,
+private[spark] class TaskContextImpl(
+ val stageId: Int,
val partitionId: Int,
- val attemptId: Long,
+ override val taskAttemptId: Long,
+ override val attemptNumber: Int,
val runningLocally: Boolean = false,
val taskMetrics: TaskMetrics = TaskMetrics.empty)
extends TaskContext
with Logging {
+ // For backwards-compatibility; this method is now deprecated as of 1.3.0.
+ override def attemptId: Long = taskAttemptId
+
// List of callback functions to execute when the task completes.
@transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener]
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index c794a7bc35..9a4adfbbb3 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -71,7 +71,8 @@ private[spark] class CoarseGrainedExecutorBackend(
val ser = env.closureSerializer.newInstance()
val taskDesc = ser.deserialize[TaskDescription](data.value)
logInfo("Got assigned task " + taskDesc.taskId)
- executor.launchTask(this, taskDesc.taskId, taskDesc.name, taskDesc.serializedTask)
+ executor.launchTask(this, taskId = taskDesc.taskId, attemptNumber = taskDesc.attemptNumber,
+ taskDesc.name, taskDesc.serializedTask)
}
case KillTask(taskId, _, interruptThread) =>
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 0f99cd9f3b..b75c77b5b4 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -108,8 +108,13 @@ private[spark] class Executor(
startDriverHeartbeater()
def launchTask(
- context: ExecutorBackend, taskId: Long, taskName: String, serializedTask: ByteBuffer) {
- val tr = new TaskRunner(context, taskId, taskName, serializedTask)
+ context: ExecutorBackend,
+ taskId: Long,
+ attemptNumber: Int,
+ taskName: String,
+ serializedTask: ByteBuffer) {
+ val tr = new TaskRunner(context, taskId = taskId, attemptNumber = attemptNumber, taskName,
+ serializedTask)
runningTasks.put(taskId, tr)
threadPool.execute(tr)
}
@@ -134,7 +139,11 @@ private[spark] class Executor(
private def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum
class TaskRunner(
- execBackend: ExecutorBackend, val taskId: Long, taskName: String, serializedTask: ByteBuffer)
+ execBackend: ExecutorBackend,
+ val taskId: Long,
+ val attemptNumber: Int,
+ taskName: String,
+ serializedTask: ByteBuffer)
extends Runnable {
@volatile private var killed = false
@@ -180,7 +189,7 @@ private[spark] class Executor(
// Run the actual task and measure its runtime.
taskStart = System.currentTimeMillis()
- val value = task.run(taskId.toInt)
+ val value = task.run(taskAttemptId = taskId, attemptNumber = attemptNumber)
val taskFinish = System.currentTimeMillis()
// If the task has been killed, let's fail it.
diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
index 2e23ae0a4f..cfd672e1d8 100644
--- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
@@ -28,6 +28,7 @@ import org.apache.mesos.Protos.{TaskStatus => MesosTaskStatus, _}
import org.apache.spark.{Logging, TaskState, SparkConf, SparkEnv}
import org.apache.spark.TaskState.TaskState
import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.scheduler.cluster.mesos.{MesosTaskLaunchData}
import org.apache.spark.util.{SignalLogger, Utils}
private[spark] class MesosExecutorBackend
@@ -77,11 +78,13 @@ private[spark] class MesosExecutorBackend
override def launchTask(d: ExecutorDriver, taskInfo: TaskInfo) {
val taskId = taskInfo.getTaskId.getValue.toLong
+ val taskData = MesosTaskLaunchData.fromByteString(taskInfo.getData)
if (executor == null) {
logError("Received launchTask but executor was null")
} else {
SparkHadoopUtil.get.runAsSparkUser { () =>
- executor.launchTask(this, taskId, taskInfo.getName, taskInfo.getData.asReadOnlyByteBuffer)
+ executor.launchTask(this, taskId = taskId, attemptNumber = taskData.attemptNumber,
+ taskInfo.getName, taskData.serializedTask)
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
index 7ba1182f0e..1c13e2c372 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
@@ -95,7 +95,8 @@ private[spark] object CheckpointRDD extends Logging {
val finalOutputName = splitIdToFile(ctx.partitionId)
val finalOutputPath = new Path(outputDir, finalOutputName)
- val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptId)
+ val tempOutputPath =
+ new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptNumber)
if (fs.exists(tempOutputPath)) {
throw new IOException("Checkpoint failed: temporary path " +
@@ -119,7 +120,7 @@ private[spark] object CheckpointRDD extends Logging {
logInfo("Deleting tempOutputPath " + tempOutputPath)
fs.delete(tempOutputPath, false)
throw new IOException("Checkpoint failed: failed to save output of task: "
- + ctx.attemptId + " and final output path does not exist")
+ + ctx.attemptNumber + " and final output path does not exist")
} else {
// Some other copy of this task must've finished before us and renamed it
logInfo("Final output path " + finalOutputPath + " already exists; not overwriting it")
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index 0001c2329c..37e0c13029 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -229,7 +229,7 @@ class HadoopRDD[K, V](
var reader: RecordReader[K, V] = null
val inputFormat = getInputFormat(jobConf)
HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime),
- context.stageId, theSplit.index, context.attemptId.toInt, jobConf)
+ context.stageId, theSplit.index, context.attemptNumber, jobConf)
reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
// Register an on-task-completion callback to close the input stream.
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index 38f8f36a4a..e43e506665 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -978,12 +978,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
val writeShard = (context: TaskContext, iter: Iterator[(K,V)]) => {
val config = wrappedConf.value
- // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
- // around by taking a mod. We expect that no task will be attempted 2 billion times.
- val attemptNumber = (context.attemptId % Int.MaxValue).toInt
/* "reduce task" <split #> <attempt # = spark task #> */
val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId,
- attemptNumber)
+ context.attemptNumber)
val hadoopContext = newTaskAttemptContext(config, attemptId)
val format = outfmt.newInstance
format match {
@@ -1062,11 +1059,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
val config = wrappedConf.value
// Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
// around by taking a mod. We expect that no task will be attempted 2 billion times.
- val attemptNumber = (context.attemptId % Int.MaxValue).toInt
+ val taskAttemptId = (context.taskAttemptId % Int.MaxValue).toInt
val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context, config)
- writer.setup(context.stageId, context.partitionId, attemptNumber)
+ writer.setup(context.stageId, context.partitionId, taskAttemptId)
writer.open()
try {
var recordsWritten = 0L
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 61d09d73e1..8cb15918ba 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -635,8 +635,8 @@ class DAGScheduler(
try {
val rdd = job.finalStage.rdd
val split = rdd.partitions(job.partitions(0))
- val taskContext =
- new TaskContextImpl(job.finalStage.id, job.partitions(0), 0, true)
+ val taskContext = new TaskContextImpl(job.finalStage.id, job.partitions(0), taskAttemptId = 0,
+ attemptNumber = 0, runningLocally = true)
TaskContextHelper.setTaskContext(taskContext)
try {
val result = job.func(taskContext, rdd.iterator(split, taskContext))
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index d7dde4fe38..2367f7e2cf 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -44,8 +44,16 @@ import org.apache.spark.util.Utils
*/
private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
- final def run(attemptId: Long): T = {
- context = new TaskContextImpl(stageId, partitionId, attemptId, runningLocally = false)
+ /**
+ * Called by Executor to run this task.
+ *
+ * @param taskAttemptId an identifier for this task attempt that is unique within a SparkContext.
+ * @param attemptNumber how many times this task has been attempted (0 for the first attempt)
+ * @return the result of the task
+ */
+ final def run(taskAttemptId: Long, attemptNumber: Int): T = {
+ context = new TaskContextImpl(stageId = stageId, partitionId = partitionId,
+ taskAttemptId = taskAttemptId, attemptNumber = attemptNumber, runningLocally = false)
TaskContextHelper.setTaskContext(context)
context.taskMetrics.hostname = Utils.localHostName()
taskThread = Thread.currentThread()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala
index 4c96b9e5fe..1c7c81c488 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala
@@ -27,6 +27,7 @@ import org.apache.spark.util.SerializableBuffer
*/
private[spark] class TaskDescription(
val taskId: Long,
+ val attemptNumber: Int,
val executorId: String,
val name: String,
val index: Int, // Index within this task's TaskSet
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 4667850917..5c94c6bbcb 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -487,7 +487,8 @@ private[spark] class TaskSetManager(
taskName, taskId, host, taskLocality, serializedTask.limit))
sched.dagScheduler.taskStarted(task, info)
- return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask))
+ return Some(new TaskDescription(taskId = taskId, attemptNumber = attemptNum, execId,
+ taskName, index, serializedTask))
}
case _ =>
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
index 10e6886c16..75d8ddf375 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
@@ -22,7 +22,7 @@ import java.util.{ArrayList => JArrayList, List => JList}
import java.util.Collections
import scala.collection.JavaConversions._
-import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
+import scala.collection.mutable.{HashMap, HashSet}
import org.apache.mesos.protobuf.ByteString
import org.apache.mesos.{Scheduler => MScheduler}
@@ -296,7 +296,7 @@ private[spark] class MesosSchedulerBackend(
.setExecutor(createExecutorInfo(slaveId))
.setName(task.name)
.addResources(cpuResource)
- .setData(ByteString.copyFrom(task.serializedTask))
+ .setData(MesosTaskLaunchData(task.serializedTask, task.attemptNumber).toByteString)
.build()
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala
new file mode 100644
index 0000000000..4416ce92ad
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster.mesos
+
+import java.nio.ByteBuffer
+
+import org.apache.mesos.protobuf.ByteString
+
+/**
+ * Wrapper for serializing the data sent when launching Mesos tasks.
+ */
+private[spark] case class MesosTaskLaunchData(
+ serializedTask: ByteBuffer,
+ attemptNumber: Int) {
+
+ def toByteString: ByteString = {
+ val dataBuffer = ByteBuffer.allocate(4 + serializedTask.limit)
+ dataBuffer.putInt(attemptNumber)
+ dataBuffer.put(serializedTask)
+ ByteString.copyFrom(dataBuffer)
+ }
+}
+
+private[spark] object MesosTaskLaunchData {
+ def fromByteString(byteString: ByteString): MesosTaskLaunchData = {
+ val byteBuffer = byteString.asReadOnlyByteBuffer()
+ val attemptNumber = byteBuffer.getInt // updates the position by 4 bytes
+ val serializedTask = byteBuffer.slice() // subsequence starting at the current position
+ MesosTaskLaunchData(serializedTask, attemptNumber)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
index b3bd3110ac..05b6fa5456 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
@@ -76,7 +76,8 @@ private[spark] class LocalActor(
val offers = Seq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores))
for (task <- scheduler.resourceOffers(offers).flatten) {
freeCores -= scheduler.CPUS_PER_TASK
- executor.launchTask(executorBackend, task.taskId, task.name, task.serializedTask)
+ executor.launchTask(executorBackend, taskId = task.taskId, attemptNumber = task.attemptNumber,
+ task.name, task.serializedTask)
}
}
}
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index 5ce299d058..07b1e44d04 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -820,7 +820,7 @@ public class JavaAPISuite implements Serializable {
@Test
public void iterator() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
- TaskContext context = new TaskContextImpl(0, 0, 0L, false, new TaskMetrics());
+ TaskContext context = new TaskContextImpl(0, 0, 0L, 0, false, new TaskMetrics());
Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue());
}
diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
index c0735f448d..d7d9dc7b50 100644
--- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala
@@ -66,7 +66,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
// in blockManager.put is a losing battle. You have been warned.
blockManager = sc.env.blockManager
cacheManager = sc.env.cacheManager
- val context = new TaskContextImpl(0, 0, 0)
+ val context = new TaskContextImpl(0, 0, 0, 0)
val computeValue = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
val getValue = blockManager.get(RDDBlockId(rdd.id, split.index))
assert(computeValue.toList === List(1, 2, 3, 4))
@@ -81,7 +81,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
}
whenExecuting(blockManager) {
- val context = new TaskContextImpl(0, 0, 0)
+ val context = new TaskContextImpl(0, 0, 0, 0)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(5, 6, 7))
}
@@ -94,7 +94,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
}
whenExecuting(blockManager) {
- val context = new TaskContextImpl(0, 0, 0, true)
+ val context = new TaskContextImpl(0, 0, 0, 0, true)
val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
assert(value.toList === List(1, 2, 3, 4))
}
@@ -102,7 +102,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
test("verify task metrics updated correctly") {
cacheManager = sc.env.cacheManager
- val context = new TaskContextImpl(0, 0, 0)
+ val context = new TaskContextImpl(0, 0, 0, 0)
cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY)
assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2)
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
index 271a90c664..1a9a0e857e 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala
@@ -174,7 +174,7 @@ class PipedRDDSuite extends FunSuite with SharedSparkContext {
}
val hadoopPart1 = generateFakeHadoopPartition()
val pipedRdd = new PipedRDD(nums, "printenv " + varName)
- val tContext = new TaskContextImpl(0, 0, 0)
+ val tContext = new TaskContextImpl(0, 0, 0, 0)
val rddIter = pipedRdd.compute(hadoopPart1, tContext)
val arr = rddIter.toArray
assert(arr(0) == "/some/path")
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index 561a5e9cd9..057e226916 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -45,13 +45,13 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte
val task = new ResultTask[String, String](
0, sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0)
intercept[RuntimeException] {
- task.run(0)
+ task.run(0, 0)
}
assert(TaskContextSuite.completed === true)
}
test("all TaskCompletionListeners should be called even if some fail") {
- val context = new TaskContextImpl(0, 0, 0)
+ val context = new TaskContextImpl(0, 0, 0, 0)
val listener = mock(classOf[TaskCompletionListener])
context.addTaskCompletionListener(_ => throw new Exception("blah"))
context.addTaskCompletionListener(listener)
@@ -63,6 +63,33 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte
verify(listener, times(1)).onTaskCompletion(any())
}
+
+ test("TaskContext.attemptNumber should return attempt number, not task id (SPARK-4014)") {
+ sc = new SparkContext("local[1,2]", "test") // use maxRetries = 2 because we test failed tasks
+ // Check that attemptIds are 0 for all tasks' initial attempts
+ val attemptIds = sc.parallelize(Seq(1, 2), 2).mapPartitions { iter =>
+ Seq(TaskContext.get().attemptNumber).iterator
+ }.collect()
+ assert(attemptIds.toSet === Set(0))
+
+ // Test a job with failed tasks
+ val attemptIdsWithFailedTask = sc.parallelize(Seq(1, 2), 2).mapPartitions { iter =>
+ val attemptId = TaskContext.get().attemptNumber
+ if (iter.next() == 1 && attemptId == 0) {
+ throw new Exception("First execution of task failed")
+ }
+ Seq(attemptId).iterator
+ }.collect()
+ assert(attemptIdsWithFailedTask.toSet === Set(0, 1))
+ }
+
+ test("TaskContext.attemptId returns taskAttemptId for backwards-compatibility (SPARK-4014)") {
+ sc = new SparkContext("local", "test")
+ val attemptIds = sc.parallelize(Seq(1, 2, 3, 4), 4).mapPartitions { iter =>
+ Seq(TaskContext.get().attemptId).iterator
+ }.collect()
+ assert(attemptIds.toSet === Set(0, 1, 2, 3))
+ }
}
private object TaskContextSuite {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala
index e60e70afd3..48f5e40f50 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala
@@ -80,7 +80,7 @@ class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with Ea
mesosOffers.get(2).getHostname,
2
))
- val taskDesc = new TaskDescription(1L, "s1", "n1", 0, ByteBuffer.wrap(new Array[Byte](0)))
+ val taskDesc = new TaskDescription(1L, 0, "s1", "n1", 0, ByteBuffer.wrap(new Array[Byte](0)))
EasyMock.expect(taskScheduler.resourceOffers(EasyMock.eq(expectedWorkerOffers))).andReturn(Seq(Seq(taskDesc)))
EasyMock.expect(taskScheduler.CPUS_PER_TASK).andReturn(2).anyTimes()
EasyMock.replay(taskScheduler)
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index 1eaabb93ad..37b593b2c5 100644
--- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -89,7 +89,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite {
)
val iterator = new ShuffleBlockFetcherIterator(
- new TaskContextImpl(0, 0, 0),
+ new TaskContextImpl(0, 0, 0, 0),
transfer,
blockManager,
blocksByAddress,
@@ -154,7 +154,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite {
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))
- val taskContext = new TaskContextImpl(0, 0, 0)
+ val taskContext = new TaskContextImpl(0, 0, 0, 0)
val iterator = new ShuffleBlockFetcherIterator(
taskContext,
transfer,
@@ -217,7 +217,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite {
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))
- val taskContext = new TaskContextImpl(0, 0, 0)
+ val taskContext = new TaskContextImpl(0, 0, 0, 0)
val iterator = new ShuffleBlockFetcherIterator(
taskContext,
transfer,
diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala
index 787f4c2b5a..e85a436cdb 100644
--- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala
@@ -173,7 +173,7 @@ class UISeleniumSuite extends FunSuite with WebBrowser with Matchers {
// Simulate fetch failures:
val mappedData = data.map { x =>
val taskContext = TaskContext.get
- if (taskContext.attemptId() == 1) { // Cause this stage to fail on its first attempt.
+ if (taskContext.attemptNumber == 0) { // Cause this stage to fail on its first attempt.
val env = SparkEnv.get
val bmAddress = env.blockManager.blockManagerId
val shuffleId = shuffleHandle.shuffleId