aboutsummaryrefslogtreecommitdiff
path: root/core/src/test/scala/org
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/test/scala/org')
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala50
-rw-r--r--core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala6
-rw-r--r--core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala5
5 files changed, 62 insertions, 3 deletions
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index 9b7b945bf3..1d7c8f4a61 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -22,6 +22,8 @@ import java.util.Random
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
+import org.mockito.Mockito.{mock, verify}
+
import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.util.{AccumulatorV2, ManualClock}
@@ -789,6 +791,54 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
assert(TaskLocation("executor_host1_3") === ExecutorCacheTaskLocation("host1", "3"))
}
+ test("Kill other task attempts when one attempt belonging to the same task succeeds") {
+ sc = new SparkContext("local", "test")
+ val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2"))
+ val taskSet = FakeTask.createTaskSet(4)
+ // Set the speculation multiplier to be 0 so speculative tasks are launched immediately
+ sc.conf.set("spark.speculation.multiplier", "0.0")
+ val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES)
+ val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet.tasks.map { task =>
+ task.metrics.internalAccums
+ }
+ // Offer resources for 4 tasks to start
+ for ((k, v) <- List(
+ "exec1" -> "host1",
+ "exec1" -> "host1",
+ "exec2" -> "host2",
+ "exec2" -> "host2")) {
+ val taskOption = manager.resourceOffer(k, v, NO_PREF)
+ assert(taskOption.isDefined)
+ val task = taskOption.get
+ assert(task.executorId === k)
+ }
+ assert(sched.startedTasks.toSet === Set(0, 1, 2, 3))
+ // Complete the 3 tasks and leave 1 task in running
+ for (id <- Set(0, 1, 2)) {
+ manager.handleSuccessfulTask(id, createTaskResult(id, accumUpdatesByTask(id)))
+ assert(sched.endedTasks(id) === Success)
+ }
+
+ assert(manager.checkSpeculatableTasks(0))
+ // Offer resource to start the speculative attempt for the running task
+ val taskOption5 = manager.resourceOffer("exec1", "host1", NO_PREF)
+ assert(taskOption5.isDefined)
+ val task5 = taskOption5.get
+ assert(task5.index === 3)
+ assert(task5.taskId === 4)
+ assert(task5.executorId === "exec1")
+ assert(task5.attemptNumber === 1)
+ sched.backend = mock(classOf[SchedulerBackend])
+ // Complete the speculative attempt for the running task
+ manager.handleSuccessfulTask(4, createTaskResult(3, accumUpdatesByTask(3)))
+ // Verify that it kills other running attempt
+ verify(sched.backend).killTask(3, "exec2", true)
+ // Because the SchedulerBackend was a mock, the 2nd copy of the task won't actually be
+ // killed, so the FakeTaskScheduler is only told about the successful completion
+ // of the speculated task.
+ assert(sched.endedTasks(3) === Success)
+ }
+
private def createTaskResult(
id: Int,
accumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty): DirectTaskResult[Int] = {
diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala
index b83ffa3282..6d726d3d59 100644
--- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala
@@ -83,7 +83,7 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext {
val taskInfo = new TaskInfo(taskId, taskId, 0, 0, "0", "localhost", TaskLocality.ANY, false)
jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo))
jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo))
- taskInfo.markSuccessful()
+ taskInfo.markFinished(TaskState.FINISHED)
val taskMetrics = TaskMetrics.empty
taskMetrics.incPeakExecutionMemory(peakExecutionMemory)
jobListener.onTaskEnd(
diff --git a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala
index 58beaf103c..6335d905c0 100644
--- a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala
@@ -110,7 +110,7 @@ class UIUtilsSuite extends SparkFunSuite {
}
test("SPARK-11906: Progress bar should not overflow because of speculative tasks") {
- val generated = makeProgressBar(2, 3, 0, 0, 4).head.child.filter(_.label == "div")
+ val generated = makeProgressBar(2, 3, 0, 0, 0, 4).head.child.filter(_.label == "div")
val expected = Seq(
<div class="bar bar-completed" style="width: 75.0%"></div>,
<div class="bar bar-running" style="width: 25.0%"></div>
diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
index 1fa9b28edf..edab727fc4 100644
--- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
@@ -243,7 +243,6 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with
new FetchFailed(null, 0, 0, 0, "ignored"),
ExceptionFailure("Exception", "description", null, null, None),
TaskResultLost,
- TaskKilled,
ExecutorLostFailure("0", true, Some("Induced failure")),
UnknownReason)
var failCount = 0
@@ -255,6 +254,11 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with
assert(listener.stageIdToData((task.stageId, 0)).numFailedTasks === failCount)
}
+ // Make sure killed tasks are accounted for correctly.
+ listener.onTaskEnd(
+ SparkListenerTaskEnd(task.stageId, 0, taskType, TaskKilled, taskInfo, metrics))
+ assert(listener.stageIdToData((task.stageId, 0)).numKilledTasks === 1)
+
// Make sure we count success as success.
listener.onTaskEnd(
SparkListenerTaskEnd(task.stageId, 1, taskType, Success, taskInfo, metrics))
diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
index 6fda7378e6..0a8bbba6c5 100644
--- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
@@ -966,6 +966,7 @@ private[spark] object JsonProtocolSuite extends Assertions {
| "Getting Result Time": 0,
| "Finish Time": 0,
| "Failed": false,
+ | "Killed": false,
| "Accumulables": [
| {
| "ID": 1,
@@ -1012,6 +1013,7 @@ private[spark] object JsonProtocolSuite extends Assertions {
| "Getting Result Time": 0,
| "Finish Time": 0,
| "Failed": false,
+ | "Killed": false,
| "Accumulables": [
| {
| "ID": 1,
@@ -1064,6 +1066,7 @@ private[spark] object JsonProtocolSuite extends Assertions {
| "Getting Result Time": 0,
| "Finish Time": 0,
| "Failed": false,
+ | "Killed": false,
| "Accumulables": [
| {
| "ID": 1,
@@ -1161,6 +1164,7 @@ private[spark] object JsonProtocolSuite extends Assertions {
| "Getting Result Time": 0,
| "Finish Time": 0,
| "Failed": false,
+ | "Killed": false,
| "Accumulables": [
| {
| "ID": 1,
@@ -1258,6 +1262,7 @@ private[spark] object JsonProtocolSuite extends Assertions {
| "Getting Result Time": 0,
| "Finish Time": 0,
| "Failed": false,
+ | "Killed": false,
| "Accumulables": [
| {
| "ID": 1,