From 8ac3d1e2636ec71ab9a14bed68f138e3a365603e Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 18 Aug 2013 19:36:34 -0700 Subject: Added unit tests for ClusterTaskSetManager, and fix a bug found with resetting locality level after a non-local launch --- .../main/scala/spark/scheduler/TaskLocation.scala | 4 +- .../main/scala/spark/scheduler/TaskResult.scala | 4 +- .../scheduler/cluster/ClusterTaskSetManager.scala | 27 +- .../spark/scheduler/cluster/TaskDescription.scala | 3 + .../spark/scheduler/cluster/TaskSetManager.scala | 9 + .../scheduler/local/LocalTaskSetManager.scala | 2 +- core/src/main/scala/spark/util/Clock.scala | 29 +++ .../scheduler/cluster/ClusterSchedulerSuite.scala | 21 +- .../cluster/ClusterTaskSetManagerSuite.scala | 273 +++++++++++++++++++++ .../scala/spark/scheduler/cluster/FakeTask.scala | 26 ++ core/src/test/scala/spark/util/FakeClock.scala | 26 ++ 11 files changed, 396 insertions(+), 28 deletions(-) create mode 100644 core/src/main/scala/spark/util/Clock.scala create mode 100644 core/src/test/scala/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala create mode 100644 core/src/test/scala/spark/scheduler/cluster/FakeTask.scala create mode 100644 core/src/test/scala/spark/util/FakeClock.scala diff --git a/core/src/main/scala/spark/scheduler/TaskLocation.scala b/core/src/main/scala/spark/scheduler/TaskLocation.scala index 0e97c61188..fea117e956 100644 --- a/core/src/main/scala/spark/scheduler/TaskLocation.scala +++ b/core/src/main/scala/spark/scheduler/TaskLocation.scala @@ -23,7 +23,9 @@ package spark.scheduler * of preference will be executors on the same host if this is not possible. */ private[spark] -class TaskLocation private (val host: String, val executorId: Option[String]) extends Serializable +class TaskLocation private (val host: String, val executorId: Option[String]) extends Serializable { + override def toString: String = "TaskLocation(" + host + ", " + executorId + ")" +} private[spark] object TaskLocation { def apply(host: String, executorId: String) = new TaskLocation(host, Some(executorId)) diff --git a/core/src/main/scala/spark/scheduler/TaskResult.scala b/core/src/main/scala/spark/scheduler/TaskResult.scala index 89793e0e82..fc4856756b 100644 --- a/core/src/main/scala/spark/scheduler/TaskResult.scala +++ b/core/src/main/scala/spark/scheduler/TaskResult.scala @@ -28,7 +28,9 @@ import java.nio.ByteBuffer // TODO: Use of distributed cache to return result is a hack to get around // what seems to be a bug with messages over 60KB in libprocess; fix it private[spark] -class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any], var metrics: TaskMetrics) extends Externalizable { +class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any], var metrics: TaskMetrics) + extends Externalizable +{ def this() = this(null.asInstanceOf[T], null, null) override def writeExternal(out: ObjectOutput) { diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala index 1d57732f5d..a4d6880abb 100644 --- a/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala @@ -34,6 +34,7 @@ import scala.Some import spark.FetchFailed import spark.ExceptionFailure import spark.TaskResultTooBigFailure +import spark.util.{SystemClock, Clock} /** @@ -46,9 +47,13 @@ import spark.TaskResultTooBigFailure * THREADING: This class is designed to only be called from code with a lock on the * ClusterScheduler (e.g. its event handlers). It should not be called from other threads. */ -private[spark] class ClusterTaskSetManager(sched: ClusterScheduler, val taskSet: TaskSet) - extends TaskSetManager with Logging { - +private[spark] class ClusterTaskSetManager( + sched: ClusterScheduler, + val taskSet: TaskSet, + clock: Clock = SystemClock) + extends TaskSetManager + with Logging +{ // CPUs to request per task val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toInt @@ -142,7 +147,7 @@ private[spark] class ClusterTaskSetManager(sched: ClusterScheduler, val taskSet: // last launched a task at that level, and move up a level when localityWaits[curLevel] expires. // We then move down if we manage to launch a "more local" task. var currentLocalityIndex = 0 // Index of our current locality level in validLocalityLevels - var lastLaunchTime = System.currentTimeMillis() // Time we last launched a task at this level + var lastLaunchTime = clock.getTime() // Time we last launched a task at this level /** * Add a task to all the pending-task lists that it should be on. If readding is set, we are @@ -340,7 +345,7 @@ private[spark] class ClusterTaskSetManager(sched: ClusterScheduler, val taskSet: : Option[TaskDescription] = { if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) { - val curTime = System.currentTimeMillis() + val curTime = clock.getTime() var allowedLocality = getAllowedLocalityLevel(curTime) if (allowedLocality > maxLocality) { @@ -361,22 +366,22 @@ private[spark] class ClusterTaskSetManager(sched: ClusterScheduler, val taskSet: taskInfos(taskId) = info taskAttempts(index) = info :: taskAttempts(index) // Update our locality level for delay scheduling - currentLocalityIndex = getLocalityIndex(allowedLocality) + currentLocalityIndex = getLocalityIndex(taskLocality) lastLaunchTime = curTime // Serialize and return the task - val startTime = System.currentTimeMillis() + val startTime = clock.getTime() // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here // we assume the task can be serialized without exceptions. val serializedTask = Task.serializeWithDependencies( task, sched.sc.addedFiles, sched.sc.addedJars, ser) - val timeTaken = System.currentTimeMillis() - startTime + val timeTaken = clock.getTime() - startTime increaseRunningTasks(1) logInfo("Serialized task %s:%d as %d bytes in %d ms".format( taskSet.id, index, serializedTask.limit, timeTaken)) val taskName = "task %s:%d".format(taskSet.id, index) if (taskAttempts(index).size == 1) taskStarted(task,info) - return Some(new TaskDescription(taskId, execId, taskName, serializedTask)) + return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask)) } case _ => } @@ -505,7 +510,7 @@ private[spark] class ClusterTaskSetManager(sched: ClusterScheduler, val taskSet: case ef: ExceptionFailure => sched.listener.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null)) val key = ef.description - val now = System.currentTimeMillis() + val now = clock.getTime() val (printFull, dupCount) = { if (recentExceptions.contains(key)) { val (dupCount, printTime) = recentExceptions(key) @@ -643,7 +648,7 @@ private[spark] class ClusterTaskSetManager(sched: ClusterScheduler, val taskSet: val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation) if (tasksFinished >= minFinishedForSpeculation) { - val time = System.currentTimeMillis() + val time = clock.getTime() val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray Arrays.sort(durations) val medianDuration = durations(min((0.5 * numTasks).round.toInt, durations.size - 1)) diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala b/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala index 761fdf6919..187553233f 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala @@ -24,6 +24,7 @@ private[spark] class TaskDescription( val taskId: Long, val executorId: String, val name: String, + val index: Int, // Index within this task's TaskSet _serializedTask: ByteBuffer) extends Serializable { @@ -31,4 +32,6 @@ private[spark] class TaskDescription( private val buffer = new SerializableBuffer(_serializedTask) def serializedTask: ByteBuffer = buffer.value + + override def toString: String = "TaskDescription(TID=%d, index=%d)".format(taskId, index) } diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala index 5ab6ab9aad..0248830b7a 100644 --- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala @@ -22,6 +22,15 @@ import java.nio.ByteBuffer import spark.TaskState.TaskState import spark.scheduler.TaskSet +/** + * Tracks and schedules the tasks within a single TaskSet. This class keeps track of the status of + * each task and is responsible for retries on failure and locality. The main interfaces to it + * are resourceOffer, which asks the TaskSet whether it wants to run a task on one node, and + * statusUpdate, which tells it that one of its tasks changed state (e.g. finished). + * + * THREADING: This class is designed to only be called from code with a lock on the TaskScheduler + * (e.g. its event handlers). It should not be called from other threads. + */ private[spark] trait TaskSetManager extends Schedulable { def schedulableQueue = null diff --git a/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala index 3ef636ff07..e237f289e3 100644 --- a/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala +++ b/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala @@ -125,7 +125,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas copiesRunning(index) += 1 increaseRunningTasks(1) taskStarted(task, info) - return Some(new TaskDescription(taskId, null, taskName, bytes)) + return Some(new TaskDescription(taskId, null, taskName, index, bytes)) case None => {} } } diff --git a/core/src/main/scala/spark/util/Clock.scala b/core/src/main/scala/spark/util/Clock.scala new file mode 100644 index 0000000000..aa71a5b442 --- /dev/null +++ b/core/src/main/scala/spark/util/Clock.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package spark.util + +/** + * An interface to represent clocks, so that they can be mocked out in unit tests. + */ +private[spark] trait Clock { + def getTime(): Long +} + +private[spark] object SystemClock extends Clock { + def getTime(): Long = System.currentTimeMillis() +} diff --git a/core/src/test/scala/spark/scheduler/cluster/ClusterSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/cluster/ClusterSchedulerSuite.scala index aeeed14786..abfdabf5fe 100644 --- a/core/src/test/scala/spark/scheduler/cluster/ClusterSchedulerSuite.scala +++ b/core/src/test/scala/spark/scheduler/cluster/ClusterSchedulerSuite.scala @@ -27,7 +27,7 @@ import scala.collection.mutable.ArrayBuffer import java.util.Properties -class DummyTaskSetManager( +class FakeTaskSetManager( initPriority: Int, initStageId: Int, initNumTasks: Int, @@ -81,7 +81,7 @@ class DummyTaskSetManager( { if (tasksFinished + runningTasks < numTasks) { increaseRunningTasks(1) - return Some(new TaskDescription(0, execId, "task 0:0", null)) + return Some(new TaskDescription(0, execId, "task 0:0", 0, null)) } return None } @@ -104,17 +104,10 @@ class DummyTaskSetManager( } } -class DummyTask(stageId: Int) extends Task[Int](stageId) -{ - def run(attemptId: Long): Int = { - return 0 - } -} - class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging { - def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: ClusterScheduler, taskSet: TaskSet): DummyTaskSetManager = { - new DummyTaskSetManager(priority, stage, numTasks, cs , taskSet) + def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: ClusterScheduler, taskSet: TaskSet): FakeTaskSetManager = { + new FakeTaskSetManager(priority, stage, numTasks, cs , taskSet) } def resourceOffer(rootPool: Pool): Int = { @@ -141,7 +134,7 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging sc = new SparkContext("local", "ClusterSchedulerSuite") val clusterScheduler = new ClusterScheduler(sc) var tasks = ArrayBuffer[Task[_]]() - val task = new DummyTask(0) + val task = new FakeTask(0) tasks += task val taskSet = new TaskSet(tasks.toArray,0,0,0,null) @@ -168,7 +161,7 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging sc = new SparkContext("local", "ClusterSchedulerSuite") val clusterScheduler = new ClusterScheduler(sc) var tasks = ArrayBuffer[Task[_]]() - val task = new DummyTask(0) + val task = new FakeTask(0) tasks += task val taskSet = new TaskSet(tasks.toArray,0,0,0,null) @@ -225,7 +218,7 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging sc = new SparkContext("local", "ClusterSchedulerSuite") val clusterScheduler = new ClusterScheduler(sc) var tasks = ArrayBuffer[Task[_]]() - val task = new DummyTask(0) + val task = new FakeTask(0) tasks += task val taskSet = new TaskSet(tasks.toArray,0,0,0,null) diff --git a/core/src/test/scala/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala b/core/src/test/scala/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala new file mode 100644 index 0000000000..5a0b949ef5 --- /dev/null +++ b/core/src/test/scala/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package spark.scheduler.cluster + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable + +import org.scalatest.FunSuite + +import spark._ +import spark.scheduler._ +import spark.executor.TaskMetrics +import java.nio.ByteBuffer +import spark.util.FakeClock + +/** + * A mock ClusterScheduler implementation that just remembers information about tasks started and + * feedback received from the TaskSetManagers. Note that it's important to initialize this with + * a list of "live" executors and their hostnames for isExecutorAlive and hasExecutorsAliveOnHost + * to work, and these are required for locality in ClusterTaskSetManager. + */ +class FakeClusterScheduler(sc: SparkContext, liveExecutors: (String, String)* /* execId, host */) + extends ClusterScheduler(sc) +{ + val startedTasks = new ArrayBuffer[Long] + val endedTasks = new mutable.HashMap[Long, TaskEndReason] + val finishedManagers = new ArrayBuffer[TaskSetManager] + + val executors = new mutable.HashMap[String, String] ++ liveExecutors + + listener = new TaskSchedulerListener { + def taskStarted(task: Task[_], taskInfo: TaskInfo) { + startedTasks += taskInfo.index + } + + def taskEnded( + task: Task[_], + reason: TaskEndReason, + result: Any, + accumUpdates: mutable.Map[Long, Any], + taskInfo: TaskInfo, + taskMetrics: TaskMetrics) + { + endedTasks(taskInfo.index) = reason + } + + def executorGained(execId: String, host: String) {} + + def executorLost(execId: String) {} + + def taskSetFailed(taskSet: TaskSet, reason: String) {} + } + + def removeExecutor(execId: String): Unit = executors -= execId + + override def taskSetFinished(manager: TaskSetManager): Unit = finishedManagers += manager + + override def isExecutorAlive(execId: String): Boolean = executors.contains(execId) + + override def hasExecutorsAliveOnHost(host: String): Boolean = executors.values.exists(_ == host) +} + +class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { + import TaskLocality.{ANY, PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL} + + val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong + + test("TaskSet with no preferences") { + sc = new SparkContext("local", "test") + val sched = new FakeClusterScheduler(sc, ("exec1", "host1")) + val taskSet = createTaskSet(1) + val manager = new ClusterTaskSetManager(sched, taskSet) + + // Offer a host with no CPUs + assert(manager.resourceOffer("exec1", "host1", 0, ANY) === None) + + // Offer a host with process-local as the constraint; this should work because the TaskSet + // above won't have any locality preferences + val taskOption = manager.resourceOffer("exec1", "host1", 2, TaskLocality.PROCESS_LOCAL) + assert(taskOption.isDefined) + val task = taskOption.get + assert(task.executorId === "exec1") + assert(sched.startedTasks.contains(0)) + + // Re-offer the host -- now we should get no more tasks + assert(manager.resourceOffer("exec1", "host1", 2, PROCESS_LOCAL) === None) + + // Tell it the task has finished + manager.statusUpdate(0, TaskState.FINISHED, createTaskResult(0)) + assert(sched.endedTasks(0) === Success) + assert(sched.finishedManagers.contains(manager)) + } + + test("multiple offers with no preferences") { + sc = new SparkContext("local", "test") + val sched = new FakeClusterScheduler(sc, ("exec1", "host1")) + val taskSet = createTaskSet(3) + val manager = new ClusterTaskSetManager(sched, taskSet) + + // First three offers should all find tasks + for (i <- 0 until 3) { + val taskOption = manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL) + assert(taskOption.isDefined) + val task = taskOption.get + assert(task.executorId === "exec1") + } + assert(sched.startedTasks.toSet === Set(0, 1, 2)) + + // Re-offer the host -- now we should get no more tasks + assert(manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL) === None) + + // Finish the first two tasks + manager.statusUpdate(0, TaskState.FINISHED, createTaskResult(0)) + manager.statusUpdate(1, TaskState.FINISHED, createTaskResult(1)) + assert(sched.endedTasks(0) === Success) + assert(sched.endedTasks(1) === Success) + assert(!sched.finishedManagers.contains(manager)) + + // Finish the last task + manager.statusUpdate(2, TaskState.FINISHED, createTaskResult(2)) + assert(sched.endedTasks(2) === Success) + assert(sched.finishedManagers.contains(manager)) + } + + test("basic delay scheduling") { + sc = new SparkContext("local", "test") + val sched = new FakeClusterScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) + val taskSet = createTaskSet(4, + Seq(TaskLocation("host1", "exec1")), + Seq(TaskLocation("host2", "exec2")), + Seq(TaskLocation("host1"), TaskLocation("host2", "exec2")), + Seq() // Last task has no locality prefs + ) + val clock = new FakeClock + val manager = new ClusterTaskSetManager(sched, taskSet, clock) + + // First offer host1, exec1: first task should be chosen + assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) + + // Offer host1, exec1 again: the last task, which has no prefs, should be chosen + assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 3) + + // Offer host1, exec1 again, at PROCESS_LOCAL level: nothing should get chosen + assert(manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL) === None) + + clock.advance(LOCALITY_WAIT) + + // Offer host1, exec1 again, at PROCESS_LOCAL level: nothing should get chosen + assert(manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL) === None) + + // Offer host1, exec1 again, at NODE_LOCAL level: we should choose task 2 + assert(manager.resourceOffer("exec1", "host1", 1, NODE_LOCAL).get.index == 2) + + // Offer host1, exec1 again, at NODE_LOCAL level: nothing should get chosen + assert(manager.resourceOffer("exec1", "host1", 1, NODE_LOCAL) === None) + + // Offer host1, exec1 again, at ANY level: nothing should get chosen + assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None) + + clock.advance(LOCALITY_WAIT) + + // Offer host1, exec1 again, at ANY level: task 1 should get chosen + assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 1) + + // Offer host1, exec1 again, at ANY level: nothing should be chosen as we've launched all tasks + assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None) + } + + test("delay scheduling with fallback") { + sc = new SparkContext("local", "test") + val sched = new FakeClusterScheduler(sc, + ("exec1", "host1"), ("exec2", "host2"), ("exec3", "host3")) + val taskSet = createTaskSet(5, + Seq(TaskLocation("host1")), + Seq(TaskLocation("host2")), + Seq(TaskLocation("host2")), + Seq(TaskLocation("host3")), + Seq(TaskLocation("host2")) + ) + val clock = new FakeClock + val manager = new ClusterTaskSetManager(sched, taskSet, clock) + + // First offer host1: first task should be chosen + assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) + + // Offer host1 again: nothing should get chosen + assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None) + + clock.advance(LOCALITY_WAIT) + + // Offer host1 again: second task (on host2) should get chosen + assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 1) + + // Offer host1 again: third task (on host2) should get chosen + assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 2) + + // Offer host2: fifth task (also on host2) should get chosen + assert(manager.resourceOffer("exec2", "host2", 1, ANY).get.index === 4) + + // Now that we've launched a local task, we should no longer launch the task for host3 + assert(manager.resourceOffer("exec2", "host2", 1, ANY) === None) + + clock.advance(LOCALITY_WAIT) + + // After another delay, we can go ahead and launch that task non-locally + assert(manager.resourceOffer("exec2", "host2", 1, ANY).get.index === 3) + } + + test("delay scheduling with failed hosts") { + sc = new SparkContext("local", "test") + val sched = new FakeClusterScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) + val taskSet = createTaskSet(3, + Seq(TaskLocation("host1")), + Seq(TaskLocation("host2")), + Seq(TaskLocation("host3")) + ) + val clock = new FakeClock + val manager = new ClusterTaskSetManager(sched, taskSet, clock) + + // First offer host1: first task should be chosen + assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) + + // Offer host1 again: third task should be chosen immediately because host3 is not up + assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 2) + + // After this, nothing should get chosen + assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None) + + // Now mark host2 as dead + sched.removeExecutor("exec2") + manager.executorLost("exec2", "host2") + + // Task 1 should immediately be launched on host1 because its original host is gone + assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 1) + + // Now that all tasks have launched, nothing new should be launched anywhere else + assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None) + assert(manager.resourceOffer("exec2", "host2", 1, ANY) === None) + } + + /** + * Utility method to create a TaskSet, potentially setting a particular sequence of preferred + * locations for each task (given as varargs) if this sequence is not empty. + */ + def createTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = { + if (prefLocs.size != 0 && prefLocs.size != numTasks) { + throw new IllegalArgumentException("Wrong number of task locations") + } + val tasks = Array.tabulate[Task[_]](numTasks) { i => + new FakeTask(i, if (prefLocs.size != 0) prefLocs(i) else Nil) + } + new TaskSet(tasks, 0, 0, 0, null) + } + + def createTaskResult(id: Int): ByteBuffer = { + ByteBuffer.wrap(Utils.serialize(new TaskResult[Int](id, mutable.Map.empty, new TaskMetrics))) + } +} diff --git a/core/src/test/scala/spark/scheduler/cluster/FakeTask.scala b/core/src/test/scala/spark/scheduler/cluster/FakeTask.scala new file mode 100644 index 0000000000..de9e66be20 --- /dev/null +++ b/core/src/test/scala/spark/scheduler/cluster/FakeTask.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package spark.scheduler.cluster + +import spark.scheduler.{TaskLocation, Task} + +class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId) { + override def run(attemptId: Long): Int = 0 + + override def preferredLocations: Seq[TaskLocation] = prefLocs +} diff --git a/core/src/test/scala/spark/util/FakeClock.scala b/core/src/test/scala/spark/util/FakeClock.scala new file mode 100644 index 0000000000..236706317e --- /dev/null +++ b/core/src/test/scala/spark/util/FakeClock.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package spark.util + +class FakeClock extends Clock { + private var time = 0L + + def advance(millis: Long): Unit = time += millis + + def getTime(): Long = time +} -- cgit v1.2.3