diff options
Diffstat (limited to 'core/src/test/scala')
5 files changed, 62 insertions, 3 deletions
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 9b7b945bf3..1d7c8f4a61 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -22,6 +22,8 @@ import java.util.Random import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import org.mockito.Mockito.{mock, verify} + import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.util.{AccumulatorV2, ManualClock} @@ -789,6 +791,54 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(TaskLocation("executor_host1_3") === ExecutorCacheTaskLocation("host1", "3")) } + test("Kill other task attempts when one attempt belonging to the same task succeeds") { + sc = new SparkContext("local", "test") + val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) + val taskSet = FakeTask.createTaskSet(4) + // Set the speculation multiplier to be 0 so speculative tasks are launched immediately + sc.conf.set("spark.speculation.multiplier", "0.0") + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) + val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet.tasks.map { task => + task.metrics.internalAccums + } + // Offer resources for 4 tasks to start + for ((k, v) <- List( + "exec1" -> "host1", + "exec1" -> "host1", + "exec2" -> "host2", + "exec2" -> "host2")) { + val taskOption = manager.resourceOffer(k, v, NO_PREF) + assert(taskOption.isDefined) + val task = taskOption.get + assert(task.executorId === k) + } + assert(sched.startedTasks.toSet === Set(0, 1, 2, 3)) + // Complete the 3 tasks and leave 1 task in running + for (id <- Set(0, 1, 2)) { + manager.handleSuccessfulTask(id, createTaskResult(id, accumUpdatesByTask(id))) + assert(sched.endedTasks(id) === Success) + } + + assert(manager.checkSpeculatableTasks(0)) + // Offer resource to start the speculative attempt for the running task + val taskOption5 = manager.resourceOffer("exec1", "host1", NO_PREF) + assert(taskOption5.isDefined) + val task5 = taskOption5.get + assert(task5.index === 3) + assert(task5.taskId === 4) + assert(task5.executorId === "exec1") + assert(task5.attemptNumber === 1) + sched.backend = mock(classOf[SchedulerBackend]) + // Complete the speculative attempt for the running task + manager.handleSuccessfulTask(4, createTaskResult(3, accumUpdatesByTask(3))) + // Verify that it kills other running attempt + verify(sched.backend).killTask(3, "exec2", true) + // Because the SchedulerBackend was a mock, the 2nd copy of the task won't actually be + // killed, so the FakeTaskScheduler is only told about the successful completion + // of the speculated task. + assert(sched.endedTasks(3) === Success) + } + private def createTaskResult( id: Int, accumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty): DirectTaskResult[Int] = { diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala index b83ffa3282..6d726d3d59 100644 --- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala @@ -83,7 +83,7 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { val taskInfo = new TaskInfo(taskId, taskId, 0, 0, "0", "localhost", TaskLocality.ANY, false) jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo)) jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo)) - taskInfo.markSuccessful() + taskInfo.markFinished(TaskState.FINISHED) val taskMetrics = TaskMetrics.empty taskMetrics.incPeakExecutionMemory(peakExecutionMemory) jobListener.onTaskEnd( diff --git a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala index 58beaf103c..6335d905c0 100644 --- a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala @@ -110,7 +110,7 @@ class UIUtilsSuite extends SparkFunSuite { } test("SPARK-11906: Progress bar should not overflow because of speculative tasks") { - val generated = makeProgressBar(2, 3, 0, 0, 4).head.child.filter(_.label == "div") + val generated = makeProgressBar(2, 3, 0, 0, 0, 4).head.child.filter(_.label == "div") val expected = Seq( <div class="bar bar-completed" style="width: 75.0%"></div>, <div class="bar bar-running" style="width: 25.0%"></div> diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 1fa9b28edf..edab727fc4 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -243,7 +243,6 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with new FetchFailed(null, 0, 0, 0, "ignored"), ExceptionFailure("Exception", "description", null, null, None), TaskResultLost, - TaskKilled, ExecutorLostFailure("0", true, Some("Induced failure")), UnknownReason) var failCount = 0 @@ -255,6 +254,11 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with assert(listener.stageIdToData((task.stageId, 0)).numFailedTasks === failCount) } + // Make sure killed tasks are accounted for correctly. + listener.onTaskEnd( + SparkListenerTaskEnd(task.stageId, 0, taskType, TaskKilled, taskInfo, metrics)) + assert(listener.stageIdToData((task.stageId, 0)).numKilledTasks === 1) + // Make sure we count success as success. listener.onTaskEnd( SparkListenerTaskEnd(task.stageId, 1, taskType, Success, taskInfo, metrics)) diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 6fda7378e6..0a8bbba6c5 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -966,6 +966,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Getting Result Time": 0, | "Finish Time": 0, | "Failed": false, + | "Killed": false, | "Accumulables": [ | { | "ID": 1, @@ -1012,6 +1013,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Getting Result Time": 0, | "Finish Time": 0, | "Failed": false, + | "Killed": false, | "Accumulables": [ | { | "ID": 1, @@ -1064,6 +1066,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Getting Result Time": 0, | "Finish Time": 0, | "Failed": false, + | "Killed": false, | "Accumulables": [ | { | "ID": 1, @@ -1161,6 +1164,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Getting Result Time": 0, | "Finish Time": 0, | "Failed": false, + | "Killed": false, | "Accumulables": [ | { | "ID": 1, @@ -1258,6 +1262,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Getting Result Time": 0, | "Finish Time": 0, | "Failed": false, + | "Killed": false, | "Accumulables": [ | { | "ID": 1, |