aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2013-08-18 19:36:34 -0700
committerMatei Zaharia <matei@eecs.berkeley.edu>2013-08-18 19:51:07 -0700
commit8ac3d1e2636ec71ab9a14bed68f138e3a365603e (patch)
treeee252fb07303247a929780ea4b33ee1565b6b7b2 /core
parent4004cf775d9397efbb5945768aaf05ba682c715c (diff)
downloadspark-8ac3d1e2636ec71ab9a14bed68f138e3a365603e.tar.gz
spark-8ac3d1e2636ec71ab9a14bed68f138e3a365603e.tar.bz2
spark-8ac3d1e2636ec71ab9a14bed68f138e3a365603e.zip
Added unit tests for ClusterTaskSetManager, and fix a bug found with
resetting locality level after a non-local launch
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/spark/scheduler/TaskLocation.scala4
-rw-r--r--core/src/main/scala/spark/scheduler/TaskResult.scala4
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala27
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala3
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala9
-rw-r--r--core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala2
-rw-r--r--core/src/main/scala/spark/util/Clock.scala29
-rw-r--r--core/src/test/scala/spark/scheduler/cluster/ClusterSchedulerSuite.scala21
-rw-r--r--core/src/test/scala/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala273
-rw-r--r--core/src/test/scala/spark/scheduler/cluster/FakeTask.scala26
-rw-r--r--core/src/test/scala/spark/util/FakeClock.scala26
11 files changed, 396 insertions, 28 deletions
diff --git a/core/src/main/scala/spark/scheduler/TaskLocation.scala b/core/src/main/scala/spark/scheduler/TaskLocation.scala
index 0e97c61188..fea117e956 100644
--- a/core/src/main/scala/spark/scheduler/TaskLocation.scala
+++ b/core/src/main/scala/spark/scheduler/TaskLocation.scala
@@ -23,7 +23,9 @@ package spark.scheduler
* of preference will be executors on the same host if this is not possible.
*/
private[spark]
-class TaskLocation private (val host: String, val executorId: Option[String]) extends Serializable
+class TaskLocation private (val host: String, val executorId: Option[String]) extends Serializable {
+ override def toString: String = "TaskLocation(" + host + ", " + executorId + ")"
+}
private[spark] object TaskLocation {
def apply(host: String, executorId: String) = new TaskLocation(host, Some(executorId))
diff --git a/core/src/main/scala/spark/scheduler/TaskResult.scala b/core/src/main/scala/spark/scheduler/TaskResult.scala
index 89793e0e82..fc4856756b 100644
--- a/core/src/main/scala/spark/scheduler/TaskResult.scala
+++ b/core/src/main/scala/spark/scheduler/TaskResult.scala
@@ -28,7 +28,9 @@ import java.nio.ByteBuffer
// TODO: Use of distributed cache to return result is a hack to get around
// what seems to be a bug with messages over 60KB in libprocess; fix it
private[spark]
-class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any], var metrics: TaskMetrics) extends Externalizable {
+class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any], var metrics: TaskMetrics)
+ extends Externalizable
+{
def this() = this(null.asInstanceOf[T], null, null)
override def writeExternal(out: ObjectOutput) {
diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala
index 1d57732f5d..a4d6880abb 100644
--- a/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala
@@ -34,6 +34,7 @@ import scala.Some
import spark.FetchFailed
import spark.ExceptionFailure
import spark.TaskResultTooBigFailure
+import spark.util.{SystemClock, Clock}
/**
@@ -46,9 +47,13 @@ import spark.TaskResultTooBigFailure
* THREADING: This class is designed to only be called from code with a lock on the
* ClusterScheduler (e.g. its event handlers). It should not be called from other threads.
*/
-private[spark] class ClusterTaskSetManager(sched: ClusterScheduler, val taskSet: TaskSet)
- extends TaskSetManager with Logging {
-
+private[spark] class ClusterTaskSetManager(
+ sched: ClusterScheduler,
+ val taskSet: TaskSet,
+ clock: Clock = SystemClock)
+ extends TaskSetManager
+ with Logging
+{
// CPUs to request per task
val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toInt
@@ -142,7 +147,7 @@ private[spark] class ClusterTaskSetManager(sched: ClusterScheduler, val taskSet:
// last launched a task at that level, and move up a level when localityWaits[curLevel] expires.
// We then move down if we manage to launch a "more local" task.
var currentLocalityIndex = 0 // Index of our current locality level in validLocalityLevels
- var lastLaunchTime = System.currentTimeMillis() // Time we last launched a task at this level
+ var lastLaunchTime = clock.getTime() // Time we last launched a task at this level
/**
* Add a task to all the pending-task lists that it should be on. If readding is set, we are
@@ -340,7 +345,7 @@ private[spark] class ClusterTaskSetManager(sched: ClusterScheduler, val taskSet:
: Option[TaskDescription] =
{
if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) {
- val curTime = System.currentTimeMillis()
+ val curTime = clock.getTime()
var allowedLocality = getAllowedLocalityLevel(curTime)
if (allowedLocality > maxLocality) {
@@ -361,22 +366,22 @@ private[spark] class ClusterTaskSetManager(sched: ClusterScheduler, val taskSet:
taskInfos(taskId) = info
taskAttempts(index) = info :: taskAttempts(index)
// Update our locality level for delay scheduling
- currentLocalityIndex = getLocalityIndex(allowedLocality)
+ currentLocalityIndex = getLocalityIndex(taskLocality)
lastLaunchTime = curTime
// Serialize and return the task
- val startTime = System.currentTimeMillis()
+ val startTime = clock.getTime()
// We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here
// we assume the task can be serialized without exceptions.
val serializedTask = Task.serializeWithDependencies(
task, sched.sc.addedFiles, sched.sc.addedJars, ser)
- val timeTaken = System.currentTimeMillis() - startTime
+ val timeTaken = clock.getTime() - startTime
increaseRunningTasks(1)
logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
taskSet.id, index, serializedTask.limit, timeTaken))
val taskName = "task %s:%d".format(taskSet.id, index)
if (taskAttempts(index).size == 1)
taskStarted(task,info)
- return Some(new TaskDescription(taskId, execId, taskName, serializedTask))
+ return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask))
}
case _ =>
}
@@ -505,7 +510,7 @@ private[spark] class ClusterTaskSetManager(sched: ClusterScheduler, val taskSet:
case ef: ExceptionFailure =>
sched.listener.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null))
val key = ef.description
- val now = System.currentTimeMillis()
+ val now = clock.getTime()
val (printFull, dupCount) = {
if (recentExceptions.contains(key)) {
val (dupCount, printTime) = recentExceptions(key)
@@ -643,7 +648,7 @@ private[spark] class ClusterTaskSetManager(sched: ClusterScheduler, val taskSet:
val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt
logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
if (tasksFinished >= minFinishedForSpeculation) {
- val time = System.currentTimeMillis()
+ val time = clock.getTime()
val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray
Arrays.sort(durations)
val medianDuration = durations(min((0.5 * numTasks).round.toInt, durations.size - 1))
diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala b/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala
index 761fdf6919..187553233f 100644
--- a/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/TaskDescription.scala
@@ -24,6 +24,7 @@ private[spark] class TaskDescription(
val taskId: Long,
val executorId: String,
val name: String,
+ val index: Int, // Index within this task's TaskSet
_serializedTask: ByteBuffer)
extends Serializable {
@@ -31,4 +32,6 @@ private[spark] class TaskDescription(
private val buffer = new SerializableBuffer(_serializedTask)
def serializedTask: ByteBuffer = buffer.value
+
+ override def toString: String = "TaskDescription(TID=%d, index=%d)".format(taskId, index)
}
diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
index 5ab6ab9aad..0248830b7a 100644
--- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
@@ -22,6 +22,15 @@ import java.nio.ByteBuffer
import spark.TaskState.TaskState
import spark.scheduler.TaskSet
+/**
+ * Tracks and schedules the tasks within a single TaskSet. This class keeps track of the status of
+ * each task and is responsible for retries on failure and locality. The main interfaces to it
+ * are resourceOffer, which asks the TaskSet whether it wants to run a task on one node, and
+ * statusUpdate, which tells it that one of its tasks changed state (e.g. finished).
+ *
+ * THREADING: This class is designed to only be called from code with a lock on the TaskScheduler
+ * (e.g. its event handlers). It should not be called from other threads.
+ */
private[spark] trait TaskSetManager extends Schedulable {
def schedulableQueue = null
diff --git a/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala
index 3ef636ff07..e237f289e3 100644
--- a/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala
+++ b/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala
@@ -125,7 +125,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
copiesRunning(index) += 1
increaseRunningTasks(1)
taskStarted(task, info)
- return Some(new TaskDescription(taskId, null, taskName, bytes))
+ return Some(new TaskDescription(taskId, null, taskName, index, bytes))
case None => {}
}
}
diff --git a/core/src/main/scala/spark/util/Clock.scala b/core/src/main/scala/spark/util/Clock.scala
new file mode 100644
index 0000000000..aa71a5b442
--- /dev/null
+++ b/core/src/main/scala/spark/util/Clock.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package spark.util
+
+/**
+ * An interface to represent clocks, so that they can be mocked out in unit tests.
+ */
+private[spark] trait Clock {
+ def getTime(): Long
+}
+
+private[spark] object SystemClock extends Clock {
+ def getTime(): Long = System.currentTimeMillis()
+}
diff --git a/core/src/test/scala/spark/scheduler/cluster/ClusterSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/cluster/ClusterSchedulerSuite.scala
index aeeed14786..abfdabf5fe 100644
--- a/core/src/test/scala/spark/scheduler/cluster/ClusterSchedulerSuite.scala
+++ b/core/src/test/scala/spark/scheduler/cluster/ClusterSchedulerSuite.scala
@@ -27,7 +27,7 @@ import scala.collection.mutable.ArrayBuffer
import java.util.Properties
-class DummyTaskSetManager(
+class FakeTaskSetManager(
initPriority: Int,
initStageId: Int,
initNumTasks: Int,
@@ -81,7 +81,7 @@ class DummyTaskSetManager(
{
if (tasksFinished + runningTasks < numTasks) {
increaseRunningTasks(1)
- return Some(new TaskDescription(0, execId, "task 0:0", null))
+ return Some(new TaskDescription(0, execId, "task 0:0", 0, null))
}
return None
}
@@ -104,17 +104,10 @@ class DummyTaskSetManager(
}
}
-class DummyTask(stageId: Int) extends Task[Int](stageId)
-{
- def run(attemptId: Long): Int = {
- return 0
- }
-}
-
class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging {
- def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: ClusterScheduler, taskSet: TaskSet): DummyTaskSetManager = {
- new DummyTaskSetManager(priority, stage, numTasks, cs , taskSet)
+ def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: ClusterScheduler, taskSet: TaskSet): FakeTaskSetManager = {
+ new FakeTaskSetManager(priority, stage, numTasks, cs , taskSet)
}
def resourceOffer(rootPool: Pool): Int = {
@@ -141,7 +134,7 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging
sc = new SparkContext("local", "ClusterSchedulerSuite")
val clusterScheduler = new ClusterScheduler(sc)
var tasks = ArrayBuffer[Task[_]]()
- val task = new DummyTask(0)
+ val task = new FakeTask(0)
tasks += task
val taskSet = new TaskSet(tasks.toArray,0,0,0,null)
@@ -168,7 +161,7 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging
sc = new SparkContext("local", "ClusterSchedulerSuite")
val clusterScheduler = new ClusterScheduler(sc)
var tasks = ArrayBuffer[Task[_]]()
- val task = new DummyTask(0)
+ val task = new FakeTask(0)
tasks += task
val taskSet = new TaskSet(tasks.toArray,0,0,0,null)
@@ -225,7 +218,7 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging
sc = new SparkContext("local", "ClusterSchedulerSuite")
val clusterScheduler = new ClusterScheduler(sc)
var tasks = ArrayBuffer[Task[_]]()
- val task = new DummyTask(0)
+ val task = new FakeTask(0)
tasks += task
val taskSet = new TaskSet(tasks.toArray,0,0,0,null)
diff --git a/core/src/test/scala/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala b/core/src/test/scala/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala
new file mode 100644
index 0000000000..5a0b949ef5
--- /dev/null
+++ b/core/src/test/scala/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala
@@ -0,0 +1,273 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package spark.scheduler.cluster
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable
+
+import org.scalatest.FunSuite
+
+import spark._
+import spark.scheduler._
+import spark.executor.TaskMetrics
+import java.nio.ByteBuffer
+import spark.util.FakeClock
+
+/**
+ * A mock ClusterScheduler implementation that just remembers information about tasks started and
+ * feedback received from the TaskSetManagers. Note that it's important to initialize this with
+ * a list of "live" executors and their hostnames for isExecutorAlive and hasExecutorsAliveOnHost
+ * to work, and these are required for locality in ClusterTaskSetManager.
+ */
+class FakeClusterScheduler(sc: SparkContext, liveExecutors: (String, String)* /* execId, host */)
+ extends ClusterScheduler(sc)
+{
+ val startedTasks = new ArrayBuffer[Long]
+ val endedTasks = new mutable.HashMap[Long, TaskEndReason]
+ val finishedManagers = new ArrayBuffer[TaskSetManager]
+
+ val executors = new mutable.HashMap[String, String] ++ liveExecutors
+
+ listener = new TaskSchedulerListener {
+ def taskStarted(task: Task[_], taskInfo: TaskInfo) {
+ startedTasks += taskInfo.index
+ }
+
+ def taskEnded(
+ task: Task[_],
+ reason: TaskEndReason,
+ result: Any,
+ accumUpdates: mutable.Map[Long, Any],
+ taskInfo: TaskInfo,
+ taskMetrics: TaskMetrics)
+ {
+ endedTasks(taskInfo.index) = reason
+ }
+
+ def executorGained(execId: String, host: String) {}
+
+ def executorLost(execId: String) {}
+
+ def taskSetFailed(taskSet: TaskSet, reason: String) {}
+ }
+
+ def removeExecutor(execId: String): Unit = executors -= execId
+
+ override def taskSetFinished(manager: TaskSetManager): Unit = finishedManagers += manager
+
+ override def isExecutorAlive(execId: String): Boolean = executors.contains(execId)
+
+ override def hasExecutorsAliveOnHost(host: String): Boolean = executors.values.exists(_ == host)
+}
+
+class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
+ import TaskLocality.{ANY, PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL}
+
+ val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong
+
+ test("TaskSet with no preferences") {
+ sc = new SparkContext("local", "test")
+ val sched = new FakeClusterScheduler(sc, ("exec1", "host1"))
+ val taskSet = createTaskSet(1)
+ val manager = new ClusterTaskSetManager(sched, taskSet)
+
+ // Offer a host with no CPUs
+ assert(manager.resourceOffer("exec1", "host1", 0, ANY) === None)
+
+ // Offer a host with process-local as the constraint; this should work because the TaskSet
+ // above won't have any locality preferences
+ val taskOption = manager.resourceOffer("exec1", "host1", 2, TaskLocality.PROCESS_LOCAL)
+ assert(taskOption.isDefined)
+ val task = taskOption.get
+ assert(task.executorId === "exec1")
+ assert(sched.startedTasks.contains(0))
+
+ // Re-offer the host -- now we should get no more tasks
+ assert(manager.resourceOffer("exec1", "host1", 2, PROCESS_LOCAL) === None)
+
+ // Tell it the task has finished
+ manager.statusUpdate(0, TaskState.FINISHED, createTaskResult(0))
+ assert(sched.endedTasks(0) === Success)
+ assert(sched.finishedManagers.contains(manager))
+ }
+
+ test("multiple offers with no preferences") {
+ sc = new SparkContext("local", "test")
+ val sched = new FakeClusterScheduler(sc, ("exec1", "host1"))
+ val taskSet = createTaskSet(3)
+ val manager = new ClusterTaskSetManager(sched, taskSet)
+
+ // First three offers should all find tasks
+ for (i <- 0 until 3) {
+ val taskOption = manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL)
+ assert(taskOption.isDefined)
+ val task = taskOption.get
+ assert(task.executorId === "exec1")
+ }
+ assert(sched.startedTasks.toSet === Set(0, 1, 2))
+
+ // Re-offer the host -- now we should get no more tasks
+ assert(manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL) === None)
+
+ // Finish the first two tasks
+ manager.statusUpdate(0, TaskState.FINISHED, createTaskResult(0))
+ manager.statusUpdate(1, TaskState.FINISHED, createTaskResult(1))
+ assert(sched.endedTasks(0) === Success)
+ assert(sched.endedTasks(1) === Success)
+ assert(!sched.finishedManagers.contains(manager))
+
+ // Finish the last task
+ manager.statusUpdate(2, TaskState.FINISHED, createTaskResult(2))
+ assert(sched.endedTasks(2) === Success)
+ assert(sched.finishedManagers.contains(manager))
+ }
+
+ test("basic delay scheduling") {
+ sc = new SparkContext("local", "test")
+ val sched = new FakeClusterScheduler(sc, ("exec1", "host1"), ("exec2", "host2"))
+ val taskSet = createTaskSet(4,
+ Seq(TaskLocation("host1", "exec1")),
+ Seq(TaskLocation("host2", "exec2")),
+ Seq(TaskLocation("host1"), TaskLocation("host2", "exec2")),
+ Seq() // Last task has no locality prefs
+ )
+ val clock = new FakeClock
+ val manager = new ClusterTaskSetManager(sched, taskSet, clock)
+
+ // First offer host1, exec1: first task should be chosen
+ assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0)
+
+ // Offer host1, exec1 again: the last task, which has no prefs, should be chosen
+ assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 3)
+
+ // Offer host1, exec1 again, at PROCESS_LOCAL level: nothing should get chosen
+ assert(manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL) === None)
+
+ clock.advance(LOCALITY_WAIT)
+
+ // Offer host1, exec1 again, at PROCESS_LOCAL level: nothing should get chosen
+ assert(manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL) === None)
+
+ // Offer host1, exec1 again, at NODE_LOCAL level: we should choose task 2
+ assert(manager.resourceOffer("exec1", "host1", 1, NODE_LOCAL).get.index == 2)
+
+ // Offer host1, exec1 again, at NODE_LOCAL level: nothing should get chosen
+ assert(manager.resourceOffer("exec1", "host1", 1, NODE_LOCAL) === None)
+
+ // Offer host1, exec1 again, at ANY level: nothing should get chosen
+ assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None)
+
+ clock.advance(LOCALITY_WAIT)
+
+ // Offer host1, exec1 again, at ANY level: task 1 should get chosen
+ assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 1)
+
+ // Offer host1, exec1 again, at ANY level: nothing should be chosen as we've launched all tasks
+ assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None)
+ }
+
+ test("delay scheduling with fallback") {
+ sc = new SparkContext("local", "test")
+ val sched = new FakeClusterScheduler(sc,
+ ("exec1", "host1"), ("exec2", "host2"), ("exec3", "host3"))
+ val taskSet = createTaskSet(5,
+ Seq(TaskLocation("host1")),
+ Seq(TaskLocation("host2")),
+ Seq(TaskLocation("host2")),
+ Seq(TaskLocation("host3")),
+ Seq(TaskLocation("host2"))
+ )
+ val clock = new FakeClock
+ val manager = new ClusterTaskSetManager(sched, taskSet, clock)
+
+ // First offer host1: first task should be chosen
+ assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0)
+
+ // Offer host1 again: nothing should get chosen
+ assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None)
+
+ clock.advance(LOCALITY_WAIT)
+
+ // Offer host1 again: second task (on host2) should get chosen
+ assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 1)
+
+ // Offer host1 again: third task (on host2) should get chosen
+ assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 2)
+
+ // Offer host2: fifth task (also on host2) should get chosen
+ assert(manager.resourceOffer("exec2", "host2", 1, ANY).get.index === 4)
+
+ // Now that we've launched a local task, we should no longer launch the task for host3
+ assert(manager.resourceOffer("exec2", "host2", 1, ANY) === None)
+
+ clock.advance(LOCALITY_WAIT)
+
+ // After another delay, we can go ahead and launch that task non-locally
+ assert(manager.resourceOffer("exec2", "host2", 1, ANY).get.index === 3)
+ }
+
+ test("delay scheduling with failed hosts") {
+ sc = new SparkContext("local", "test")
+ val sched = new FakeClusterScheduler(sc, ("exec1", "host1"), ("exec2", "host2"))
+ val taskSet = createTaskSet(3,
+ Seq(TaskLocation("host1")),
+ Seq(TaskLocation("host2")),
+ Seq(TaskLocation("host3"))
+ )
+ val clock = new FakeClock
+ val manager = new ClusterTaskSetManager(sched, taskSet, clock)
+
+ // First offer host1: first task should be chosen
+ assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0)
+
+ // Offer host1 again: third task should be chosen immediately because host3 is not up
+ assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 2)
+
+ // After this, nothing should get chosen
+ assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None)
+
+ // Now mark host2 as dead
+ sched.removeExecutor("exec2")
+ manager.executorLost("exec2", "host2")
+
+ // Task 1 should immediately be launched on host1 because its original host is gone
+ assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 1)
+
+ // Now that all tasks have launched, nothing new should be launched anywhere else
+ assert(manager.resourceOffer("exec1", "host1", 1, ANY) === None)
+ assert(manager.resourceOffer("exec2", "host2", 1, ANY) === None)
+ }
+
+ /**
+ * Utility method to create a TaskSet, potentially setting a particular sequence of preferred
+ * locations for each task (given as varargs) if this sequence is not empty.
+ */
+ def createTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = {
+ if (prefLocs.size != 0 && prefLocs.size != numTasks) {
+ throw new IllegalArgumentException("Wrong number of task locations")
+ }
+ val tasks = Array.tabulate[Task[_]](numTasks) { i =>
+ new FakeTask(i, if (prefLocs.size != 0) prefLocs(i) else Nil)
+ }
+ new TaskSet(tasks, 0, 0, 0, null)
+ }
+
+ def createTaskResult(id: Int): ByteBuffer = {
+ ByteBuffer.wrap(Utils.serialize(new TaskResult[Int](id, mutable.Map.empty, new TaskMetrics)))
+ }
+}
diff --git a/core/src/test/scala/spark/scheduler/cluster/FakeTask.scala b/core/src/test/scala/spark/scheduler/cluster/FakeTask.scala
new file mode 100644
index 0000000000..de9e66be20
--- /dev/null
+++ b/core/src/test/scala/spark/scheduler/cluster/FakeTask.scala
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package spark.scheduler.cluster
+
+import spark.scheduler.{TaskLocation, Task}
+
+class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId) {
+ override def run(attemptId: Long): Int = 0
+
+ override def preferredLocations: Seq[TaskLocation] = prefLocs
+}
diff --git a/core/src/test/scala/spark/util/FakeClock.scala b/core/src/test/scala/spark/util/FakeClock.scala
new file mode 100644
index 0000000000..236706317e
--- /dev/null
+++ b/core/src/test/scala/spark/util/FakeClock.scala
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package spark.util
+
+class FakeClock extends Clock {
+ private var time = 0L
+
+ def advance(millis: Long): Unit = time += millis
+
+ def getTime(): Long = time
+}