aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorImran Rashid <irashid@cloudera.com>2015-07-20 10:28:32 -0700
committerKay Ousterhout <kayousterhout@gmail.com>2015-07-20 10:28:32 -0700
commit80e2568b25780a7094199239da8ad6cfb6efc9f7 (patch)
tree317142655edc12ed5edc07820f72282f52dc05de
parentc6fe9b4a179eecce69a813501dd0f4a22ff5dd5b (diff)
downloadspark-80e2568b25780a7094199239da8ad6cfb6efc9f7.tar.gz
spark-80e2568b25780a7094199239da8ad6cfb6efc9f7.tar.bz2
spark-80e2568b25780a7094199239da8ad6cfb6efc9f7.zip
[SPARK-8103][core] DAGScheduler should not submit multiple concurrent attempts for a stage
https://issues.apache.org/jira/browse/SPARK-8103 cc kayousterhout (thanks for the extra test case) Author: Imran Rashid <irashid@cloudera.com> Author: Kay Ousterhout <kayousterhout@gmail.com> Author: Imran Rashid <squito@users.noreply.github.com> Closes #6750 from squito/SPARK-8103 and squashes the following commits: fb3acfc [Imran Rashid] fix log msg e01b7aa [Imran Rashid] fix some comments, style 584acd4 [Imran Rashid] simplify going from taskId to taskSetMgr e43ac25 [Imran Rashid] Merge branch 'master' into SPARK-8103 6bc23af [Imran Rashid] update log msg 4470fa1 [Imran Rashid] rename c04707e [Imran Rashid] style 88b61cc [Imran Rashid] add tests to make sure that TaskSchedulerImpl schedules correctly with zombie attempts d7f1ef2 [Imran Rashid] get rid of activeTaskSets a21c8b5 [Imran Rashid] Merge branch 'master' into SPARK-8103 906d626 [Imran Rashid] fix merge 109900e [Imran Rashid] Merge branch 'master' into SPARK-8103 c0d4d90 [Imran Rashid] Revert "Index active task sets by stage Id rather than by task set id" f025154 [Imran Rashid] Merge pull request #2 from kayousterhout/imran_SPARK-8103 baf46e1 [Kay Ousterhout] Index active task sets by stage Id rather than by task set id 19685bb [Imran Rashid] switch to using latestInfo.attemptId, and add comments a5f7c8c [Imran Rashid] remove comment for reviewers 227b40d [Imran Rashid] style 517b6e5 [Imran Rashid] get rid of SparkIllegalStateException b2faef5 [Imran Rashid] faster check for conflicting task sets 6542b42 [Imran Rashid] remove extra stageAttemptId ada7726 [Imran Rashid] reviewer feedback d8eb202 [Imran Rashid] Merge branch 'master' into SPARK-8103 46bc26a [Imran Rashid] more cleanup of debug garbage cb245da [Imran Rashid] finally found the issue ... clean up debug stuff 8c29707 [Imran Rashid] Merge branch 'master' into SPARK-8103 89a59b6 [Imran Rashid] more printlns ... 9601b47 [Imran Rashid] more debug printlns ecb4e7d [Imran Rashid] debugging printlns b6bc248 [Imran Rashid] style 55f4a94 [Imran Rashid] get rid of more random test case since kays tests are clearer 7021d28 [Imran Rashid] update test since listenerBus.waitUntilEmpty now throws an exception instead of returning a boolean 883fe49 [Kay Ousterhout] Unit tests for concurrent stages issue 6e14683 [Imran Rashid] unit test just to make sure we fail fast on concurrent attempts 06a0af6 [Imran Rashid] ignore for jenkins c443def [Imran Rashid] better fix and simpler test case 28d70aa [Imran Rashid] wip on getting a better test case ... a9bf31f [Imran Rashid] wip
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala78
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Task.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala99
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala5
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala141
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala8
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala113
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala2
13 files changed, 383 insertions, 86 deletions
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index dd55cd8054..71a219a4f3 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -857,7 +857,6 @@ class DAGScheduler(
// Get our pending tasks and remember them in our pendingTasks entry
stage.pendingTasks.clear()
-
// First figure out the indexes of partition ids to compute.
val partitionsToCompute: Seq[Int] = {
stage match {
@@ -918,7 +917,7 @@ class DAGScheduler(
partitionsToCompute.map { id =>
val locs = getPreferredLocs(stage.rdd, id)
val part = stage.rdd.partitions(id)
- new ShuffleMapTask(stage.id, taskBinary, part, locs)
+ new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs)
}
case stage: ResultStage =>
@@ -927,7 +926,7 @@ class DAGScheduler(
val p: Int = job.partitions(id)
val part = stage.rdd.partitions(p)
val locs = getPreferredLocs(stage.rdd, p)
- new ResultTask(stage.id, taskBinary, part, locs, id)
+ new ResultTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs, id)
}
}
} catch {
@@ -1069,10 +1068,11 @@ class DAGScheduler(
val execId = status.location.executorId
logDebug("ShuffleMapTask finished on " + execId)
if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) {
- logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId)
+ logInfo(s"Ignoring possibly bogus $smt completion from executor $execId")
} else {
shuffleStage.addOutputLoc(smt.partitionId, status)
}
+
if (runningStages.contains(shuffleStage) && shuffleStage.pendingTasks.isEmpty) {
markStageAsFinished(shuffleStage)
logInfo("looking for newly runnable stages")
@@ -1132,38 +1132,48 @@ class DAGScheduler(
val failedStage = stageIdToStage(task.stageId)
val mapStage = shuffleToMapStage(shuffleId)
- // It is likely that we receive multiple FetchFailed for a single stage (because we have
- // multiple tasks running concurrently on different executors). In that case, it is possible
- // the fetch failure has already been handled by the scheduler.
- if (runningStages.contains(failedStage)) {
- logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
- s"due to a fetch failure from $mapStage (${mapStage.name})")
- markStageAsFinished(failedStage, Some(failureMessage))
- }
+ if (failedStage.latestInfo.attemptId != task.stageAttemptId) {
+ logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" +
+ s" ${task.stageAttemptId} and there is a more recent attempt for that stage " +
+ s"(attempt ID ${failedStage.latestInfo.attemptId}) running")
+ } else {
- if (disallowStageRetryForTest) {
- abortStage(failedStage, "Fetch failure will not retry stage due to testing config")
- } else if (failedStages.isEmpty) {
- // Don't schedule an event to resubmit failed stages if failed isn't empty, because
- // in that case the event will already have been scheduled.
- // TODO: Cancel running tasks in the stage
- logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " +
- s"$failedStage (${failedStage.name}) due to fetch failure")
- messageScheduler.schedule(new Runnable {
- override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
- }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS)
- }
- failedStages += failedStage
- failedStages += mapStage
- // Mark the map whose fetch failed as broken in the map stage
- if (mapId != -1) {
- mapStage.removeOutputLoc(mapId, bmAddress)
- mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
- }
+ // It is likely that we receive multiple FetchFailed for a single stage (because we have
+ // multiple tasks running concurrently on different executors). In that case, it is
+ // possible the fetch failure has already been handled by the scheduler.
+ if (runningStages.contains(failedStage)) {
+ logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
+ s"due to a fetch failure from $mapStage (${mapStage.name})")
+ markStageAsFinished(failedStage, Some(failureMessage))
+ } else {
+ logDebug(s"Received fetch failure from $task, but its from $failedStage which is no " +
+ s"longer running")
+ }
+
+ if (disallowStageRetryForTest) {
+ abortStage(failedStage, "Fetch failure will not retry stage due to testing config")
+ } else if (failedStages.isEmpty) {
+ // Don't schedule an event to resubmit failed stages if failed isn't empty, because
+ // in that case the event will already have been scheduled.
+ // TODO: Cancel running tasks in the stage
+ logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " +
+ s"$failedStage (${failedStage.name}) due to fetch failure")
+ messageScheduler.schedule(new Runnable {
+ override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
+ }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS)
+ }
+ failedStages += failedStage
+ failedStages += mapStage
+ // Mark the map whose fetch failed as broken in the map stage
+ if (mapId != -1) {
+ mapStage.removeOutputLoc(mapId, bmAddress)
+ mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
+ }
- // TODO: mark the executor as failed only if there were lots of fetch failures on it
- if (bmAddress != null) {
- handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch))
+ // TODO: mark the executor as failed only if there were lots of fetch failures on it
+ if (bmAddress != null) {
+ handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch))
+ }
}
case commitDenied: TaskCommitDenied =>
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
index c9a1241139..9c2606e278 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
@@ -41,11 +41,12 @@ import org.apache.spark.rdd.RDD
*/
private[spark] class ResultTask[T, U](
stageId: Int,
+ stageAttemptId: Int,
taskBinary: Broadcast[Array[Byte]],
partition: Partition,
@transient locs: Seq[TaskLocation],
val outputId: Int)
- extends Task[U](stageId, partition.index) with Serializable {
+ extends Task[U](stageId, stageAttemptId, partition.index) with Serializable {
@transient private[this] val preferredLocs: Seq[TaskLocation] = {
if (locs == null) Nil else locs.toSet.toSeq
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index bd3dd23dfe..14c8c00961 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -40,14 +40,15 @@ import org.apache.spark.shuffle.ShuffleWriter
*/
private[spark] class ShuffleMapTask(
stageId: Int,
+ stageAttemptId: Int,
taskBinary: Broadcast[Array[Byte]],
partition: Partition,
@transient private var locs: Seq[TaskLocation])
- extends Task[MapStatus](stageId, partition.index) with Logging {
+ extends Task[MapStatus](stageId, stageAttemptId, partition.index) with Logging {
/** A constructor used only in test suites. This does not require passing in an RDD. */
def this(partitionId: Int) {
- this(0, null, new Partition { override def index: Int = 0 }, null)
+ this(0, 0, null, new Partition { override def index: Int = 0 }, null)
}
@transient private val preferredLocs: Seq[TaskLocation] = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 6a86f9d4b8..76a19aeac4 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -43,7 +43,10 @@ import org.apache.spark.util.Utils
* @param stageId id of the stage this task belongs to
* @param partitionId index of the number in the RDD
*/
-private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
+private[spark] abstract class Task[T](
+ val stageId: Int,
+ val stageAttemptId: Int,
+ var partitionId: Int) extends Serializable {
/**
* The key of the Map is the accumulator id and the value of the Map is the latest accumulator
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index ed3dde0fc3..1705e7f962 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -75,9 +75,9 @@ private[spark] class TaskSchedulerImpl(
// TaskSetManagers are not thread safe, so any access to one should be synchronized
// on this class.
- val activeTaskSets = new HashMap[String, TaskSetManager]
+ private val taskSetsByStageIdAndAttempt = new HashMap[Int, HashMap[Int, TaskSetManager]]
- val taskIdToTaskSetId = new HashMap[Long, String]
+ private[scheduler] val taskIdToTaskSetManager = new HashMap[Long, TaskSetManager]
val taskIdToExecutorId = new HashMap[Long, String]
@volatile private var hasReceivedTask = false
@@ -162,7 +162,17 @@ private[spark] class TaskSchedulerImpl(
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
this.synchronized {
val manager = createTaskSetManager(taskSet, maxTaskFailures)
- activeTaskSets(taskSet.id) = manager
+ val stage = taskSet.stageId
+ val stageTaskSets =
+ taskSetsByStageIdAndAttempt.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager])
+ stageTaskSets(taskSet.stageAttemptId) = manager
+ val conflictingTaskSet = stageTaskSets.exists { case (_, ts) =>
+ ts.taskSet != taskSet && !ts.isZombie
+ }
+ if (conflictingTaskSet) {
+ throw new IllegalStateException(s"more than one active taskSet for stage $stage:" +
+ s" ${stageTaskSets.toSeq.map{_._2.taskSet.id}.mkString(",")}")
+ }
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
if (!isLocal && !hasReceivedTask) {
@@ -192,19 +202,21 @@ private[spark] class TaskSchedulerImpl(
override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized {
logInfo("Cancelling stage " + stageId)
- activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
- // There are two possible cases here:
- // 1. The task set manager has been created and some tasks have been scheduled.
- // In this case, send a kill signal to the executors to kill the task and then abort
- // the stage.
- // 2. The task set manager has been created but no tasks has been scheduled. In this case,
- // simply abort the stage.
- tsm.runningTasksSet.foreach { tid =>
- val execId = taskIdToExecutorId(tid)
- backend.killTask(tid, execId, interruptThread)
+ taskSetsByStageIdAndAttempt.get(stageId).foreach { attempts =>
+ attempts.foreach { case (_, tsm) =>
+ // There are two possible cases here:
+ // 1. The task set manager has been created and some tasks have been scheduled.
+ // In this case, send a kill signal to the executors to kill the task and then abort
+ // the stage.
+ // 2. The task set manager has been created but no tasks has been scheduled. In this case,
+ // simply abort the stage.
+ tsm.runningTasksSet.foreach { tid =>
+ val execId = taskIdToExecutorId(tid)
+ backend.killTask(tid, execId, interruptThread)
+ }
+ tsm.abort("Stage %s cancelled".format(stageId))
+ logInfo("Stage %d was cancelled".format(stageId))
}
- tsm.abort("Stage %s cancelled".format(stageId))
- logInfo("Stage %d was cancelled".format(stageId))
}
}
@@ -214,7 +226,12 @@ private[spark] class TaskSchedulerImpl(
* cleaned up.
*/
def taskSetFinished(manager: TaskSetManager): Unit = synchronized {
- activeTaskSets -= manager.taskSet.id
+ taskSetsByStageIdAndAttempt.get(manager.taskSet.stageId).foreach { taskSetsForStage =>
+ taskSetsForStage -= manager.taskSet.stageAttemptId
+ if (taskSetsForStage.isEmpty) {
+ taskSetsByStageIdAndAttempt -= manager.taskSet.stageId
+ }
+ }
manager.parent.removeSchedulable(manager)
logInfo("Removed TaskSet %s, whose tasks have all completed, from pool %s"
.format(manager.taskSet.id, manager.parent.name))
@@ -235,7 +252,7 @@ private[spark] class TaskSchedulerImpl(
for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
tasks(i) += task
val tid = task.taskId
- taskIdToTaskSetId(tid) = taskSet.taskSet.id
+ taskIdToTaskSetManager(tid) = taskSet
taskIdToExecutorId(tid) = execId
executorsByHost(host) += execId
availableCpus(i) -= CPUS_PER_TASK
@@ -319,26 +336,24 @@ private[spark] class TaskSchedulerImpl(
failedExecutor = Some(execId)
}
}
- taskIdToTaskSetId.get(tid) match {
- case Some(taskSetId) =>
+ taskIdToTaskSetManager.get(tid) match {
+ case Some(taskSet) =>
if (TaskState.isFinished(state)) {
- taskIdToTaskSetId.remove(tid)
+ taskIdToTaskSetManager.remove(tid)
taskIdToExecutorId.remove(tid)
}
- activeTaskSets.get(taskSetId).foreach { taskSet =>
- if (state == TaskState.FINISHED) {
- taskSet.removeRunningTask(tid)
- taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
- } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
- taskSet.removeRunningTask(tid)
- taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
- }
+ if (state == TaskState.FINISHED) {
+ taskSet.removeRunningTask(tid)
+ taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
+ } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
+ taskSet.removeRunningTask(tid)
+ taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
}
case None =>
logError(
("Ignoring update with state %s for TID %s because its task set is gone (this is " +
- "likely the result of receiving duplicate task finished status updates)")
- .format(state, tid))
+ "likely the result of receiving duplicate task finished status updates)")
+ .format(state, tid))
}
} catch {
case e: Exception => logError("Exception in statusUpdate", e)
@@ -363,9 +378,9 @@ private[spark] class TaskSchedulerImpl(
val metricsWithStageIds: Array[(Long, Int, Int, TaskMetrics)] = synchronized {
taskMetrics.flatMap { case (id, metrics) =>
- taskIdToTaskSetId.get(id)
- .flatMap(activeTaskSets.get)
- .map(taskSetMgr => (id, taskSetMgr.stageId, taskSetMgr.taskSet.attempt, metrics))
+ taskIdToTaskSetManager.get(id).map { taskSetMgr =>
+ (id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics)
+ }
}
}
dagScheduler.executorHeartbeatReceived(execId, metricsWithStageIds, blockManagerId)
@@ -397,9 +412,12 @@ private[spark] class TaskSchedulerImpl(
def error(message: String) {
synchronized {
- if (activeTaskSets.nonEmpty) {
+ if (taskSetsByStageIdAndAttempt.nonEmpty) {
// Have each task set throw a SparkException with the error
- for ((taskSetId, manager) <- activeTaskSets) {
+ for {
+ attempts <- taskSetsByStageIdAndAttempt.values
+ manager <- attempts.values
+ } {
try {
manager.abort(message)
} catch {
@@ -520,6 +538,17 @@ private[spark] class TaskSchedulerImpl(
override def applicationAttemptId(): Option[String] = backend.applicationAttemptId()
+ private[scheduler] def taskSetManagerForAttempt(
+ stageId: Int,
+ stageAttemptId: Int): Option[TaskSetManager] = {
+ for {
+ attempts <- taskSetsByStageIdAndAttempt.get(stageId)
+ manager <- attempts.get(stageAttemptId)
+ } yield {
+ manager
+ }
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala
index c3ad325156..be8526ba9b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala
@@ -26,10 +26,10 @@ import java.util.Properties
private[spark] class TaskSet(
val tasks: Array[Task[_]],
val stageId: Int,
- val attempt: Int,
+ val stageAttemptId: Int,
val priority: Int,
val properties: Properties) {
- val id: String = stageId + "." + attempt
+ val id: String = stageId + "." + stageAttemptId
override def toString: String = "TaskSet " + id
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 0e3215d6e9..f14c603ac6 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -191,15 +191,14 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
for (task <- tasks.flatten) {
val serializedTask = ser.serialize(task)
if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
- val taskSetId = scheduler.taskIdToTaskSetId(task.taskId)
- scheduler.activeTaskSets.get(taskSetId).foreach { taskSet =>
+ scheduler.taskIdToTaskSetManager.get(task.taskId).foreach { taskSetMgr =>
try {
var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " +
"spark.akka.frameSize (%d bytes) - reserved (%d bytes). Consider increasing " +
"spark.akka.frameSize or using broadcast variables for large values."
msg = msg.format(task.taskId, task.index, serializedTask.limit, akkaFrameSize,
AkkaUtils.reservedSizeBytes)
- taskSet.abort(msg)
+ taskSetMgr.abort(msg)
} catch {
case e: Exception => logError("Exception in error callback", e)
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 4f2b0fa162..86728cb2b6 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -101,9 +101,15 @@ class DAGSchedulerSuite
/** Length of time to wait while draining listener events. */
val WAIT_TIMEOUT_MILLIS = 10000
val sparkListener = new SparkListener() {
+ val submittedStageInfos = new HashSet[StageInfo]
val successfulStages = new HashSet[Int]
val failedStages = new ArrayBuffer[Int]
val stageByOrderOfExecution = new ArrayBuffer[Int]
+
+ override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
+ submittedStageInfos += stageSubmitted.stageInfo
+ }
+
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) {
val stageInfo = stageCompleted.stageInfo
stageByOrderOfExecution += stageInfo.stageId
@@ -150,6 +156,7 @@ class DAGSchedulerSuite
// Enable local execution for this test
val conf = new SparkConf().set("spark.localExecution.enabled", "true")
sc = new SparkContext("local", "DAGSchedulerSuite", conf)
+ sparkListener.submittedStageInfos.clear()
sparkListener.successfulStages.clear()
sparkListener.failedStages.clear()
failure = null
@@ -547,6 +554,140 @@ class DAGSchedulerSuite
assert(sparkListener.failedStages.size == 1)
}
+ /**
+ * This tests the case where another FetchFailed comes in while the map stage is getting
+ * re-run.
+ */
+ test("late fetch failures don't cause multiple concurrent attempts for the same map stage") {
+ val shuffleMapRdd = new MyRDD(sc, 2, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+ val shuffleId = shuffleDep.shuffleId
+ val reduceRdd = new MyRDD(sc, 2, List(shuffleDep))
+ submit(reduceRdd, Array(0, 1))
+
+ val mapStageId = 0
+ def countSubmittedMapStageAttempts(): Int = {
+ sparkListener.submittedStageInfos.count(_.stageId == mapStageId)
+ }
+
+ // The map stage should have been submitted.
+ assert(countSubmittedMapStageAttempts() === 1)
+
+ complete(taskSets(0), Seq(
+ (Success, makeMapStatus("hostA", 2)),
+ (Success, makeMapStatus("hostB", 2))))
+ // The MapOutputTracker should know about both map output locations.
+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) ===
+ Array("hostA", "hostB"))
+ assert(mapOutputTracker.getServerStatuses(shuffleId, 1).map(_._1.host) ===
+ Array("hostA", "hostB"))
+
+ // The first result task fails, with a fetch failure for the output from the first mapper.
+ runEvent(CompletionEvent(
+ taskSets(1).tasks(0),
+ FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"),
+ null,
+ Map[Long, Any](),
+ createFakeTaskInfo(),
+ null))
+ sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+ assert(sparkListener.failedStages.contains(1))
+
+ // Trigger resubmission of the failed map stage.
+ runEvent(ResubmitFailedStages)
+ sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+
+ // Another attempt for the map stage should have been submitted, resulting in 2 total attempts.
+ assert(countSubmittedMapStageAttempts() === 2)
+
+ // The second ResultTask fails, with a fetch failure for the output from the second mapper.
+ runEvent(CompletionEvent(
+ taskSets(1).tasks(1),
+ FetchFailed(makeBlockManagerId("hostB"), shuffleId, 1, 1, "ignored"),
+ null,
+ Map[Long, Any](),
+ createFakeTaskInfo(),
+ null))
+
+ // Another ResubmitFailedStages event should not result in another attempt for the map
+ // stage being run concurrently.
+ // NOTE: the actual ResubmitFailedStages may get called at any time during this, but it
+ // shouldn't effect anything -- our calling it just makes *SURE* it gets called between the
+ // desired event and our check.
+ runEvent(ResubmitFailedStages)
+ sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+ assert(countSubmittedMapStageAttempts() === 2)
+
+ }
+
+ /**
+ * This tests the case where a late FetchFailed comes in after the map stage has finished getting
+ * retried and a new reduce stage starts running.
+ */
+ test("extremely late fetch failures don't cause multiple concurrent attempts for " +
+ "the same stage") {
+ val shuffleMapRdd = new MyRDD(sc, 2, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+ val shuffleId = shuffleDep.shuffleId
+ val reduceRdd = new MyRDD(sc, 2, List(shuffleDep))
+ submit(reduceRdd, Array(0, 1))
+
+ def countSubmittedReduceStageAttempts(): Int = {
+ sparkListener.submittedStageInfos.count(_.stageId == 1)
+ }
+ def countSubmittedMapStageAttempts(): Int = {
+ sparkListener.submittedStageInfos.count(_.stageId == 0)
+ }
+
+ // The map stage should have been submitted.
+ sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+ assert(countSubmittedMapStageAttempts() === 1)
+
+ // Complete the map stage.
+ complete(taskSets(0), Seq(
+ (Success, makeMapStatus("hostA", 2)),
+ (Success, makeMapStatus("hostB", 2))))
+
+ // The reduce stage should have been submitted.
+ sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+ assert(countSubmittedReduceStageAttempts() === 1)
+
+ // The first result task fails, with a fetch failure for the output from the first mapper.
+ runEvent(CompletionEvent(
+ taskSets(1).tasks(0),
+ FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"),
+ null,
+ Map[Long, Any](),
+ createFakeTaskInfo(),
+ null))
+
+ // Trigger resubmission of the failed map stage and finish the re-started map task.
+ runEvent(ResubmitFailedStages)
+ complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1))))
+
+ // Because the map stage finished, another attempt for the reduce stage should have been
+ // submitted, resulting in 2 total attempts for each the map and the reduce stage.
+ sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
+ assert(countSubmittedMapStageAttempts() === 2)
+ assert(countSubmittedReduceStageAttempts() === 2)
+
+ // A late FetchFailed arrives from the second task in the original reduce stage.
+ runEvent(CompletionEvent(
+ taskSets(1).tasks(1),
+ FetchFailed(makeBlockManagerId("hostB"), shuffleId, 1, 1, "ignored"),
+ null,
+ Map[Long, Any](),
+ createFakeTaskInfo(),
+ null))
+
+ // Running ResubmitFailedStages shouldn't result in any more attempts for the map stage, because
+ // the FetchFailed should have been ignored
+ runEvent(ResubmitFailedStages)
+
+ // The FetchFailed from the original reduce stage should be ignored.
+ assert(countSubmittedMapStageAttempts() === 2)
+ }
+
test("ignore late map task completions") {
val shuffleMapRdd = new MyRDD(sc, 2, Nil)
val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
index 0a7cb69416..b3ca150195 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
@@ -19,7 +19,7 @@ package org.apache.spark.scheduler
import org.apache.spark.TaskContext
-class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0) {
+class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0, 0) {
override def runTask(context: TaskContext): Int = 0
override def preferredLocations: Seq[TaskLocation] = prefLocs
@@ -31,12 +31,16 @@ object FakeTask {
* locations for each task (given as varargs) if this sequence is not empty.
*/
def createTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = {
+ createTaskSet(numTasks, 0, prefLocs: _*)
+ }
+
+ def createTaskSet(numTasks: Int, stageAttemptId: Int, prefLocs: Seq[TaskLocation]*): TaskSet = {
if (prefLocs.size != 0 && prefLocs.size != numTasks) {
throw new IllegalArgumentException("Wrong number of task locations")
}
val tasks = Array.tabulate[Task[_]](numTasks) { i =>
new FakeTask(i, if (prefLocs.size != 0) prefLocs(i) else Nil)
}
- new TaskSet(tasks, 0, 0, 0, null)
+ new TaskSet(tasks, 0, stageAttemptId, 0, null)
}
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
index 9b92f8de56..383855caef 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala
@@ -25,7 +25,7 @@ import org.apache.spark.TaskContext
* A Task implementation that fails to serialize.
*/
private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int)
- extends Task[Array[Byte]](stageId, 0) {
+ extends Task[Array[Byte]](stageId, 0, 0) {
override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte]
override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]()
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index 7c1adc1aef..b9b0eccb0d 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -41,8 +41,8 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
}
val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
val func = (c: TaskContext, i: Iterator[String]) => i.next()
- val task = new ResultTask[String, String](
- 0, sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0)
+ val task = new ResultTask[String, String](0, 0,
+ sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0)
intercept[RuntimeException] {
task.run(0, 0)
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
index a6d5232feb..c2edd4c317 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
@@ -33,7 +33,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L
val taskScheduler = new TaskSchedulerImpl(sc)
taskScheduler.initialize(new FakeSchedulerBackend)
// Need to initialize a DAGScheduler for the taskScheduler to use for callbacks.
- val dagScheduler = new DAGScheduler(sc, taskScheduler) {
+ new DAGScheduler(sc, taskScheduler) {
override def taskStarted(task: Task[_], taskInfo: TaskInfo) {}
override def executorAdded(execId: String, host: String) {}
}
@@ -67,7 +67,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L
val taskScheduler = new TaskSchedulerImpl(sc)
taskScheduler.initialize(new FakeSchedulerBackend)
// Need to initialize a DAGScheduler for the taskScheduler to use for callbacks.
- val dagScheduler = new DAGScheduler(sc, taskScheduler) {
+ new DAGScheduler(sc, taskScheduler) {
override def taskStarted(task: Task[_], taskInfo: TaskInfo) {}
override def executorAdded(execId: String, host: String) {}
}
@@ -128,4 +128,113 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with L
assert(taskDescriptions.map(_.executorId) === Seq("executor0"))
}
+ test("refuse to schedule concurrent attempts for the same stage (SPARK-8103)") {
+ sc = new SparkContext("local", "TaskSchedulerImplSuite")
+ val taskScheduler = new TaskSchedulerImpl(sc)
+ taskScheduler.initialize(new FakeSchedulerBackend)
+ // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks.
+ val dagScheduler = new DAGScheduler(sc, taskScheduler) {
+ override def taskStarted(task: Task[_], taskInfo: TaskInfo) {}
+ override def executorAdded(execId: String, host: String) {}
+ }
+ taskScheduler.setDAGScheduler(dagScheduler)
+ val attempt1 = FakeTask.createTaskSet(1, 0)
+ val attempt2 = FakeTask.createTaskSet(1, 1)
+ taskScheduler.submitTasks(attempt1)
+ intercept[IllegalStateException] { taskScheduler.submitTasks(attempt2) }
+
+ // OK to submit multiple if previous attempts are all zombie
+ taskScheduler.taskSetManagerForAttempt(attempt1.stageId, attempt1.stageAttemptId)
+ .get.isZombie = true
+ taskScheduler.submitTasks(attempt2)
+ val attempt3 = FakeTask.createTaskSet(1, 2)
+ intercept[IllegalStateException] { taskScheduler.submitTasks(attempt3) }
+ taskScheduler.taskSetManagerForAttempt(attempt2.stageId, attempt2.stageAttemptId)
+ .get.isZombie = true
+ taskScheduler.submitTasks(attempt3)
+ }
+
+ test("don't schedule more tasks after a taskset is zombie") {
+ sc = new SparkContext("local", "TaskSchedulerImplSuite")
+ val taskScheduler = new TaskSchedulerImpl(sc)
+ taskScheduler.initialize(new FakeSchedulerBackend)
+ // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks.
+ new DAGScheduler(sc, taskScheduler) {
+ override def taskStarted(task: Task[_], taskInfo: TaskInfo) {}
+ override def executorAdded(execId: String, host: String) {}
+ }
+
+ val numFreeCores = 1
+ val workerOffers = Seq(new WorkerOffer("executor0", "host0", numFreeCores))
+ val attempt1 = FakeTask.createTaskSet(10)
+
+ // submit attempt 1, offer some resources, some tasks get scheduled
+ taskScheduler.submitTasks(attempt1)
+ val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten
+ assert(1 === taskDescriptions.length)
+
+ // now mark attempt 1 as a zombie
+ taskScheduler.taskSetManagerForAttempt(attempt1.stageId, attempt1.stageAttemptId)
+ .get.isZombie = true
+
+ // don't schedule anything on another resource offer
+ val taskDescriptions2 = taskScheduler.resourceOffers(workerOffers).flatten
+ assert(0 === taskDescriptions2.length)
+
+ // if we schedule another attempt for the same stage, it should get scheduled
+ val attempt2 = FakeTask.createTaskSet(10, 1)
+
+ // submit attempt 2, offer some resources, some tasks get scheduled
+ taskScheduler.submitTasks(attempt2)
+ val taskDescriptions3 = taskScheduler.resourceOffers(workerOffers).flatten
+ assert(1 === taskDescriptions3.length)
+ val mgr = taskScheduler.taskIdToTaskSetManager.get(taskDescriptions3(0).taskId).get
+ assert(mgr.taskSet.stageAttemptId === 1)
+ }
+
+ test("if a zombie attempt finishes, continue scheduling tasks for non-zombie attempts") {
+ sc = new SparkContext("local", "TaskSchedulerImplSuite")
+ val taskScheduler = new TaskSchedulerImpl(sc)
+ taskScheduler.initialize(new FakeSchedulerBackend)
+ // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks.
+ new DAGScheduler(sc, taskScheduler) {
+ override def taskStarted(task: Task[_], taskInfo: TaskInfo) {}
+ override def executorAdded(execId: String, host: String) {}
+ }
+
+ val numFreeCores = 10
+ val workerOffers = Seq(new WorkerOffer("executor0", "host0", numFreeCores))
+ val attempt1 = FakeTask.createTaskSet(10)
+
+ // submit attempt 1, offer some resources, some tasks get scheduled
+ taskScheduler.submitTasks(attempt1)
+ val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten
+ assert(10 === taskDescriptions.length)
+
+ // now mark attempt 1 as a zombie
+ val mgr1 = taskScheduler.taskSetManagerForAttempt(attempt1.stageId, attempt1.stageAttemptId).get
+ mgr1.isZombie = true
+
+ // don't schedule anything on another resource offer
+ val taskDescriptions2 = taskScheduler.resourceOffers(workerOffers).flatten
+ assert(0 === taskDescriptions2.length)
+
+ // submit attempt 2
+ val attempt2 = FakeTask.createTaskSet(10, 1)
+ taskScheduler.submitTasks(attempt2)
+
+ // attempt 1 finished (this can happen even if it was marked zombie earlier -- all tasks were
+ // already submitted, and then they finish)
+ taskScheduler.taskSetFinished(mgr1)
+
+ // now with another resource offer, we should still schedule all the tasks in attempt2
+ val taskDescriptions3 = taskScheduler.resourceOffers(workerOffers).flatten
+ assert(10 === taskDescriptions3.length)
+
+ taskDescriptions3.foreach { task =>
+ val mgr = taskScheduler.taskIdToTaskSetManager.get(task.taskId).get
+ assert(mgr.taskSet.stageAttemptId === 1)
+ }
+ }
+
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index cdae0d83d0..3abb99c4b2 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -136,7 +136,7 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex
/**
* A Task implementation that results in a large serialized task.
*/
-class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0) {
+class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0) {
val randomBuffer = new Array[Byte](TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024)
val random = new Random(0)
random.nextBytes(randomBuffer)