diff options
author | Kay Ousterhout <kayousterhout@gmail.com> | 2013-11-13 14:32:50 -0800 |
---|---|---|
committer | Kay Ousterhout <kayousterhout@gmail.com> | 2013-11-13 14:32:50 -0800 |
commit | 68e5ad58b7e7e3e1b42852de8d0fdf9e9b9c1a14 (patch) | |
tree | 837719ad9bc7bb11cfc149964eeb0e65e629a942 /core | |
parent | fb64828b0b573f3a77938592f168af7aa3a2b6c5 (diff) | |
download | spark-68e5ad58b7e7e3e1b42852de8d0fdf9e9b9c1a14.tar.gz spark-68e5ad58b7e7e3e1b42852de8d0fdf9e9b9c1a14.tar.bz2 spark-68e5ad58b7e7e3e1b42852de8d0fdf9e9b9c1a14.zip |
Extracted TaskScheduler interface.
Also changed the default maximum number of task failures to be
0 when running in local mode.
Diffstat (limited to 'core')
14 files changed, 79 insertions, 73 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 1850436ff2..e8ff4da475 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -159,26 +159,26 @@ class SparkContext( master match { case "local" => - val scheduler = new TaskScheduler(this) + val scheduler = new ClusterScheduler(this, isLocal = true) val backend = new LocalBackend(scheduler, 1) scheduler.initialize(backend) scheduler case LOCAL_N_REGEX(threads) => - val scheduler = new TaskScheduler(this) + val scheduler = new ClusterScheduler(this, isLocal = true) val backend = new LocalBackend(scheduler, threads.toInt) scheduler.initialize(backend) scheduler case SPARK_REGEX(sparkUrl) => - val scheduler = new TaskScheduler(this) + val scheduler = new ClusterScheduler(this) val masterUrls = sparkUrl.split(",").map("spark://" + _) val backend = new SparkDeploySchedulerBackend(scheduler, this, masterUrls, appName) scheduler.initialize(backend) scheduler case SIMR_REGEX(simrUrl) => - val scheduler = new TaskScheduler(this) + val scheduler = new ClusterScheduler(this) val backend = new SimrSchedulerBackend(scheduler, this, simrUrl) scheduler.initialize(backend) scheduler @@ -192,7 +192,7 @@ class SparkContext( memoryPerSlaveInt, SparkContext.executorMemoryRequested)) } - val scheduler = new TaskScheduler(this) + val scheduler = new ClusterScheduler(this, isLocal = true) val localCluster = new LocalSparkCluster( numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt) val masterUrls = localCluster.start() @@ -207,7 +207,7 @@ class SparkContext( val scheduler = try { val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClusterScheduler") val cons = clazz.getConstructor(classOf[SparkContext]) - cons.newInstance(this).asInstanceOf[TaskScheduler] + cons.newInstance(this).asInstanceOf[ClusterScheduler] } catch { // TODO: Enumerate the exact reasons why it can fail // But irrespective of it, it means we cannot proceed ! @@ -221,7 +221,7 @@ class SparkContext( case MESOS_REGEX(mesosUrl) => MesosNativeLibrary.load() - val scheduler = new TaskScheduler(this) + val scheduler = new ClusterScheduler(this) val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean val backend = if (coarseGrained) { new CoarseMesosSchedulerBackend(scheduler, this, mesosUrl, appName) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala index b4ec695ece..c7d1295215 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala @@ -30,17 +30,13 @@ import org.apache.spark.TaskState.TaskState import org.apache.spark.scheduler.SchedulingMode.SchedulingMode /** - * Schedules tasks for a single SparkContext. Receives a set of tasks from the DAGScheduler for - * each stage, and is responsible for sending tasks to executors, running them, retrying if there - * are failures, and mitigating stragglers. Returns events to the DAGScheduler. - * - * Clients should first call initialize() and start(), then submit task sets through the - * runTasks method. - * - * This class can work with multiple types of clusters by acting through a SchedulerBackend. + * Schedules tasks for multiple types of clusters by acting through a SchedulerBackend. * It can also work with a local setup by using a LocalBackend and setting isLocal to true. * It handles common logic, like determining a scheduling order across jobs, waking up to launch * speculative tasks, etc. + * + * Clients should first call initialize() and start(), then submit task sets through the + * runTasks method. * * THREADING: SchedulerBackends and task-submitting clients can call this class from multiple * threads, so it needs locks in public API methods to maintain its state. In addition, some @@ -48,7 +44,9 @@ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode * acquire a lock on us, so we need to make sure that we don't try to lock the backend while * we are holding a lock on ourselves. */ -private[spark] class TaskScheduler(val sc: SparkContext, isLocal: Boolean = false) extends Logging { +private[spark] class ClusterScheduler(val sc: SparkContext, isLocal: Boolean = false) + extends TaskScheduler with Logging { + // How often to check for speculative tasks val SPECULATION_INTERVAL = System.getProperty("spark.speculation.interval", "100").toLong @@ -59,6 +57,15 @@ private[spark] class TaskScheduler(val sc: SparkContext, isLocal: Boolean = fals // on this class. val activeTaskSets = new HashMap[String, TaskSetManager] + val MAX_TASK_FAILURES = { + if (isLocal) { + // No sense in retrying if all tasks run locally! + 0 + } else { + System.getProperty("spark.task.maxFailures", "4").toInt + } + } + val taskIdToTaskSetId = new HashMap[Long, String] val taskIdToExecutorId = new HashMap[Long, String] val taskSetTaskIds = new HashMap[String, HashSet[Long]] @@ -95,7 +102,7 @@ private[spark] class TaskScheduler(val sc: SparkContext, isLocal: Boolean = fals // This is a var so that we can reset it for testing purposes. private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this) - def setDAGScheduler(dagScheduler: DAGScheduler) { + override def setDAGScheduler(dagScheduler: DAGScheduler) { this.dagScheduler = dagScheduler } @@ -116,7 +123,7 @@ private[spark] class TaskScheduler(val sc: SparkContext, isLocal: Boolean = fals def newTaskId(): Long = nextTaskId.getAndIncrement() - def start() { + override def start() { backend.start() if (!isLocal && System.getProperty("spark.speculation", "false").toBoolean) { @@ -138,11 +145,11 @@ private[spark] class TaskScheduler(val sc: SparkContext, isLocal: Boolean = fals } } - def submitTasks(taskSet: TaskSet) { + override def submitTasks(taskSet: TaskSet) { val tasks = taskSet.tasks logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks") this.synchronized { - val manager = new TaskSetManager(this, taskSet) + val manager = new TaskSetManager(this, taskSet, MAX_TASK_FAILURES) activeTaskSets(taskSet.id) = manager schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) taskSetTaskIds(taskSet.id) = new HashSet[Long]() @@ -165,7 +172,7 @@ private[spark] class TaskScheduler(val sc: SparkContext, isLocal: Boolean = fals backend.reviveOffers() } - def cancelTasks(stageId: Int): Unit = synchronized { + override def cancelTasks(stageId: Int): Unit = synchronized { logInfo("Cancelling stage " + stageId) activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) => // There are two possible cases here: @@ -351,7 +358,7 @@ private[spark] class TaskScheduler(val sc: SparkContext, isLocal: Boolean = fals } } - def stop() { + override def stop() { if (backend != null) { backend.stop() } @@ -364,7 +371,7 @@ private[spark] class TaskScheduler(val sc: SparkContext, isLocal: Boolean = fals Thread.sleep(5000L) } - def defaultParallelism() = backend.defaultParallelism() + override def defaultParallelism() = backend.defaultParallelism() // Check for speculatable tasks in all our active jobs. def checkSpeculatableTasks() { @@ -439,16 +446,10 @@ private[spark] class TaskScheduler(val sc: SparkContext, isLocal: Boolean = fals // By default, rack is unknown def getRackForHost(value: String): Option[String] = None - - /** - * Invoked after the system has successfully been initialized. YARN uses this to bootstrap - * allocation of resources based on preferred locations, wait for slave registrations, etc. - */ - def postStartHook() { } } -private[spark] object TaskScheduler { +private[spark] object ClusterScheduler { /** * Used to balance containers across hosts. * diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 5408fa7353..a77ff35323 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -28,7 +28,7 @@ import org.apache.spark.util.Utils /** * Runs a thread pool that deserializes and remotely fetches (if necessary) task results. */ -private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskScheduler) +private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterScheduler) extends Logging { private val THREADS = System.getProperty("spark.resultGetter.threads", "4").toInt private val getTaskResultExecutor = Utils.newDaemonFixedThreadPool( diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 90b6519027..8757d7fd2a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -40,19 +40,22 @@ import org.apache.spark.util.{SystemClock, Clock} * * THREADING: This class is designed to only be called from code with a lock on the * TaskScheduler (e.g. its event handlers). It should not be called from other threads. + * + * @param sched the ClusterScheduler associated with the TaskSetManager + * @param taskSet the TaskSet to manage scheduling for + * @param maxTaskFailures if any particular task fails more than this number of times, the entire + * task set will be aborted */ private[spark] class TaskSetManager( - sched: TaskScheduler, + sched: ClusterScheduler, val taskSet: TaskSet, + val maxTaskFailures: Int, clock: Clock = SystemClock) extends Schedulable with Logging { // CPUs to request per task val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toInt - // Maximum times a task is allowed to fail before failing the job - val MAX_TASK_FAILURES = System.getProperty("spark.task.maxFailures", "4").toInt - // Quantile of tasks at which to start speculation val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble @@ -521,10 +524,10 @@ private[spark] class TaskSetManager( addPendingTask(index) if (state != TaskState.KILLED) { numFailures(index) += 1 - if (numFailures(index) > MAX_TASK_FAILURES) { + if (numFailures(index) > maxTaskFailures) { logError("Task %s:%d failed more than %d times; aborting job".format( - taskSet.id, index, MAX_TASK_FAILURES)) - abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES)) + taskSet.id, index, maxTaskFailures)) + abort("Task %s:%d failed more than %d times".format(taskSet.id, index, maxTaskFailures)) } } } else { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index b8ac498527..f5548fc2da 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -29,7 +29,7 @@ import akka.util.Duration import akka.util.duration._ import org.apache.spark.{SparkException, Logging, TaskState} -import org.apache.spark.scheduler.{SchedulerBackend, SlaveLost, TaskDescription, TaskScheduler, +import org.apache.spark.scheduler.{SchedulerBackend, SlaveLost, TaskDescription, ClusterScheduler, WorkerOffer} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util.Utils @@ -43,7 +43,7 @@ import org.apache.spark.util.Utils * (spark.deploy.*). */ private[spark] -class CoarseGrainedSchedulerBackend(scheduler: TaskScheduler, actorSystem: ActorSystem) +class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: ActorSystem) extends SchedulerBackend with Logging { // Use an atomic variable to track total number of cores in the cluster for simplicity and speed diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala index a589e7456f..40fdfcddb1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala @@ -21,10 +21,10 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{Path, FileSystem} import org.apache.spark.{Logging, SparkContext} -import org.apache.spark.scheduler.TaskScheduler +import org.apache.spark.scheduler.ClusterScheduler private[spark] class SimrSchedulerBackend( - scheduler: TaskScheduler, + scheduler: ClusterScheduler, sc: SparkContext, driverFilePath: String) extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 15c600a1ec..acf15dbc40 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -22,11 +22,11 @@ import scala.collection.mutable.HashMap import org.apache.spark.{Logging, SparkContext} import org.apache.spark.deploy.client.{Client, ClientListener} import org.apache.spark.deploy.{Command, ApplicationDescription} -import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SlaveLost, TaskScheduler} +import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SlaveLost, ClusterScheduler} import org.apache.spark.util.Utils private[spark] class SparkDeploySchedulerBackend( - scheduler: TaskScheduler, + scheduler: ClusterScheduler, sc: SparkContext, masters: Array[String], appName: String) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 310da0027e..226ea46cc7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -30,7 +30,7 @@ import org.apache.mesos._ import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _} import org.apache.spark.{SparkException, Logging, SparkContext, TaskState} -import org.apache.spark.scheduler.TaskScheduler +import org.apache.spark.scheduler.ClusterScheduler import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend /** @@ -44,7 +44,7 @@ import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend * remove this. */ private[spark] class CoarseMesosSchedulerBackend( - scheduler: TaskScheduler, + scheduler: ClusterScheduler, sc: SparkContext, master: String, appName: String) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index c0e99df0b6..3acad1bb46 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -31,7 +31,7 @@ import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTas import org.apache.spark.{Logging, SparkException, SparkContext, TaskState} import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SchedulerBackend, SlaveLost, - TaskDescription, TaskScheduler, WorkerOffer} + TaskDescription, ClusterScheduler, WorkerOffer} import org.apache.spark.util.Utils /** @@ -40,7 +40,7 @@ import org.apache.spark.util.Utils * from multiple apps can run on different cores) and in time (a core can switch ownership). */ private[spark] class MesosSchedulerBackend( - scheduler: TaskScheduler, + scheduler: ClusterScheduler, sc: SparkContext, master: String, appName: String) diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 96c3a03602..3e9d31cd5e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -24,16 +24,17 @@ import akka.actor.{Actor, ActorRef, Props} import org.apache.spark.{SparkContext, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.{Executor, ExecutorBackend} -import org.apache.spark.scheduler.{SchedulerBackend, TaskScheduler, WorkerOffer} +import org.apache.spark.scheduler.{SchedulerBackend, ClusterScheduler, WorkerOffer} /** - * LocalBackend sits behind a TaskScheduler and handles launching tasks on a single Executor - * (created by the LocalBackend) running locally. + * LocalBackend is used when running a local version of Spark where the executor, backend, and + * master all run in the same JVM. It sits behind a ClusterScheduler and handles launching tasks + * on a single Executor (created by the LocalBackend) running locally. * * THREADING: Because methods can be called both from the Executor and the TaskScheduler, and * because the Executor class is not thread safe, all methods are synchronized. */ -private[spark] class LocalBackend(scheduler: TaskScheduler, private val totalCores: Int) +private[spark] class LocalBackend(scheduler: ClusterScheduler, private val totalCores: Int) extends SchedulerBackend with ExecutorBackend { private var freeCores = totalCores diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala index bfbffdf261..96adcf7198 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala @@ -29,9 +29,9 @@ class FakeTaskSetManager( initPriority: Int, initStageId: Int, initNumTasks: Int, - taskScheduler: TaskScheduler, + taskScheduler: ClusterScheduler, taskSet: TaskSet) - extends TaskSetManager(taskScheduler, taskSet) { + extends TaskSetManager(taskScheduler, taskSet, 1) { parent = null weight = 1 @@ -102,9 +102,9 @@ class FakeTaskSetManager( } } -class TaskSchedulerSuite extends FunSuite with LocalSparkContext with Logging { +class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging { - def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: TaskScheduler, taskSet: TaskSet): FakeTaskSetManager = { + def createDummyTaskSetManager(priority: Int, stage: Int, numTasks: Int, cs: ClusterScheduler, taskSet: TaskSet): FakeTaskSetManager = { new FakeTaskSetManager(priority, stage, numTasks, cs , taskSet) } @@ -131,7 +131,7 @@ class TaskSchedulerSuite extends FunSuite with LocalSparkContext with Logging { test("FIFO Scheduler Test") { sc = new SparkContext("local", "TaskSchedulerSuite") - val taskScheduler = new TaskScheduler(sc) + val taskScheduler = new ClusterScheduler(sc) var tasks = ArrayBuffer[Task[_]]() val task = new FakeTask(0) tasks += task @@ -158,7 +158,7 @@ class TaskSchedulerSuite extends FunSuite with LocalSparkContext with Logging { test("Fair Scheduler Test") { sc = new SparkContext("local", "TaskSchedulerSuite") - val taskScheduler = new TaskScheduler(sc) + val taskScheduler = new ClusterScheduler(sc) var tasks = ArrayBuffer[Task[_]]() val task = new FakeTask(0) tasks += task @@ -215,7 +215,7 @@ class TaskSchedulerSuite extends FunSuite with LocalSparkContext with Logging { test("Nested Pool Test") { sc = new SparkContext("local", "TaskSchedulerSuite") - val taskScheduler = new TaskScheduler(sc) + val taskScheduler = new ClusterScheduler(sc) var tasks = ArrayBuffer[Task[_]]() val task = new FakeTask(0) tasks += task diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 5b5a2178f3..24689a7093 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -37,7 +37,7 @@ import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} * TaskScheduler that records the task sets that the DAGScheduler requested executed. */ class TaskSetRecordingTaskScheduler(sc: SparkContext, - mapOutputTrackerMaster: MapOutputTrackerMaster) extends TaskScheduler(sc) { + mapOutputTrackerMaster: MapOutputTrackerMaster) extends ClusterScheduler(sc) { /** Set of TaskSets the DAGScheduler has requested executed. */ val taskSets = scala.collection.mutable.Buffer[TaskSet]() override def start() = {} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index 30e6bc5721..2ac2d7a36a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.storage.TaskResultBlockId * Used to test the case where a BlockManager evicts the task result (or dies) before the * TaskResult is retrieved. */ -class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskScheduler) +class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterScheduler) extends TaskResultGetter(sparkEnv, scheduler) { var removedResult = false @@ -91,8 +91,8 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndA test("task retried if result missing from block manager") { // If this test hangs, it's probably because no resource offers were made after the task // failed. - val scheduler: TaskScheduler = sc.taskScheduler match { - case clusterScheduler: TaskScheduler => + val scheduler: ClusterScheduler = sc.taskScheduler match { + case clusterScheduler: ClusterScheduler => clusterScheduler case _ => assert(false, "Expect local cluster to use TaskScheduler") diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index fe3ea7b594..592bb11364 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -58,7 +58,7 @@ class FakeDAGScheduler(taskScheduler: FakeTaskScheduler) extends DAGScheduler(ta * to work, and these are required for locality in TaskSetManager. */ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* execId, host */) - extends TaskScheduler(sc) + extends ClusterScheduler(sc) { val startedTasks = new ArrayBuffer[Long] val endedTasks = new mutable.HashMap[Long, TaskEndReason] @@ -82,12 +82,13 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { import TaskLocality.{ANY, PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL} val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong + val MAX_TASK_FAILURES = 4 test("TaskSet with no preferences") { sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) val taskSet = createTaskSet(1) - val manager = new TaskSetManager(sched, taskSet) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) // Offer a host with no CPUs assert(manager.resourceOffer("exec1", "host1", 0, ANY) === None) @@ -113,7 +114,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) val taskSet = createTaskSet(3) - val manager = new TaskSetManager(sched, taskSet) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) // First three offers should all find tasks for (i <- 0 until 3) { @@ -150,7 +151,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { Seq() // Last task has no locality prefs ) val clock = new FakeClock - val manager = new TaskSetManager(sched, taskSet, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // First offer host1, exec1: first task should be chosen assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) @@ -196,7 +197,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { Seq(TaskLocation("host2")) ) val clock = new FakeClock - val manager = new TaskSetManager(sched, taskSet, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // First offer host1: first task should be chosen assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) @@ -233,7 +234,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { Seq(TaskLocation("host3")) ) val clock = new FakeClock - val manager = new TaskSetManager(sched, taskSet, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // First offer host1: first task should be chosen assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) @@ -261,7 +262,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) val taskSet = createTaskSet(1) val clock = new FakeClock - val manager = new TaskSetManager(sched, taskSet, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) @@ -278,17 +279,17 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { val sched = new FakeTaskScheduler(sc, ("exec1", "host1")) val taskSet = createTaskSet(1) val clock = new FakeClock - val manager = new TaskSetManager(sched, taskSet, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) // Fail the task MAX_TASK_FAILURES times, and check that the task set is aborted // after the last failure. - (0 until manager.MAX_TASK_FAILURES).foreach { index => + (0 until MAX_TASK_FAILURES).foreach { index => val offerResult = manager.resourceOffer("exec1", "host1", 1, ANY) assert(offerResult != None, "Expect resource offer on iteration %s to return a task".format(index)) assert(offerResult.get.index === 0) manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, Some(TaskResultLost)) - if (index < manager.MAX_TASK_FAILURES) { + if (index < MAX_TASK_FAILURES) { assert(!sched.taskSetsFailed.contains(taskSet.id)) } else { assert(sched.taskSetsFailed.contains(taskSet.id)) |