aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKay Ousterhout <kayousterhout@gmail.com>2014-02-06 16:10:48 -0800
committerPatrick Wendell <pwendell@gmail.com>2014-02-06 16:10:48 -0800
commit18ad59e2c6b7bd009e8ba5ebf8fcf99630863029 (patch)
treee8eea9263dc23ab4b4508e0330425cccab0333ef
parent446403b63763157831ddbf6209044efc3cc7bf7c (diff)
downloadspark-18ad59e2c6b7bd009e8ba5ebf8fcf99630863029.tar.gz
spark-18ad59e2c6b7bd009e8ba5ebf8fcf99630863029.tar.bz2
spark-18ad59e2c6b7bd009e8ba5ebf8fcf99630863029.zip
Merge pull request #321 from kayousterhout/ui_kill_fix. Closes #321.
Inform DAG scheduler about all started/finished tasks. Previously, the DAG scheduler was not always informed when tasks started and finished. The simplest example here is for speculated tasks: the DAGScheduler was only told about the first attempt of a task, meaning that SparkListeners were also not told about multiple task attempts, so users can't see what's going on with speculation in the UI. The DAGScheduler also wasn't always told about finished tasks, so in the UI, some tasks will never be shown as finished (this occurs, for example, if a task set gets killed). The other problem is that the fairness accounting was wrong -- the number of running tasks in a pool was decreased when a task set was considered done, even if all of its tasks hadn't yet finished. Author: Kay Ousterhout <kayousterhout@gmail.com> == Merge branch commits == commit c8d547d0f7a17f5a193bef05f5872b9f475675c5 Author: Kay Ousterhout <kayousterhout@gmail.com> Date: Wed Jan 15 16:47:33 2014 -0800 Addressed Reynold's review comments. Always use a TaskEndReason (remove the option), and explicitly signal when we don't know the reason. Also, always tell DAGScheduler (and associated listeners) about started tasks, even when they're speculated. commit 3fee1e2e3c06b975ff7f95d595448f38cce97a04 Author: Kay Ousterhout <kayousterhout@gmail.com> Date: Wed Jan 8 22:58:13 2014 -0800 Fixed broken test and improved logging commit ff12fcaa2567c5d02b75a1d5db35687225bcd46f Author: Kay Ousterhout <kayousterhout@gmail.com> Date: Sun Dec 29 21:08:20 2013 -0800 Inform DAG scheduler about all finished tasks. Previously, the DAG scheduler was not always informed when tasks finished. For example, when a task set was aborted, the DAG scheduler was never told when the tasks in that task set finished. The DAG scheduler was also never told about the completion of speculated tasks. This led to confusion with SparkListeners because information about the completion of those tasks was never passed on to the listeners (so in the UI, for example, some tasks will never be shown as finished). The other problem is that the fairness accounting was wrong -- the number of running tasks in a pool was decreased when a task set was considered done, even if all of its tasks hadn't yet finished.
-rw-r--r--core/src/main/scala/org/apache/spark/TaskEndReason.scala13
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala46
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala193
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala12
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala41
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala4
9 files changed, 183 insertions, 144 deletions
diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
index faf6dcd618..3fd6f5eb47 100644
--- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala
+++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
@@ -53,3 +53,16 @@ private[spark] case class ExceptionFailure(
private[spark] case object TaskResultLost extends TaskEndReason
private[spark] case object TaskKilled extends TaskEndReason
+
+/**
+ * The task failed because the executor that it was running on was lost. This may happen because
+ * the task crashed the JVM.
+ */
+private[spark] case object ExecutorLostFailure extends TaskEndReason
+
+/**
+ * We don't know why the task ended -- for example, because of a ClassNotFound exception when
+ * deserializing the task result.
+ */
+private[spark] case object UnknownReason extends TaskEndReason
+
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 237cbf4c0c..821241508e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -954,8 +954,8 @@ class DAGScheduler(
// Do nothing here; the TaskScheduler handles these failures and resubmits the task.
case other =>
- // Unrecognized failure - abort all jobs depending on this stage
- abortStage(stageIdToStage(task.stageId), task + " failed: " + other)
+ // Unrecognized failure - also do nothing. If the task fails repeatedly, the TaskScheduler
+ // will abort the job.
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
index e9f2198a00..c4d1ad5733 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
@@ -21,6 +21,12 @@ import scala.collection._
import org.apache.spark.executor.TaskMetrics
+/**
+ * Stores information about a stage to pass from the scheduler to SparkListeners.
+ *
+ * taskInfos stores the metrics for all tasks that have completed, including redundant, speculated
+ * tasks.
+ */
class StageInfo(
stage: Stage,
val taskInfos: mutable.Buffer[(TaskInfo, TaskMetrics)] = mutable.Buffer[(TaskInfo, TaskMetrics)]()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
index 35e9544718..bdec08e968 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
@@ -57,7 +57,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
* between when the task ended and when we tried to fetch the result, or if the
* block manager had to flush the result. */
scheduler.handleFailedTask(
- taskSetManager, tid, TaskState.FINISHED, Some(TaskResultLost))
+ taskSetManager, tid, TaskState.FINISHED, TaskResultLost)
return
}
val deserializedResult = serializer.get().deserialize[DirectTaskResult[_]](
@@ -80,13 +80,13 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
def enqueueFailedTask(taskSetManager: TaskSetManager, tid: Long, taskState: TaskState,
serializedData: ByteBuffer) {
- var reason: Option[TaskEndReason] = None
+ var reason : TaskEndReason = UnknownReason
getTaskResultExecutor.execute(new Runnable {
override def run() {
try {
if (serializedData != null && serializedData.limit() > 0) {
- reason = Some(serializer.get().deserialize[TaskEndReason](
- serializedData, getClass.getClassLoader))
+ reason = serializer.get().deserialize[TaskEndReason](
+ serializedData, getClass.getClassLoader)
}
} catch {
case cnd: ClassNotFoundException =>
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index 83ba584015..5b525155e9 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -67,7 +67,6 @@ private[spark] class TaskSchedulerImpl(
val taskIdToTaskSetId = new HashMap[Long, String]
val taskIdToExecutorId = new HashMap[Long, String]
- val taskSetTaskIds = new HashMap[String, HashSet[Long]]
@volatile private var hasReceivedTask = false
@volatile private var hasLaunchedTask = false
@@ -142,7 +141,6 @@ private[spark] class TaskSchedulerImpl(
val manager = new TaskSetManager(this, taskSet, maxTaskFailures)
activeTaskSets(taskSet.id) = manager
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
- taskSetTaskIds(taskSet.id) = new HashSet[Long]()
if (!isLocal && !hasReceivedTask) {
starvationTimer.scheduleAtFixedRate(new TimerTask() {
@@ -171,31 +169,25 @@ private[spark] class TaskSchedulerImpl(
// the stage.
// 2. The task set manager has been created but no tasks has been scheduled. In this case,
// simply abort the stage.
- val taskIds = taskSetTaskIds(tsm.taskSet.id)
- if (taskIds.size > 0) {
- taskIds.foreach { tid =>
- val execId = taskIdToExecutorId(tid)
- backend.killTask(tid, execId)
- }
+ tsm.runningTasksSet.foreach { tid =>
+ val execId = taskIdToExecutorId(tid)
+ backend.killTask(tid, execId)
}
+ tsm.abort("Stage %s cancelled".format(stageId))
logInfo("Stage %d was cancelled".format(stageId))
- tsm.removeAllRunningTasks()
- taskSetFinished(tsm)
}
}
+ /**
+ * Called to indicate that all task attempts (including speculated tasks) associated with the
+ * given TaskSetManager have completed, so state associated with the TaskSetManager should be
+ * cleaned up.
+ */
def taskSetFinished(manager: TaskSetManager): Unit = synchronized {
- // Check to see if the given task set has been removed. This is possible in the case of
- // multiple unrecoverable task failures (e.g. if the entire task set is killed when it has
- // more than one running tasks).
- if (activeTaskSets.contains(manager.taskSet.id)) {
- activeTaskSets -= manager.taskSet.id
- manager.parent.removeSchedulable(manager)
- logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name))
- taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
- taskIdToExecutorId --= taskSetTaskIds(manager.taskSet.id)
- taskSetTaskIds.remove(manager.taskSet.id)
- }
+ activeTaskSets -= manager.taskSet.id
+ manager.parent.removeSchedulable(manager)
+ logInfo("Removed TaskSet %s, whose tasks have all completed, from pool %s"
+ .format(manager.taskSet.id, manager.parent.name))
}
/**
@@ -237,7 +229,6 @@ private[spark] class TaskSchedulerImpl(
tasks(i) += task
val tid = task.taskId
taskIdToTaskSetId(tid) = taskSet.taskSet.id
- taskSetTaskIds(taskSet.taskSet.id) += tid
taskIdToExecutorId(tid) = execId
activeExecutorIds += execId
executorsByHost(host) += execId
@@ -270,9 +261,6 @@ private[spark] class TaskSchedulerImpl(
case Some(taskSetId) =>
if (TaskState.isFinished(state)) {
taskIdToTaskSetId.remove(tid)
- if (taskSetTaskIds.contains(taskSetId)) {
- taskSetTaskIds(taskSetId) -= tid
- }
taskIdToExecutorId.remove(tid)
}
activeTaskSets.get(taskSetId).foreach { taskSet =>
@@ -285,7 +273,9 @@ private[spark] class TaskSchedulerImpl(
}
}
case None =>
- logInfo("Ignoring update with state %s from TID %s because its task set is gone"
+ logError(
+ ("Ignoring update with state %s for TID %s because its task set is gone (this is " +
+ "likely the result of receiving duplicate task finished status updates)")
.format(state, tid))
}
} catch {
@@ -314,9 +304,9 @@ private[spark] class TaskSchedulerImpl(
taskSetManager: TaskSetManager,
tid: Long,
taskState: TaskState,
- reason: Option[TaskEndReason]) = synchronized {
+ reason: TaskEndReason) = synchronized {
taskSetManager.handleFailedTask(tid, taskState, reason)
- if (taskState != TaskState.KILLED) {
+ if (!taskSetManager.isZombie && taskState != TaskState.KILLED) {
// Need to revive offers again now that the task set manager state has been updated to
// reflect failed tasks that need to be re-run.
backend.reviveOffers()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 777f31dc5e..3f0ee7a6d4 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -26,9 +26,10 @@ import scala.collection.mutable.HashSet
import scala.math.max
import scala.math.min
-import org.apache.spark.{ExceptionFailure, FetchFailed, Logging, Resubmitted, SparkEnv,
- Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState}
+import org.apache.spark.{ExceptionFailure, ExecutorLostFailure, FetchFailed, Logging, Resubmitted,
+ SparkEnv, Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState}
import org.apache.spark.TaskState.TaskState
+import org.apache.spark.executor.TaskMetrics
import org.apache.spark.util.{Clock, SystemClock}
@@ -82,8 +83,16 @@ private[spark] class TaskSetManager(
var name = "TaskSet_"+taskSet.stageId.toString
var parent: Pool = null
- var runningTasks = 0
- private val runningTasksSet = new HashSet[Long]
+ val runningTasksSet = new HashSet[Long]
+ override def runningTasks = runningTasksSet.size
+
+ // True once no more tasks should be launched for this task set manager. TaskSetManagers enter
+ // the zombie state once at least one attempt of each task has completed successfully, or if the
+ // task set is aborted (for example, because it was killed). TaskSetManagers remain in the zombie
+ // state until all tasks have finished running; we keep TaskSetManagers that are in the zombie
+ // state in order to continue to track and account for the running tasks.
+ // TODO: We should kill any running task attempts when the task set manager becomes a zombie.
+ var isZombie = false
// Set of pending tasks for each executor. These collections are actually
// treated as stacks, in which new tasks are added to the end of the
@@ -345,7 +354,7 @@ private[spark] class TaskSetManager(
maxLocality: TaskLocality.TaskLocality)
: Option[TaskDescription] =
{
- if (tasksSuccessful < numTasks && availableCpus >= CPUS_PER_TASK) {
+ if (!isZombie && availableCpus >= CPUS_PER_TASK) {
val curTime = clock.getTime()
var allowedLocality = getAllowedLocalityLevel(curTime)
@@ -380,8 +389,7 @@ private[spark] class TaskSetManager(
logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
taskSet.id, index, serializedTask.limit, timeTaken))
val taskName = "task %s:%d".format(taskSet.id, index)
- if (taskAttempts(index).size == 1)
- taskStarted(task,info)
+ sched.dagScheduler.taskStarted(task, info)
return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask))
}
case _ =>
@@ -390,6 +398,12 @@ private[spark] class TaskSetManager(
None
}
+ private def maybeFinishTaskSet() {
+ if (isZombie && runningTasks == 0) {
+ sched.taskSetFinished(this)
+ }
+ }
+
/**
* Get the level we can launch tasks according to delay scheduling, based on current wait time.
*/
@@ -418,10 +432,6 @@ private[spark] class TaskSetManager(
index
}
- private def taskStarted(task: Task[_], info: TaskInfo) {
- sched.dagScheduler.taskStarted(task, info)
- }
-
def handleTaskGettingResult(tid: Long) = {
val info = taskInfos(tid)
info.markGettingResult()
@@ -436,123 +446,116 @@ private[spark] class TaskSetManager(
val index = info.index
info.markSuccessful()
removeRunningTask(tid)
+ sched.dagScheduler.taskEnded(
+ tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
if (!successful(index)) {
tasksSuccessful += 1
logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format(
tid, info.duration, info.host, tasksSuccessful, numTasks))
- sched.dagScheduler.taskEnded(
- tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
-
// Mark successful and stop if all the tasks have succeeded.
successful(index) = true
if (tasksSuccessful == numTasks) {
- sched.taskSetFinished(this)
+ isZombie = true
}
} else {
logInfo("Ignorning task-finished event for TID " + tid + " because task " +
index + " has already completed successfully")
}
+ maybeFinishTaskSet()
}
/**
* Marks the task as failed, re-adds it to the list of pending tasks, and notifies the
* DAG Scheduler.
*/
- def handleFailedTask(tid: Long, state: TaskState, reason: Option[TaskEndReason]) {
+ def handleFailedTask(tid: Long, state: TaskState, reason: TaskEndReason) {
val info = taskInfos(tid)
if (info.failed) {
return
}
removeRunningTask(tid)
- val index = info.index
info.markFailed()
- var failureReason = "unknown"
- if (!successful(index)) {
+ val index = info.index
+ copiesRunning(index) -= 1
+ if (!isZombie) {
logWarning("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
- copiesRunning(index) -= 1
- // Check if the problem is a map output fetch failure. In that case, this
- // task will never succeed on any node, so tell the scheduler about it.
- reason.foreach {
- case fetchFailed: FetchFailed =>
- logWarning("Loss was due to fetch failure from " + fetchFailed.bmAddress)
- sched.dagScheduler.taskEnded(tasks(index), fetchFailed, null, null, info, null)
+ }
+ var taskMetrics : TaskMetrics = null
+ var failureReason = "unknown"
+ reason match {
+ case fetchFailed: FetchFailed =>
+ logWarning("Loss was due to fetch failure from " + fetchFailed.bmAddress)
+ if (!successful(index)) {
successful(index) = true
tasksSuccessful += 1
- sched.taskSetFinished(this)
- removeAllRunningTasks()
- return
-
- case TaskKilled =>
- logWarning("Task %d was killed.".format(tid))
- sched.dagScheduler.taskEnded(tasks(index), reason.get, null, null, info, null)
+ }
+ isZombie = true
+
+ case TaskKilled =>
+ logWarning("Task %d was killed.".format(tid))
+
+ case ef: ExceptionFailure =>
+ taskMetrics = ef.metrics.getOrElse(null)
+ if (ef.className == classOf[NotSerializableException].getName()) {
+ // If the task result wasn't serializable, there's no point in trying to re-execute it.
+ logError("Task %s:%s had a not serializable result: %s; not retrying".format(
+ taskSet.id, index, ef.description))
+ abort("Task %s:%s had a not serializable result: %s".format(
+ taskSet.id, index, ef.description))
return
-
- case ef: ExceptionFailure =>
- sched.dagScheduler.taskEnded(
- tasks(index), ef, null, null, info, ef.metrics.getOrElse(null))
- if (ef.className == classOf[NotSerializableException].getName()) {
- // If the task result wasn't rerializable, there's no point in trying to re-execute it.
- logError("Task %s:%s had a not serializable result: %s; not retrying".format(
- taskSet.id, index, ef.description))
- abort("Task %s:%s had a not serializable result: %s".format(
- taskSet.id, index, ef.description))
- return
- }
- val key = ef.description
- failureReason = "Exception failure: %s".format(ef.description)
- val now = clock.getTime()
- val (printFull, dupCount) = {
- if (recentExceptions.contains(key)) {
- val (dupCount, printTime) = recentExceptions(key)
- if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
- recentExceptions(key) = (0, now)
- (true, 0)
- } else {
- recentExceptions(key) = (dupCount + 1, printTime)
- (false, dupCount + 1)
- }
- } else {
+ }
+ val key = ef.description
+ failureReason = "Exception failure: %s".format(ef.description)
+ val now = clock.getTime()
+ val (printFull, dupCount) = {
+ if (recentExceptions.contains(key)) {
+ val (dupCount, printTime) = recentExceptions(key)
+ if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
recentExceptions(key) = (0, now)
(true, 0)
+ } else {
+ recentExceptions(key) = (dupCount + 1, printTime)
+ (false, dupCount + 1)
}
- }
- if (printFull) {
- val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString))
- logWarning("Loss was due to %s\n%s\n%s".format(
- ef.className, ef.description, locs.mkString("\n")))
} else {
- logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount))
+ recentExceptions(key) = (0, now)
+ (true, 0)
}
+ }
+ if (printFull) {
+ val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString))
+ logWarning("Loss was due to %s\n%s\n%s".format(
+ ef.className, ef.description, locs.mkString("\n")))
+ } else {
+ logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount))
+ }
- case TaskResultLost =>
- failureReason = "Lost result for TID %s on host %s".format(tid, info.host)
- logWarning(failureReason)
- sched.dagScheduler.taskEnded(tasks(index), TaskResultLost, null, null, info, null)
+ case TaskResultLost =>
+ failureReason = "Lost result for TID %s on host %s".format(tid, info.host)
+ logWarning(failureReason)
- case _ => {}
- }
- // On non-fetch failures, re-enqueue the task as pending for a max number of retries
- addPendingTask(index)
- if (state != TaskState.KILLED) {
- numFailures(index) += 1
- if (numFailures(index) >= maxTaskFailures) {
- logError("Task %s:%d failed %d times; aborting job".format(
- taskSet.id, index, maxTaskFailures))
- abort("Task %s:%d failed %d times (most recent failure: %s)".format(
- taskSet.id, index, maxTaskFailures, failureReason))
- }
+ case _ => {}
+ }
+ sched.dagScheduler.taskEnded(tasks(index), reason, null, null, info, taskMetrics)
+ addPendingTask(index)
+ if (!isZombie && state != TaskState.KILLED) {
+ numFailures(index) += 1
+ if (numFailures(index) >= maxTaskFailures) {
+ logError("Task %s:%d failed %d times; aborting job".format(
+ taskSet.id, index, maxTaskFailures))
+ abort("Task %s:%d failed %d times (most recent failure: %s)".format(
+ taskSet.id, index, maxTaskFailures, failureReason))
+ return
}
- } else {
- logInfo("Ignoring task-lost event for TID " + tid +
- " because task " + index + " is already finished")
}
+ maybeFinishTaskSet()
}
def abort(message: String) {
// TODO: Kill running tasks if we were not terminated due to a Mesos error
sched.dagScheduler.taskSetFailed(taskSet, message)
- removeAllRunningTasks()
- sched.taskSetFinished(this)
+ isZombie = true
+ maybeFinishTaskSet()
}
/** If the given task ID is not in the set of running tasks, adds it.
@@ -563,7 +566,6 @@ private[spark] class TaskSetManager(
if (runningTasksSet.add(tid) && parent != null) {
parent.increaseRunningTasks(1)
}
- runningTasks = runningTasksSet.size
}
/** If the given task ID is in the set of running tasks, removes it. */
@@ -571,16 +573,6 @@ private[spark] class TaskSetManager(
if (runningTasksSet.remove(tid) && parent != null) {
parent.decreaseRunningTasks(1)
}
- runningTasks = runningTasksSet.size
- }
-
- private[scheduler] def removeAllRunningTasks() {
- val numRunningTasks = runningTasksSet.size
- runningTasksSet.clear()
- if (parent != null) {
- parent.decreaseRunningTasks(numRunningTasks)
- }
- runningTasks = 0
}
override def getSchedulableByName(name: String): Schedulable = {
@@ -629,7 +621,7 @@ private[spark] class TaskSetManager(
}
// Also re-enqueue any tasks that were running on the node
for ((tid, info) <- taskInfos if info.running && info.executorId == execId) {
- handleFailedTask(tid, TaskState.FAILED, None)
+ handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure)
}
}
@@ -641,8 +633,9 @@ private[spark] class TaskSetManager(
* we don't scan the whole task set. It might also help to make this sorted by launch time.
*/
override def checkSpeculatableTasks(): Boolean = {
- // Can't speculate if we only have one task, or if all tasks have finished.
- if (numTasks == 1 || tasksSuccessful == numTasks) {
+ // Can't speculate if we only have one task, and no need to speculate if the task set is a
+ // zombie.
+ if (isZombie || numTasks == 1) {
return false
}
var foundTasks = false
diff --git a/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala
index 235d31709a..98ea4cb561 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala
@@ -36,22 +36,24 @@ class FakeTaskSetManager(
parent = null
weight = 1
minShare = 2
- runningTasks = 0
priority = initPriority
stageId = initStageId
name = "TaskSet_"+stageId
override val numTasks = initNumTasks
tasksSuccessful = 0
+ var numRunningTasks = 0
+ override def runningTasks = numRunningTasks
+
def increaseRunningTasks(taskNum: Int) {
- runningTasks += taskNum
+ numRunningTasks += taskNum
if (parent != null) {
parent.increaseRunningTasks(taskNum)
}
}
def decreaseRunningTasks(taskNum: Int) {
- runningTasks -= taskNum
+ numRunningTasks -= taskNum
if (parent != null) {
parent.decreaseRunningTasks(taskNum)
}
@@ -77,7 +79,7 @@ class FakeTaskSetManager(
maxLocality: TaskLocality.TaskLocality)
: Option[TaskDescription] =
{
- if (tasksSuccessful + runningTasks < numTasks) {
+ if (tasksSuccessful + numRunningTasks < numTasks) {
increaseRunningTasks(1)
Some(new TaskDescription(0, execId, "task 0:0", 0, null))
} else {
@@ -98,7 +100,7 @@ class FakeTaskSetManager(
}
def abort() {
- decreaseRunningTasks(runningTasks)
+ decreaseRunningTasks(numRunningTasks)
parent.removeSchedulable(this)
}
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
index 1a16e438c4..368c5154ea 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
@@ -168,6 +168,39 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
assert(listener.endedTasks.contains(TASK_INDEX))
}
+ test("onTaskEnd() should be called for all started tasks, even after job has been killed") {
+ val WAIT_TIMEOUT_MILLIS = 10000
+ val listener = new SaveTaskEvents
+ sc.addSparkListener(listener)
+
+ val numTasks = 10
+ val f = sc.parallelize(1 to 10000, numTasks).map { i => Thread.sleep(10); i }.countAsync()
+ // Wait until one task has started (because we want to make sure that any tasks that are started
+ // have corresponding end events sent to the listener).
+ var finishTime = System.currentTimeMillis + WAIT_TIMEOUT_MILLIS
+ listener.synchronized {
+ var remainingWait = finishTime - System.currentTimeMillis
+ while (listener.startedTasks.isEmpty && remainingWait > 0) {
+ listener.wait(remainingWait)
+ remainingWait = finishTime - System.currentTimeMillis
+ }
+ assert(!listener.startedTasks.isEmpty)
+ }
+
+ f.cancel()
+
+ // Ensure that onTaskEnd is called for all started tasks.
+ finishTime = System.currentTimeMillis + WAIT_TIMEOUT_MILLIS
+ listener.synchronized {
+ var remainingWait = finishTime - System.currentTimeMillis
+ while (listener.endedTasks.size < listener.startedTasks.size && remainingWait > 0) {
+ listener.wait(finishTime - System.currentTimeMillis)
+ remainingWait = finishTime - System.currentTimeMillis
+ }
+ assert(listener.endedTasks.size === listener.startedTasks.size)
+ }
+ }
+
def checkNonZeroAvg(m: Traversable[Long], msg: String) {
assert(m.sum / m.size.toDouble > 0.0, msg)
}
@@ -184,12 +217,14 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
val startedGettingResultTasks = new HashSet[Int]()
val endedTasks = new HashSet[Int]()
- override def onTaskStart(taskStart: SparkListenerTaskStart) {
+ override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized {
startedTasks += taskStart.taskInfo.index
+ notify()
}
- override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
- endedTasks += taskEnd.taskInfo.index
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized {
+ endedTasks += taskEnd.taskInfo.index
+ notify()
}
override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index ecac2f79a2..de321c45b5 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -269,7 +269,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0)
// Tell it the task has finished but the result was lost.
- manager.handleFailedTask(0, TaskState.FINISHED, Some(TaskResultLost))
+ manager.handleFailedTask(0, TaskState.FINISHED, TaskResultLost)
assert(sched.endedTasks(0) === TaskResultLost)
// Re-offer the host -- now we should get task 0 again.
@@ -290,7 +290,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
assert(offerResult.isDefined,
"Expect resource offer on iteration %s to return a task".format(index))
assert(offerResult.get.index === 0)
- manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, Some(TaskResultLost))
+ manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, TaskResultLost)
if (index < MAX_TASK_FAILURES) {
assert(!sched.taskSetsFailed.contains(taskSet.id))
} else {