aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala26
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala (renamed from core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala)49
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala (renamed from core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorLossReason.scala)2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala (renamed from core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala)6
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala (renamed from core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala)7
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala55
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala690
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala (renamed from core/src/main/scala/org/apache/spark/scheduler/cluster/WorkerOffer.scala)2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala703
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala74
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala219
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala191
-rw-r--r--core/src/test/scala/org/apache/spark/FailureSuite.scala20
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala (renamed from core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala)52
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala97
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala (renamed from core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala)3
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala19
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala (renamed from core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala)11
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala (renamed from core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala)46
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala227
-rw-r--r--yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala10
26 files changed, 939 insertions, 1589 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index d884095671..10db2fa7e7 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -56,10 +56,9 @@ import org.apache.spark.partial.{ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd._
import org.apache.spark.scheduler._
import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend,
- SparkDeploySchedulerBackend, ClusterScheduler, SimrSchedulerBackend}
+ SparkDeploySchedulerBackend, SimrSchedulerBackend}
import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
-import org.apache.spark.scheduler.local.LocalScheduler
-import org.apache.spark.scheduler.StageInfo
+import org.apache.spark.scheduler.local.LocalBackend
import org.apache.spark.storage.{BlockManagerSource, RDDInfo, StorageStatus, StorageUtils}
import org.apache.spark.ui.SparkUI
import org.apache.spark.util.{ClosureCleaner, MetadataCleaner, MetadataCleanerType,
@@ -157,8 +156,6 @@ class SparkContext(
private[spark] var taskScheduler: TaskScheduler = {
// Regular expression used for local[N] master format
val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r
- // Regular expression for local[N, maxRetries], used in tests with failing tasks
- val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r
// Regular expression for simulating a Spark cluster of [N, cores, memory] locally
val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r
// Regular expression for connecting to Spark deploy clusters
@@ -170,13 +167,16 @@ class SparkContext(
master match {
case "local" =>
- new LocalScheduler(1, 0, this)
+ val scheduler = new ClusterScheduler(this, isLocal = true)
+ val backend = new LocalBackend(scheduler, 1)
+ scheduler.initialize(backend)
+ scheduler
case LOCAL_N_REGEX(threads) =>
- new LocalScheduler(threads.toInt, 0, this)
-
- case LOCAL_N_FAILURES_REGEX(threads, maxFailures) =>
- new LocalScheduler(threads.toInt, maxFailures.toInt, this)
+ val scheduler = new ClusterScheduler(this, isLocal = true)
+ val backend = new LocalBackend(scheduler, threads.toInt)
+ scheduler.initialize(backend)
+ scheduler
case SPARK_REGEX(sparkUrl) =>
val scheduler = new ClusterScheduler(this)
@@ -200,7 +200,7 @@ class SparkContext(
memoryPerSlaveInt, SparkContext.executorMemoryRequested))
}
- val scheduler = new ClusterScheduler(this)
+ val scheduler = new ClusterScheduler(this, isLocal = true)
val localCluster = new LocalSparkCluster(
numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt)
val masterUrls = localCluster.start()
@@ -610,9 +610,7 @@ class SparkContext(
}
addedFiles(key) = System.currentTimeMillis
- // Fetch the file locally in case a job is executed locally.
- // Jobs that run through LocalScheduler will already fetch the required dependencies,
- // but jobs run in DAGScheduler.runLocally() will not so we must fetch the files here.
+ // Fetch the file locally in case a job is executed using DAGScheduler.runLocally().
Utils.fetchFile(path, new File(SparkFiles.getRootDirectory))
logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key))
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala
index 53a589615d..c5d7ca0481 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.scheduler.cluster
+package org.apache.spark.scheduler
import java.nio.ByteBuffer
import java.util.concurrent.atomic.AtomicLong
@@ -29,16 +29,16 @@ import akka.util.duration._
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
-import org.apache.spark.scheduler._
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
/**
- * The main TaskScheduler implementation, for running tasks on a cluster. Clients should first call
- * initialize() and start(), then submit task sets through the runTasks method.
- *
- * This class can work with multiple types of clusters by acting through a SchedulerBackend.
+ * Schedules tasks for multiple types of clusters by acting through a SchedulerBackend.
+ * It can also work with a local setup by using a LocalBackend and setting isLocal to true.
* It handles common logic, like determining a scheduling order across jobs, waking up to launch
* speculative tasks, etc.
+ *
+ * Clients should first call initialize() and start(), then submit task sets through the
+ * runTasks method.
*
* THREADING: SchedulerBackends and task-submitting clients can call this class from multiple
* threads, so it needs locks in public API methods to maintain its state. In addition, some
@@ -46,19 +46,27 @@ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
* acquire a lock on us, so we need to make sure that we don't try to lock the backend while
* we are holding a lock on ourselves.
*/
-private[spark] class ClusterScheduler(val sc: SparkContext)
- extends TaskScheduler
- with Logging
-{
+private[spark] class ClusterScheduler(val sc: SparkContext, isLocal: Boolean = false)
+ extends TaskScheduler with Logging {
+
// How often to check for speculative tasks
val SPECULATION_INTERVAL = System.getProperty("spark.speculation.interval", "100").toLong
// Threshold above which we warn user initial TaskSet may be starved
val STARVATION_TIMEOUT = System.getProperty("spark.starvation.timeout", "15000").toLong
- // ClusterTaskSetManagers are not thread safe, so any access to one should be synchronized
+ // TaskSetManagers are not thread safe, so any access to one should be synchronized
// on this class.
- val activeTaskSets = new HashMap[String, ClusterTaskSetManager]
+ val activeTaskSets = new HashMap[String, TaskSetManager]
+
+ val MAX_TASK_FAILURES = {
+ if (isLocal) {
+ // No sense in retrying if all tasks run locally!
+ 0
+ } else {
+ System.getProperty("spark.task.maxFailures", "4").toInt
+ }
+ }
val taskIdToTaskSetId = new HashMap[Long, String]
val taskIdToExecutorId = new HashMap[Long, String]
@@ -120,7 +128,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
override def start() {
backend.start()
- if (System.getProperty("spark.speculation", "false").toBoolean) {
+ if (!isLocal && System.getProperty("spark.speculation", "false").toBoolean) {
logInfo("Starting speculative execution thread")
sc.env.actorSystem.scheduler.schedule(SPECULATION_INTERVAL milliseconds,
@@ -134,12 +142,12 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
val tasks = taskSet.tasks
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
this.synchronized {
- val manager = new ClusterTaskSetManager(this, taskSet)
+ val manager = new TaskSetManager(this, taskSet, MAX_TASK_FAILURES)
activeTaskSets(taskSet.id) = manager
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
taskSetTaskIds(taskSet.id) = new HashSet[Long]()
- if (!hasReceivedTask) {
+ if (!isLocal && !hasReceivedTask) {
starvationTimer.scheduleAtFixedRate(new TimerTask() {
override def run() {
if (!hasLaunchedTask) {
@@ -299,19 +307,19 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
}
- def handleTaskGettingResult(taskSetManager: ClusterTaskSetManager, tid: Long) {
+ def handleTaskGettingResult(taskSetManager: TaskSetManager, tid: Long) {
taskSetManager.handleTaskGettingResult(tid)
}
def handleSuccessfulTask(
- taskSetManager: ClusterTaskSetManager,
+ taskSetManager: TaskSetManager,
tid: Long,
taskResult: DirectTaskResult[_]) = synchronized {
taskSetManager.handleSuccessfulTask(tid, taskResult)
}
def handleFailedTask(
- taskSetManager: ClusterTaskSetManager,
+ taskSetManager: TaskSetManager,
tid: Long,
taskState: TaskState,
reason: Option[TaskEndReason]) = synchronized {
@@ -337,7 +345,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
// No task sets are active but we still got an error. Just exit since this
// must mean the error is during registration.
// It might be good to do something smarter here in the future.
- logError("Exiting due to error from cluster scheduler: " + message)
+ logError("Exiting due to error from task scheduler: " + message)
System.exit(1)
}
}
@@ -358,7 +366,6 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
override def defaultParallelism() = backend.defaultParallelism()
-
// Check for speculatable tasks in all our active jobs.
def checkSpeculatableTasks() {
var shouldRevive = false
@@ -435,7 +442,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
-object ClusterScheduler {
+private[spark] object ClusterScheduler {
/**
* Used to balance containers across hosts.
*
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorLossReason.scala b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala
index 5077b2b48b..2bc43a9186 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorLossReason.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.scheduler.cluster
+package org.apache.spark.scheduler
import org.apache.spark.executor.ExecutorExitCode
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala
index 5367218faa..1f0839a0e1 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala
@@ -15,13 +15,13 @@
* limitations under the License.
*/
-package org.apache.spark.scheduler.cluster
+package org.apache.spark.scheduler
import org.apache.spark.SparkContext
/**
- * A backend interface for cluster scheduling systems that allows plugging in different ones under
- * ClusterScheduler. We assume a Mesos-like model where the application gets resource offers as
+ * A backend interface for scheduling systems that allows plugging in different ones under
+ * TaskScheduler. We assume a Mesos-like model where the application gets resource offers as
* machines become available and can launch tasks on them.
*/
private[spark] trait SchedulerBackend {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
index 2064d97b49..a77ff35323 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
@@ -15,14 +15,13 @@
* limitations under the License.
*/
-package org.apache.spark.scheduler.cluster
+package org.apache.spark.scheduler
import java.nio.ByteBuffer
import java.util.concurrent.{LinkedBlockingDeque, ThreadFactory, ThreadPoolExecutor, TimeUnit}
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
-import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, TaskResult}
import org.apache.spark.serializer.SerializerInstance
import org.apache.spark.util.Utils
@@ -42,7 +41,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterSche
}
def enqueueSuccessfulTask(
- taskSetManager: ClusterTaskSetManager, tid: Long, serializedData: ByteBuffer) {
+ taskSetManager: TaskSetManager, tid: Long, serializedData: ByteBuffer) {
getTaskResultExecutor.execute(new Runnable {
override def run() {
try {
@@ -78,7 +77,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterSche
})
}
- def enqueueFailedTask(taskSetManager: ClusterTaskSetManager, tid: Long, taskState: TaskState,
+ def enqueueFailedTask(taskSetManager: TaskSetManager, tid: Long, taskState: TaskState,
serializedData: ByteBuffer) {
var reason: Option[TaskEndReason] = None
getTaskResultExecutor.execute(new Runnable {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
deleted file mode 100644
index 10e0478108..0000000000
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
+++ /dev/null
@@ -1,55 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.scheduler
-
-import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
-
-/**
- * Low-level task scheduler interface, implemented by both ClusterScheduler and LocalScheduler.
- * Each TaskScheduler schedulers task for a single SparkContext.
- * These schedulers get sets of tasks submitted to them from the DAGScheduler for each stage,
- * and are responsible for sending the tasks to the cluster, running them, retrying if there
- * are failures, and mitigating stragglers. They return events to the DAGScheduler.
- */
-private[spark] trait TaskScheduler {
-
- def rootPool: Pool
-
- def schedulingMode: SchedulingMode
-
- def start(): Unit
-
- // Invoked after system has successfully initialized (typically in spark context).
- // Yarn uses this to bootstrap allocation of resources based on preferred locations, wait for slave registerations, etc.
- def postStartHook() { }
-
- // Disconnect from the cluster.
- def stop(): Unit
-
- // Submit a sequence of tasks to run.
- def submitTasks(taskSet: TaskSet): Unit
-
- // Cancel a stage.
- def cancelTasks(stageId: Int)
-
- // Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called.
- def setDAGScheduler(dagScheduler: DAGScheduler): Unit
-
- // Get the default level of parallelism to use in the cluster, as a hint for sizing jobs.
- def defaultParallelism(): Int
-}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 90f6bcefac..8757d7fd2a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -17,32 +17,692 @@
package org.apache.spark.scheduler
-import java.nio.ByteBuffer
+import java.util.Arrays
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
+import scala.math.max
+import scala.math.min
+
+import org.apache.spark.{ExceptionFailure, FetchFailed, Logging, Resubmitted, SparkEnv,
+ Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState}
import org.apache.spark.TaskState.TaskState
+import org.apache.spark.util.{SystemClock, Clock}
+
/**
- * Tracks and schedules the tasks within a single TaskSet. This class keeps track of the status of
- * each task and is responsible for retries on failure and locality. The main interfaces to it
- * are resourceOffer, which asks the TaskSet whether it wants to run a task on one node, and
- * statusUpdate, which tells it that one of its tasks changed state (e.g. finished).
+ * Schedules the tasks within a single TaskSet in the TaskScheduler. This class keeps track of
+ * each task, retries tasks if they fail (up to a limited number of times), and
+ * handles locality-aware scheduling for this TaskSet via delay scheduling. The main interfaces
+ * to it are resourceOffer, which asks the TaskSet whether it wants to run a task on one node,
+ * and statusUpdate, which tells it that one of its tasks changed state (e.g. finished).
+ *
+ * THREADING: This class is designed to only be called from code with a lock on the
+ * TaskScheduler (e.g. its event handlers). It should not be called from other threads.
*
- * THREADING: This class is designed to only be called from code with a lock on the TaskScheduler
- * (e.g. its event handlers). It should not be called from other threads.
+ * @param sched the ClusterScheduler associated with the TaskSetManager
+ * @param taskSet the TaskSet to manage scheduling for
+ * @param maxTaskFailures if any particular task fails more than this number of times, the entire
+ * task set will be aborted
*/
-private[spark] trait TaskSetManager extends Schedulable {
- def schedulableQueue = null
-
- def schedulingMode = SchedulingMode.NONE
-
- def taskSet: TaskSet
+private[spark] class TaskSetManager(
+ sched: ClusterScheduler,
+ val taskSet: TaskSet,
+ val maxTaskFailures: Int,
+ clock: Clock = SystemClock)
+ extends Schedulable with Logging
+{
+ // CPUs to request per task
+ val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toInt
+
+ // Quantile of tasks at which to start speculation
+ val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble
+ val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble
+
+ // Serializer for closures and tasks.
+ val env = SparkEnv.get
+ val ser = env.closureSerializer.newInstance()
+
+ val tasks = taskSet.tasks
+ val numTasks = tasks.length
+ val copiesRunning = new Array[Int](numTasks)
+ val successful = new Array[Boolean](numTasks)
+ val numFailures = new Array[Int](numTasks)
+ val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil)
+ var tasksSuccessful = 0
+
+ var weight = 1
+ var minShare = 0
+ var priority = taskSet.priority
+ var stageId = taskSet.stageId
+ var name = "TaskSet_"+taskSet.stageId.toString
+ var parent: Pool = null
+
+ var runningTasks = 0
+ private val runningTasksSet = new HashSet[Long]
+
+ // Set of pending tasks for each executor. These collections are actually
+ // treated as stacks, in which new tasks are added to the end of the
+ // ArrayBuffer and removed from the end. This makes it faster to detect
+ // tasks that repeatedly fail because whenever a task failed, it is put
+ // back at the head of the stack. They are also only cleaned up lazily;
+ // when a task is launched, it remains in all the pending lists except
+ // the one that it was launched from, but gets removed from them later.
+ private val pendingTasksForExecutor = new HashMap[String, ArrayBuffer[Int]]
+
+ // Set of pending tasks for each host. Similar to pendingTasksForExecutor,
+ // but at host level.
+ private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]]
+
+ // Set of pending tasks for each rack -- similar to the above.
+ private val pendingTasksForRack = new HashMap[String, ArrayBuffer[Int]]
+
+ // Set containing pending tasks with no locality preferences.
+ val pendingTasksWithNoPrefs = new ArrayBuffer[Int]
+
+ // Set containing all pending tasks (also used as a stack, as above).
+ val allPendingTasks = new ArrayBuffer[Int]
+
+ // Tasks that can be speculated. Since these will be a small fraction of total
+ // tasks, we'll just hold them in a HashSet.
+ val speculatableTasks = new HashSet[Int]
+
+ // Task index, start and finish time for each task attempt (indexed by task ID)
+ val taskInfos = new HashMap[Long, TaskInfo]
+
+ // Did the TaskSet fail?
+ var failed = false
+ var causeOfFailure = ""
+
+ // How frequently to reprint duplicate exceptions in full, in milliseconds
+ val EXCEPTION_PRINT_INTERVAL =
+ System.getProperty("spark.logging.exceptionPrintInterval", "10000").toLong
+
+ // Map of recent exceptions (identified by string representation and top stack frame) to
+ // duplicate count (how many times the same exception has appeared) and time the full exception
+ // was printed. This should ideally be an LRU map that can drop old exceptions automatically.
+ val recentExceptions = HashMap[String, (Int, Long)]()
+
+ // Figure out the current map output tracker epoch and set it on all tasks
+ val epoch = sched.mapOutputTracker.getEpoch
+ logDebug("Epoch for " + taskSet + ": " + epoch)
+ for (t <- tasks) {
+ t.epoch = epoch
+ }
+
+ // Add all our tasks to the pending lists. We do this in reverse order
+ // of task index so that tasks with low indices get launched first.
+ for (i <- (0 until numTasks).reverse) {
+ addPendingTask(i)
+ }
+
+ // Figure out which locality levels we have in our TaskSet, so we can do delay scheduling
+ val myLocalityLevels = computeValidLocalityLevels()
+ val localityWaits = myLocalityLevels.map(getLocalityWait) // Time to wait at each level
+
+ // Delay scheduling variables: we keep track of our current locality level and the time we
+ // last launched a task at that level, and move up a level when localityWaits[curLevel] expires.
+ // We then move down if we manage to launch a "more local" task.
+ var currentLocalityIndex = 0 // Index of our current locality level in validLocalityLevels
+ var lastLaunchTime = clock.getTime() // Time we last launched a task at this level
+
+ override def schedulableQueue = null
+
+ override def schedulingMode = SchedulingMode.NONE
+
+ /**
+ * Add a task to all the pending-task lists that it should be on. If readding is set, we are
+ * re-adding the task so only include it in each list if it's not already there.
+ */
+ private def addPendingTask(index: Int, readding: Boolean = false) {
+ // Utility method that adds `index` to a list only if readding=false or it's not already there
+ def addTo(list: ArrayBuffer[Int]) {
+ if (!readding || !list.contains(index)) {
+ list += index
+ }
+ }
+
+ var hadAliveLocations = false
+ for (loc <- tasks(index).preferredLocations) {
+ for (execId <- loc.executorId) {
+ if (sched.isExecutorAlive(execId)) {
+ addTo(pendingTasksForExecutor.getOrElseUpdate(execId, new ArrayBuffer))
+ hadAliveLocations = true
+ }
+ }
+ if (sched.hasExecutorsAliveOnHost(loc.host)) {
+ addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer))
+ for (rack <- sched.getRackForHost(loc.host)) {
+ addTo(pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer))
+ }
+ hadAliveLocations = true
+ }
+ }
+
+ if (!hadAliveLocations) {
+ // Even though the task might've had preferred locations, all of those hosts or executors
+ // are dead; put it in the no-prefs list so we can schedule it elsewhere right away.
+ addTo(pendingTasksWithNoPrefs)
+ }
+
+ if (!readding) {
+ allPendingTasks += index // No point scanning this whole list to find the old task there
+ }
+ }
+
+ /**
+ * Return the pending tasks list for a given executor ID, or an empty list if
+ * there is no map entry for that host
+ */
+ private def getPendingTasksForExecutor(executorId: String): ArrayBuffer[Int] = {
+ pendingTasksForExecutor.getOrElse(executorId, ArrayBuffer())
+ }
+
+ /**
+ * Return the pending tasks list for a given host, or an empty list if
+ * there is no map entry for that host
+ */
+ private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = {
+ pendingTasksForHost.getOrElse(host, ArrayBuffer())
+ }
+
+ /**
+ * Return the pending rack-local task list for a given rack, or an empty list if
+ * there is no map entry for that rack
+ */
+ private def getPendingTasksForRack(rack: String): ArrayBuffer[Int] = {
+ pendingTasksForRack.getOrElse(rack, ArrayBuffer())
+ }
+
+ /**
+ * Dequeue a pending task from the given list and return its index.
+ * Return None if the list is empty.
+ * This method also cleans up any tasks in the list that have already
+ * been launched, since we want that to happen lazily.
+ */
+ private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
+ while (!list.isEmpty) {
+ val index = list.last
+ list.trimEnd(1)
+ if (copiesRunning(index) == 0 && !successful(index)) {
+ return Some(index)
+ }
+ }
+ return None
+ }
+
+ /** Check whether a task is currently running an attempt on a given host */
+ private def hasAttemptOnHost(taskIndex: Int, host: String): Boolean = {
+ !taskAttempts(taskIndex).exists(_.host == host)
+ }
+
+ /**
+ * Return a speculative task for a given executor if any are available. The task should not have
+ * an attempt running on this host, in case the host is slow. In addition, the task should meet
+ * the given locality constraint.
+ */
+ private def findSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value)
+ : Option[(Int, TaskLocality.Value)] =
+ {
+ speculatableTasks.retain(index => !successful(index)) // Remove finished tasks from set
+
+ if (!speculatableTasks.isEmpty) {
+ // Check for process-local or preference-less tasks; note that tasks can be process-local
+ // on multiple nodes when we replicate cached blocks, as in Spark Streaming
+ for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
+ val prefs = tasks(index).preferredLocations
+ val executors = prefs.flatMap(_.executorId)
+ if (prefs.size == 0 || executors.contains(execId)) {
+ speculatableTasks -= index
+ return Some((index, TaskLocality.PROCESS_LOCAL))
+ }
+ }
+
+ // Check for node-local tasks
+ if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) {
+ for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
+ val locations = tasks(index).preferredLocations.map(_.host)
+ if (locations.contains(host)) {
+ speculatableTasks -= index
+ return Some((index, TaskLocality.NODE_LOCAL))
+ }
+ }
+ }
+
+ // Check for rack-local tasks
+ if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
+ for (rack <- sched.getRackForHost(host)) {
+ for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
+ val racks = tasks(index).preferredLocations.map(_.host).map(sched.getRackForHost)
+ if (racks.contains(rack)) {
+ speculatableTasks -= index
+ return Some((index, TaskLocality.RACK_LOCAL))
+ }
+ }
+ }
+ }
+ // Check for non-local tasks
+ if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
+ for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
+ speculatableTasks -= index
+ return Some((index, TaskLocality.ANY))
+ }
+ }
+ }
+
+ return None
+ }
+
+ /**
+ * Dequeue a pending task for a given node and return its index and locality level.
+ * Only search for tasks matching the given locality constraint.
+ */
+ private def findTask(execId: String, host: String, locality: TaskLocality.Value)
+ : Option[(Int, TaskLocality.Value)] =
+ {
+ for (index <- findTaskFromList(getPendingTasksForExecutor(execId))) {
+ return Some((index, TaskLocality.PROCESS_LOCAL))
+ }
+
+ if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) {
+ for (index <- findTaskFromList(getPendingTasksForHost(host))) {
+ return Some((index, TaskLocality.NODE_LOCAL))
+ }
+ }
+
+ if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
+ for {
+ rack <- sched.getRackForHost(host)
+ index <- findTaskFromList(getPendingTasksForRack(rack))
+ } {
+ return Some((index, TaskLocality.RACK_LOCAL))
+ }
+ }
+
+ // Look for no-pref tasks after rack-local tasks since they can run anywhere.
+ for (index <- findTaskFromList(pendingTasksWithNoPrefs)) {
+ return Some((index, TaskLocality.PROCESS_LOCAL))
+ }
+
+ if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
+ for (index <- findTaskFromList(allPendingTasks)) {
+ return Some((index, TaskLocality.ANY))
+ }
+ }
+
+ // Finally, if all else has failed, find a speculative task
+ return findSpeculativeTask(execId, host, locality)
+ }
+
+ /**
+ * Respond to an offer of a single executor from the scheduler by finding a task
+ */
def resourceOffer(
execId: String,
host: String,
availableCpus: Int,
maxLocality: TaskLocality.TaskLocality)
- : Option[TaskDescription]
+ : Option[TaskDescription] =
+ {
+ if (tasksSuccessful < numTasks && availableCpus >= CPUS_PER_TASK) {
+ val curTime = clock.getTime()
+
+ var allowedLocality = getAllowedLocalityLevel(curTime)
+ if (allowedLocality > maxLocality) {
+ allowedLocality = maxLocality // We're not allowed to search for farther-away tasks
+ }
+
+ findTask(execId, host, allowedLocality) match {
+ case Some((index, taskLocality)) => {
+ // Found a task; do some bookkeeping and return a task description
+ val task = tasks(index)
+ val taskId = sched.newTaskId()
+ // Figure out whether this should count as a preferred launch
+ logInfo("Starting task %s:%d as TID %s on executor %s: %s (%s)".format(
+ taskSet.id, index, taskId, execId, host, taskLocality))
+ // Do various bookkeeping
+ copiesRunning(index) += 1
+ val info = new TaskInfo(taskId, index, curTime, execId, host, taskLocality)
+ taskInfos(taskId) = info
+ taskAttempts(index) = info :: taskAttempts(index)
+ // Update our locality level for delay scheduling
+ currentLocalityIndex = getLocalityIndex(taskLocality)
+ lastLaunchTime = curTime
+ // Serialize and return the task
+ val startTime = clock.getTime()
+ // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here
+ // we assume the task can be serialized without exceptions.
+ val serializedTask = Task.serializeWithDependencies(
+ task, sched.sc.addedFiles, sched.sc.addedJars, ser)
+ val timeTaken = clock.getTime() - startTime
+ addRunningTask(taskId)
+ logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
+ taskSet.id, index, serializedTask.limit, timeTaken))
+ val taskName = "task %s:%d".format(taskSet.id, index)
+ if (taskAttempts(index).size == 1)
+ taskStarted(task,info)
+ return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask))
+ }
+ case _ =>
+ }
+ }
+ return None
+ }
+
+ /**
+ * Get the level we can launch tasks according to delay scheduling, based on current wait time.
+ */
+ private def getAllowedLocalityLevel(curTime: Long): TaskLocality.TaskLocality = {
+ while (curTime - lastLaunchTime >= localityWaits(currentLocalityIndex) &&
+ currentLocalityIndex < myLocalityLevels.length - 1)
+ {
+ // Jump to the next locality level, and remove our waiting time for the current one since
+ // we don't want to count it again on the next one
+ lastLaunchTime += localityWaits(currentLocalityIndex)
+ currentLocalityIndex += 1
+ }
+ myLocalityLevels(currentLocalityIndex)
+ }
+
+ /**
+ * Find the index in myLocalityLevels for a given locality. This is also designed to work with
+ * localities that are not in myLocalityLevels (in case we somehow get those) by returning the
+ * next-biggest level we have. Uses the fact that the last value in myLocalityLevels is ANY.
+ */
+ def getLocalityIndex(locality: TaskLocality.TaskLocality): Int = {
+ var index = 0
+ while (locality > myLocalityLevels(index)) {
+ index += 1
+ }
+ index
+ }
+
+ private def taskStarted(task: Task[_], info: TaskInfo) {
+ sched.dagScheduler.taskStarted(task, info)
+ }
+
+ def handleTaskGettingResult(tid: Long) = {
+ val info = taskInfos(tid)
+ info.markGettingResult()
+ sched.dagScheduler.taskGettingResult(tasks(info.index), info)
+ }
+
+ /**
+ * Marks the task as successful and notifies the DAGScheduler that a task has ended.
+ */
+ def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]) = {
+ val info = taskInfos(tid)
+ val index = info.index
+ info.markSuccessful()
+ removeRunningTask(tid)
+ if (!successful(index)) {
+ logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format(
+ tid, info.duration, info.host, tasksSuccessful, numTasks))
+ sched.dagScheduler.taskEnded(
+ tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
+
+ // Mark successful and stop if all the tasks have succeeded.
+ tasksSuccessful += 1
+ successful(index) = true
+ if (tasksSuccessful == numTasks) {
+ sched.taskSetFinished(this)
+ }
+ } else {
+ logInfo("Ignorning task-finished event for TID " + tid + " because task " +
+ index + " has already completed successfully")
+ }
+ }
+
+ /**
+ * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the
+ * DAG Scheduler.
+ */
+ def handleFailedTask(tid: Long, state: TaskState, reason: Option[TaskEndReason]) {
+ val info = taskInfos(tid)
+ if (info.failed) {
+ return
+ }
+ removeRunningTask(tid)
+ val index = info.index
+ info.markFailed()
+ if (!successful(index)) {
+ logWarning("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
+ copiesRunning(index) -= 1
+ // Check if the problem is a map output fetch failure. In that case, this
+ // task will never succeed on any node, so tell the scheduler about it.
+ reason.foreach {
+ case fetchFailed: FetchFailed =>
+ logWarning("Loss was due to fetch failure from " + fetchFailed.bmAddress)
+ sched.dagScheduler.taskEnded(tasks(index), fetchFailed, null, null, info, null)
+ successful(index) = true
+ tasksSuccessful += 1
+ sched.taskSetFinished(this)
+ removeAllRunningTasks()
+ return
+
+ case TaskKilled =>
+ logWarning("Task %d was killed.".format(tid))
+ sched.dagScheduler.taskEnded(tasks(index), reason.get, null, null, info, null)
+ return
+
+ case ef: ExceptionFailure =>
+ sched.dagScheduler.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null))
+ val key = ef.description
+ val now = clock.getTime()
+ val (printFull, dupCount) = {
+ if (recentExceptions.contains(key)) {
+ val (dupCount, printTime) = recentExceptions(key)
+ if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
+ recentExceptions(key) = (0, now)
+ (true, 0)
+ } else {
+ recentExceptions(key) = (dupCount + 1, printTime)
+ (false, dupCount + 1)
+ }
+ } else {
+ recentExceptions(key) = (0, now)
+ (true, 0)
+ }
+ }
+ if (printFull) {
+ val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString))
+ logWarning("Loss was due to %s\n%s\n%s".format(
+ ef.className, ef.description, locs.mkString("\n")))
+ } else {
+ logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount))
+ }
+
+ case TaskResultLost =>
+ logWarning("Lost result for TID %s on host %s".format(tid, info.host))
+ sched.dagScheduler.taskEnded(tasks(index), TaskResultLost, null, null, info, null)
+
+ case _ => {}
+ }
+ // On non-fetch failures, re-enqueue the task as pending for a max number of retries
+ addPendingTask(index)
+ if (state != TaskState.KILLED) {
+ numFailures(index) += 1
+ if (numFailures(index) > maxTaskFailures) {
+ logError("Task %s:%d failed more than %d times; aborting job".format(
+ taskSet.id, index, maxTaskFailures))
+ abort("Task %s:%d failed more than %d times".format(taskSet.id, index, maxTaskFailures))
+ }
+ }
+ } else {
+ logInfo("Ignoring task-lost event for TID " + tid +
+ " because task " + index + " is already finished")
+ }
+ }
+
+ def error(message: String) {
+ // Save the error message
+ abort("Error: " + message)
+ }
+
+ def abort(message: String) {
+ failed = true
+ causeOfFailure = message
+ // TODO: Kill running tasks if we were not terminated due to a Mesos error
+ sched.dagScheduler.taskSetFailed(taskSet, message)
+ removeAllRunningTasks()
+ sched.taskSetFinished(this)
+ }
+
+ /** If the given task ID is not in the set of running tasks, adds it.
+ *
+ * Used to keep track of the number of running tasks, for enforcing scheduling policies.
+ */
+ def addRunningTask(tid: Long) {
+ if (runningTasksSet.add(tid) && parent != null) {
+ parent.increaseRunningTasks(1)
+ }
+ runningTasks = runningTasksSet.size
+ }
+
+ /** If the given task ID is in the set of running tasks, removes it. */
+ def removeRunningTask(tid: Long) {
+ if (runningTasksSet.remove(tid) && parent != null) {
+ parent.decreaseRunningTasks(1)
+ }
+ runningTasks = runningTasksSet.size
+ }
+
+ private def removeAllRunningTasks() {
+ val numRunningTasks = runningTasksSet.size
+ runningTasksSet.clear()
+ if (parent != null) {
+ parent.decreaseRunningTasks(numRunningTasks)
+ }
+ runningTasks = 0
+ }
+
+ override def getSchedulableByName(name: String): Schedulable = {
+ return null
+ }
+
+ override def addSchedulable(schedulable: Schedulable) {}
+
+ override def removeSchedulable(schedulable: Schedulable) {}
+
+ override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
+ var sortedTaskSetQueue = ArrayBuffer[TaskSetManager](this)
+ sortedTaskSetQueue += this
+ return sortedTaskSetQueue
+ }
+
+ /** Called by TaskScheduler when an executor is lost so we can re-enqueue our tasks */
+ override def executorLost(execId: String, host: String) {
+ logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id)
+
+ // Re-enqueue pending tasks for this host based on the status of the cluster -- for example, a
+ // task that used to have locations on only this host might now go to the no-prefs list. Note
+ // that it's okay if we add a task to the same queue twice (if it had multiple preferred
+ // locations), because findTaskFromList will skip already-running tasks.
+ for (index <- getPendingTasksForExecutor(execId)) {
+ addPendingTask(index, readding=true)
+ }
+ for (index <- getPendingTasksForHost(host)) {
+ addPendingTask(index, readding=true)
+ }
+
+ // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage
+ if (tasks(0).isInstanceOf[ShuffleMapTask]) {
+ for ((tid, info) <- taskInfos if info.executorId == execId) {
+ val index = taskInfos(tid).index
+ if (successful(index)) {
+ successful(index) = false
+ copiesRunning(index) -= 1
+ tasksSuccessful -= 1
+ addPendingTask(index)
+ // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
+ // stage finishes when a total of tasks.size tasks finish.
+ sched.dagScheduler.taskEnded(tasks(index), Resubmitted, null, null, info, null)
+ }
+ }
+ }
+ // Also re-enqueue any tasks that were running on the node
+ for ((tid, info) <- taskInfos if info.running && info.executorId == execId) {
+ handleFailedTask(tid, TaskState.KILLED, None)
+ }
+ }
+
+ /**
+ * Check for tasks to be speculated and return true if there are any. This is called periodically
+ * by the TaskScheduler.
+ *
+ * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that
+ * we don't scan the whole task set. It might also help to make this sorted by launch time.
+ */
+ override def checkSpeculatableTasks(): Boolean = {
+ // Can't speculate if we only have one task, or if all tasks have finished.
+ if (numTasks == 1 || tasksSuccessful == numTasks) {
+ return false
+ }
+ var foundTasks = false
+ val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt
+ logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
+ if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) {
+ val time = clock.getTime()
+ val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray
+ Arrays.sort(durations)
+ val medianDuration = durations(min((0.5 * tasksSuccessful).round.toInt, durations.size - 1))
+ val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100)
+ // TODO: Threshold should also look at standard deviation of task durations and have a lower
+ // bound based on that.
+ logDebug("Task length threshold for speculation: " + threshold)
+ for ((tid, info) <- taskInfos) {
+ val index = info.index
+ if (!successful(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold &&
+ !speculatableTasks.contains(index)) {
+ logInfo(
+ "Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format(
+ taskSet.id, index, info.host, threshold))
+ speculatableTasks += index
+ foundTasks = true
+ }
+ }
+ }
+ return foundTasks
+ }
+
+ override def hasPendingTasks(): Boolean = {
+ numTasks > 0 && tasksSuccessful < numTasks
+ }
+
+ private def getLocalityWait(level: TaskLocality.TaskLocality): Long = {
+ val defaultWait = System.getProperty("spark.locality.wait", "3000")
+ level match {
+ case TaskLocality.PROCESS_LOCAL =>
+ System.getProperty("spark.locality.wait.process", defaultWait).toLong
+ case TaskLocality.NODE_LOCAL =>
+ System.getProperty("spark.locality.wait.node", defaultWait).toLong
+ case TaskLocality.RACK_LOCAL =>
+ System.getProperty("spark.locality.wait.rack", defaultWait).toLong
+ case TaskLocality.ANY =>
+ 0L
+ }
+ }
- def error(message: String)
+ /**
+ * Compute the locality levels used in this TaskSet. Assumes that all tasks have already been
+ * added to queues using addPendingTask.
+ */
+ private def computeValidLocalityLevels(): Array[TaskLocality.TaskLocality] = {
+ import TaskLocality.{PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY}
+ val levels = new ArrayBuffer[TaskLocality.TaskLocality]
+ if (!pendingTasksForExecutor.isEmpty && getLocalityWait(PROCESS_LOCAL) != 0) {
+ levels += PROCESS_LOCAL
+ }
+ if (!pendingTasksForHost.isEmpty && getLocalityWait(NODE_LOCAL) != 0) {
+ levels += NODE_LOCAL
+ }
+ if (!pendingTasksForRack.isEmpty && getLocalityWait(RACK_LOCAL) != 0) {
+ levels += RACK_LOCAL
+ }
+ levels += ANY
+ logDebug("Valid locality levels for " + taskSet + ": " + levels.mkString(", "))
+ levels.toArray
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/WorkerOffer.scala b/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala
index 938f62883a..ba6bab3f91 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/WorkerOffer.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.scheduler.cluster
+package org.apache.spark.scheduler
/**
* Represents free resources available on an executor.
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
deleted file mode 100644
index ee47aaffca..0000000000
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
+++ /dev/null
@@ -1,703 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.scheduler.cluster
-
-import java.util.Arrays
-
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.HashSet
-import scala.math.max
-import scala.math.min
-
-import org.apache.spark.{ExceptionFailure, FetchFailed, Logging, Resubmitted, SparkEnv,
- Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState}
-import org.apache.spark.TaskState.TaskState
-import org.apache.spark.scheduler._
-import org.apache.spark.util.{SystemClock, Clock}
-
-
-/**
- * Schedules the tasks within a single TaskSet in the ClusterScheduler. This class keeps track of
- * the status of each task, retries tasks if they fail (up to a limited number of times), and
- * handles locality-aware scheduling for this TaskSet via delay scheduling. The main interfaces
- * to it are resourceOffer, which asks the TaskSet whether it wants to run a task on one node,
- * and statusUpdate, which tells it that one of its tasks changed state (e.g. finished).
- *
- * THREADING: This class is designed to only be called from code with a lock on the
- * ClusterScheduler (e.g. its event handlers). It should not be called from other threads.
- */
-private[spark] class ClusterTaskSetManager(
- sched: ClusterScheduler,
- val taskSet: TaskSet,
- clock: Clock = SystemClock)
- extends TaskSetManager
- with Logging
-{
- // CPUs to request per task
- val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toInt
-
- // Maximum times a task is allowed to fail before failing the job
- val MAX_TASK_FAILURES = System.getProperty("spark.task.maxFailures", "4").toInt
-
- // Quantile of tasks at which to start speculation
- val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble
- val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble
-
- // Serializer for closures and tasks.
- val env = SparkEnv.get
- val ser = env.closureSerializer.newInstance()
-
- val tasks = taskSet.tasks
- val numTasks = tasks.length
- val copiesRunning = new Array[Int](numTasks)
- val successful = new Array[Boolean](numTasks)
- val numFailures = new Array[Int](numTasks)
- val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil)
- var tasksSuccessful = 0
-
- var weight = 1
- var minShare = 0
- var priority = taskSet.priority
- var stageId = taskSet.stageId
- var name = "TaskSet_"+taskSet.stageId.toString
- var parent: Pool = null
-
- var runningTasks = 0
- private val runningTasksSet = new HashSet[Long]
-
- // Set of pending tasks for each executor. These collections are actually
- // treated as stacks, in which new tasks are added to the end of the
- // ArrayBuffer and removed from the end. This makes it faster to detect
- // tasks that repeatedly fail because whenever a task failed, it is put
- // back at the head of the stack. They are also only cleaned up lazily;
- // when a task is launched, it remains in all the pending lists except
- // the one that it was launched from, but gets removed from them later.
- private val pendingTasksForExecutor = new HashMap[String, ArrayBuffer[Int]]
-
- // Set of pending tasks for each host. Similar to pendingTasksForExecutor,
- // but at host level.
- private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]]
-
- // Set of pending tasks for each rack -- similar to the above.
- private val pendingTasksForRack = new HashMap[String, ArrayBuffer[Int]]
-
- // Set containing pending tasks with no locality preferences.
- val pendingTasksWithNoPrefs = new ArrayBuffer[Int]
-
- // Set containing all pending tasks (also used as a stack, as above).
- val allPendingTasks = new ArrayBuffer[Int]
-
- // Tasks that can be speculated. Since these will be a small fraction of total
- // tasks, we'll just hold them in a HashSet.
- val speculatableTasks = new HashSet[Int]
-
- // Task index, start and finish time for each task attempt (indexed by task ID)
- val taskInfos = new HashMap[Long, TaskInfo]
-
- // Did the TaskSet fail?
- var failed = false
- var causeOfFailure = ""
-
- // How frequently to reprint duplicate exceptions in full, in milliseconds
- val EXCEPTION_PRINT_INTERVAL =
- System.getProperty("spark.logging.exceptionPrintInterval", "10000").toLong
-
- // Map of recent exceptions (identified by string representation and top stack frame) to
- // duplicate count (how many times the same exception has appeared) and time the full exception
- // was printed. This should ideally be an LRU map that can drop old exceptions automatically.
- val recentExceptions = HashMap[String, (Int, Long)]()
-
- // Figure out the current map output tracker epoch and set it on all tasks
- val epoch = sched.mapOutputTracker.getEpoch
- logDebug("Epoch for " + taskSet + ": " + epoch)
- for (t <- tasks) {
- t.epoch = epoch
- }
-
- // Add all our tasks to the pending lists. We do this in reverse order
- // of task index so that tasks with low indices get launched first.
- for (i <- (0 until numTasks).reverse) {
- addPendingTask(i)
- }
-
- // Figure out which locality levels we have in our TaskSet, so we can do delay scheduling
- val myLocalityLevels = computeValidLocalityLevels()
- val localityWaits = myLocalityLevels.map(getLocalityWait) // Time to wait at each level
-
- // Delay scheduling variables: we keep track of our current locality level and the time we
- // last launched a task at that level, and move up a level when localityWaits[curLevel] expires.
- // We then move down if we manage to launch a "more local" task.
- var currentLocalityIndex = 0 // Index of our current locality level in validLocalityLevels
- var lastLaunchTime = clock.getTime() // Time we last launched a task at this level
-
- /**
- * Add a task to all the pending-task lists that it should be on. If readding is set, we are
- * re-adding the task so only include it in each list if it's not already there.
- */
- private def addPendingTask(index: Int, readding: Boolean = false) {
- // Utility method that adds `index` to a list only if readding=false or it's not already there
- def addTo(list: ArrayBuffer[Int]) {
- if (!readding || !list.contains(index)) {
- list += index
- }
- }
-
- var hadAliveLocations = false
- for (loc <- tasks(index).preferredLocations) {
- for (execId <- loc.executorId) {
- if (sched.isExecutorAlive(execId)) {
- addTo(pendingTasksForExecutor.getOrElseUpdate(execId, new ArrayBuffer))
- hadAliveLocations = true
- }
- }
- if (sched.hasExecutorsAliveOnHost(loc.host)) {
- addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer))
- for (rack <- sched.getRackForHost(loc.host)) {
- addTo(pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer))
- }
- hadAliveLocations = true
- }
- }
-
- if (!hadAliveLocations) {
- // Even though the task might've had preferred locations, all of those hosts or executors
- // are dead; put it in the no-prefs list so we can schedule it elsewhere right away.
- addTo(pendingTasksWithNoPrefs)
- }
-
- if (!readding) {
- allPendingTasks += index // No point scanning this whole list to find the old task there
- }
- }
-
- /**
- * Return the pending tasks list for a given executor ID, or an empty list if
- * there is no map entry for that host
- */
- private def getPendingTasksForExecutor(executorId: String): ArrayBuffer[Int] = {
- pendingTasksForExecutor.getOrElse(executorId, ArrayBuffer())
- }
-
- /**
- * Return the pending tasks list for a given host, or an empty list if
- * there is no map entry for that host
- */
- private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = {
- pendingTasksForHost.getOrElse(host, ArrayBuffer())
- }
-
- /**
- * Return the pending rack-local task list for a given rack, or an empty list if
- * there is no map entry for that rack
- */
- private def getPendingTasksForRack(rack: String): ArrayBuffer[Int] = {
- pendingTasksForRack.getOrElse(rack, ArrayBuffer())
- }
-
- /**
- * Dequeue a pending task from the given list and return its index.
- * Return None if the list is empty.
- * This method also cleans up any tasks in the list that have already
- * been launched, since we want that to happen lazily.
- */
- private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
- while (!list.isEmpty) {
- val index = list.last
- list.trimEnd(1)
- if (copiesRunning(index) == 0 && !successful(index)) {
- return Some(index)
- }
- }
- return None
- }
-
- /** Check whether a task is currently running an attempt on a given host */
- private def hasAttemptOnHost(taskIndex: Int, host: String): Boolean = {
- !taskAttempts(taskIndex).exists(_.host == host)
- }
-
- /**
- * Return a speculative task for a given executor if any are available. The task should not have
- * an attempt running on this host, in case the host is slow. In addition, the task should meet
- * the given locality constraint.
- */
- private def findSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value)
- : Option[(Int, TaskLocality.Value)] =
- {
- speculatableTasks.retain(index => !successful(index)) // Remove finished tasks from set
-
- if (!speculatableTasks.isEmpty) {
- // Check for process-local or preference-less tasks; note that tasks can be process-local
- // on multiple nodes when we replicate cached blocks, as in Spark Streaming
- for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
- val prefs = tasks(index).preferredLocations
- val executors = prefs.flatMap(_.executorId)
- if (prefs.size == 0 || executors.contains(execId)) {
- speculatableTasks -= index
- return Some((index, TaskLocality.PROCESS_LOCAL))
- }
- }
-
- // Check for node-local tasks
- if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) {
- for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
- val locations = tasks(index).preferredLocations.map(_.host)
- if (locations.contains(host)) {
- speculatableTasks -= index
- return Some((index, TaskLocality.NODE_LOCAL))
- }
- }
- }
-
- // Check for rack-local tasks
- if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
- for (rack <- sched.getRackForHost(host)) {
- for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
- val racks = tasks(index).preferredLocations.map(_.host).map(sched.getRackForHost)
- if (racks.contains(rack)) {
- speculatableTasks -= index
- return Some((index, TaskLocality.RACK_LOCAL))
- }
- }
- }
- }
-
- // Check for non-local tasks
- if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
- for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
- speculatableTasks -= index
- return Some((index, TaskLocality.ANY))
- }
- }
- }
-
- return None
- }
-
- /**
- * Dequeue a pending task for a given node and return its index and locality level.
- * Only search for tasks matching the given locality constraint.
- */
- private def findTask(execId: String, host: String, locality: TaskLocality.Value)
- : Option[(Int, TaskLocality.Value)] =
- {
- for (index <- findTaskFromList(getPendingTasksForExecutor(execId))) {
- return Some((index, TaskLocality.PROCESS_LOCAL))
- }
-
- if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) {
- for (index <- findTaskFromList(getPendingTasksForHost(host))) {
- return Some((index, TaskLocality.NODE_LOCAL))
- }
- }
-
- if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
- for {
- rack <- sched.getRackForHost(host)
- index <- findTaskFromList(getPendingTasksForRack(rack))
- } {
- return Some((index, TaskLocality.RACK_LOCAL))
- }
- }
-
- // Look for no-pref tasks after rack-local tasks since they can run anywhere.
- for (index <- findTaskFromList(pendingTasksWithNoPrefs)) {
- return Some((index, TaskLocality.PROCESS_LOCAL))
- }
-
- if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
- for (index <- findTaskFromList(allPendingTasks)) {
- return Some((index, TaskLocality.ANY))
- }
- }
-
- // Finally, if all else has failed, find a speculative task
- return findSpeculativeTask(execId, host, locality)
- }
-
- /**
- * Respond to an offer of a single executor from the scheduler by finding a task
- */
- override def resourceOffer(
- execId: String,
- host: String,
- availableCpus: Int,
- maxLocality: TaskLocality.TaskLocality)
- : Option[TaskDescription] =
- {
- if (tasksSuccessful < numTasks && availableCpus >= CPUS_PER_TASK) {
- val curTime = clock.getTime()
-
- var allowedLocality = getAllowedLocalityLevel(curTime)
- if (allowedLocality > maxLocality) {
- allowedLocality = maxLocality // We're not allowed to search for farther-away tasks
- }
-
- findTask(execId, host, allowedLocality) match {
- case Some((index, taskLocality)) => {
- // Found a task; do some bookkeeping and return a task description
- val task = tasks(index)
- val taskId = sched.newTaskId()
- // Figure out whether this should count as a preferred launch
- logInfo("Starting task %s:%d as TID %s on executor %s: %s (%s)".format(
- taskSet.id, index, taskId, execId, host, taskLocality))
- // Do various bookkeeping
- copiesRunning(index) += 1
- val info = new TaskInfo(taskId, index, curTime, execId, host, taskLocality)
- taskInfos(taskId) = info
- taskAttempts(index) = info :: taskAttempts(index)
- // Update our locality level for delay scheduling
- currentLocalityIndex = getLocalityIndex(taskLocality)
- lastLaunchTime = curTime
- // Serialize and return the task
- val startTime = clock.getTime()
- // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here
- // we assume the task can be serialized without exceptions.
- val serializedTask = Task.serializeWithDependencies(
- task, sched.sc.addedFiles, sched.sc.addedJars, ser)
- val timeTaken = clock.getTime() - startTime
- addRunningTask(taskId)
- logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
- taskSet.id, index, serializedTask.limit, timeTaken))
- val taskName = "task %s:%d".format(taskSet.id, index)
- if (taskAttempts(index).size == 1)
- taskStarted(task,info)
- return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask))
- }
- case _ =>
- }
- }
- return None
- }
-
- /**
- * Get the level we can launch tasks according to delay scheduling, based on current wait time.
- */
- private def getAllowedLocalityLevel(curTime: Long): TaskLocality.TaskLocality = {
- while (curTime - lastLaunchTime >= localityWaits(currentLocalityIndex) &&
- currentLocalityIndex < myLocalityLevels.length - 1)
- {
- // Jump to the next locality level, and remove our waiting time for the current one since
- // we don't want to count it again on the next one
- lastLaunchTime += localityWaits(currentLocalityIndex)
- currentLocalityIndex += 1
- }
- myLocalityLevels(currentLocalityIndex)
- }
-
- /**
- * Find the index in myLocalityLevels for a given locality. This is also designed to work with
- * localities that are not in myLocalityLevels (in case we somehow get those) by returning the
- * next-biggest level we have. Uses the fact that the last value in myLocalityLevels is ANY.
- */
- def getLocalityIndex(locality: TaskLocality.TaskLocality): Int = {
- var index = 0
- while (locality > myLocalityLevels(index)) {
- index += 1
- }
- index
- }
-
- private def taskStarted(task: Task[_], info: TaskInfo) {
- sched.dagScheduler.taskStarted(task, info)
- }
-
- def handleTaskGettingResult(tid: Long) = {
- val info = taskInfos(tid)
- info.markGettingResult()
- sched.dagScheduler.taskGettingResult(tasks(info.index), info)
- }
-
- /**
- * Marks the task as successful and notifies the DAGScheduler that a task has ended.
- */
- def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]) = {
- val info = taskInfos(tid)
- val index = info.index
- info.markSuccessful()
- removeRunningTask(tid)
- if (!successful(index)) {
- logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format(
- tid, info.duration, info.host, tasksSuccessful, numTasks))
- sched.dagScheduler.taskEnded(
- tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
-
- // Mark successful and stop if all the tasks have succeeded.
- tasksSuccessful += 1
- successful(index) = true
- if (tasksSuccessful == numTasks) {
- sched.taskSetFinished(this)
- }
- } else {
- logInfo("Ignorning task-finished event for TID " + tid + " because task " +
- index + " has already completed successfully")
- }
- }
-
- /**
- * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the
- * DAG Scheduler.
- */
- def handleFailedTask(tid: Long, state: TaskState, reason: Option[TaskEndReason]) {
- val info = taskInfos(tid)
- if (info.failed) {
- return
- }
- removeRunningTask(tid)
- val index = info.index
- info.markFailed()
- if (!successful(index)) {
- logWarning("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
- copiesRunning(index) -= 1
- // Check if the problem is a map output fetch failure. In that case, this
- // task will never succeed on any node, so tell the scheduler about it.
- reason.foreach {
- case fetchFailed: FetchFailed =>
- logWarning("Loss was due to fetch failure from " + fetchFailed.bmAddress)
- sched.dagScheduler.taskEnded(tasks(index), fetchFailed, null, null, info, null)
- successful(index) = true
- tasksSuccessful += 1
- sched.taskSetFinished(this)
- removeAllRunningTasks()
- return
-
- case TaskKilled =>
- logWarning("Task %d was killed.".format(tid))
- sched.dagScheduler.taskEnded(tasks(index), reason.get, null, null, info, null)
- return
-
- case ef: ExceptionFailure =>
- sched.dagScheduler.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null))
- val key = ef.description
- val now = clock.getTime()
- val (printFull, dupCount) = {
- if (recentExceptions.contains(key)) {
- val (dupCount, printTime) = recentExceptions(key)
- if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
- recentExceptions(key) = (0, now)
- (true, 0)
- } else {
- recentExceptions(key) = (dupCount + 1, printTime)
- (false, dupCount + 1)
- }
- } else {
- recentExceptions(key) = (0, now)
- (true, 0)
- }
- }
- if (printFull) {
- val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString))
- logWarning("Loss was due to %s\n%s\n%s".format(
- ef.className, ef.description, locs.mkString("\n")))
- } else {
- logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount))
- }
-
- case TaskResultLost =>
- logWarning("Lost result for TID %s on host %s".format(tid, info.host))
- sched.dagScheduler.taskEnded(tasks(index), TaskResultLost, null, null, info, null)
-
- case _ => {}
- }
- // On non-fetch failures, re-enqueue the task as pending for a max number of retries
- addPendingTask(index)
- if (state != TaskState.KILLED) {
- numFailures(index) += 1
- if (numFailures(index) > MAX_TASK_FAILURES) {
- logError("Task %s:%d failed more than %d times; aborting job".format(
- taskSet.id, index, MAX_TASK_FAILURES))
- abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES))
- }
- }
- } else {
- logInfo("Ignoring task-lost event for TID " + tid +
- " because task " + index + " is already finished")
- }
- }
-
- override def error(message: String) {
- // Save the error message
- abort("Error: " + message)
- }
-
- def abort(message: String) {
- failed = true
- causeOfFailure = message
- // TODO: Kill running tasks if we were not terminated due to a Mesos error
- sched.dagScheduler.taskSetFailed(taskSet, message)
- removeAllRunningTasks()
- sched.taskSetFinished(this)
- }
-
- /** If the given task ID is not in the set of running tasks, adds it.
- *
- * Used to keep track of the number of running tasks, for enforcing scheduling policies.
- */
- def addRunningTask(tid: Long) {
- if (runningTasksSet.add(tid) && parent != null) {
- parent.increaseRunningTasks(1)
- }
- runningTasks = runningTasksSet.size
- }
-
- /** If the given task ID is in the set of running tasks, removes it. */
- def removeRunningTask(tid: Long) {
- if (runningTasksSet.remove(tid) && parent != null) {
- parent.decreaseRunningTasks(1)
- }
- runningTasks = runningTasksSet.size
- }
-
- private def removeAllRunningTasks() {
- val numRunningTasks = runningTasksSet.size
- runningTasksSet.clear()
- if (parent != null) {
- parent.decreaseRunningTasks(numRunningTasks)
- }
- runningTasks = 0
- }
-
- override def getSchedulableByName(name: String): Schedulable = {
- return null
- }
-
- override def addSchedulable(schedulable: Schedulable) {}
-
- override def removeSchedulable(schedulable: Schedulable) {}
-
- override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
- var sortedTaskSetQueue = ArrayBuffer[TaskSetManager](this)
- sortedTaskSetQueue += this
- return sortedTaskSetQueue
- }
-
- /** Called by cluster scheduler when an executor is lost so we can re-enqueue our tasks */
- override def executorLost(execId: String, host: String) {
- logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id)
-
- // Re-enqueue pending tasks for this host based on the status of the cluster -- for example, a
- // task that used to have locations on only this host might now go to the no-prefs list. Note
- // that it's okay if we add a task to the same queue twice (if it had multiple preferred
- // locations), because findTaskFromList will skip already-running tasks.
- for (index <- getPendingTasksForExecutor(execId)) {
- addPendingTask(index, readding=true)
- }
- for (index <- getPendingTasksForHost(host)) {
- addPendingTask(index, readding=true)
- }
-
- // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage
- if (tasks(0).isInstanceOf[ShuffleMapTask]) {
- for ((tid, info) <- taskInfos if info.executorId == execId) {
- val index = taskInfos(tid).index
- if (successful(index)) {
- successful(index) = false
- copiesRunning(index) -= 1
- tasksSuccessful -= 1
- addPendingTask(index)
- // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
- // stage finishes when a total of tasks.size tasks finish.
- sched.dagScheduler.taskEnded(tasks(index), Resubmitted, null, null, info, null)
- }
- }
- }
- // Also re-enqueue any tasks that were running on the node
- for ((tid, info) <- taskInfos if info.running && info.executorId == execId) {
- handleFailedTask(tid, TaskState.KILLED, None)
- }
- }
-
- /**
- * Check for tasks to be speculated and return true if there are any. This is called periodically
- * by the ClusterScheduler.
- *
- * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that
- * we don't scan the whole task set. It might also help to make this sorted by launch time.
- */
- override def checkSpeculatableTasks(): Boolean = {
- // Can't speculate if we only have one task, or if all tasks have finished.
- if (numTasks == 1 || tasksSuccessful == numTasks) {
- return false
- }
- var foundTasks = false
- val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt
- logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
- if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) {
- val time = clock.getTime()
- val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray
- Arrays.sort(durations)
- val medianDuration = durations(min((0.5 * tasksSuccessful).round.toInt, durations.size - 1))
- val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100)
- // TODO: Threshold should also look at standard deviation of task durations and have a lower
- // bound based on that.
- logDebug("Task length threshold for speculation: " + threshold)
- for ((tid, info) <- taskInfos) {
- val index = info.index
- if (!successful(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold &&
- !speculatableTasks.contains(index)) {
- logInfo(
- "Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format(
- taskSet.id, index, info.host, threshold))
- speculatableTasks += index
- foundTasks = true
- }
- }
- }
- return foundTasks
- }
-
- override def hasPendingTasks(): Boolean = {
- numTasks > 0 && tasksSuccessful < numTasks
- }
-
- private def getLocalityWait(level: TaskLocality.TaskLocality): Long = {
- val defaultWait = System.getProperty("spark.locality.wait", "3000")
- level match {
- case TaskLocality.PROCESS_LOCAL =>
- System.getProperty("spark.locality.wait.process", defaultWait).toLong
- case TaskLocality.NODE_LOCAL =>
- System.getProperty("spark.locality.wait.node", defaultWait).toLong
- case TaskLocality.RACK_LOCAL =>
- System.getProperty("spark.locality.wait.rack", defaultWait).toLong
- case TaskLocality.ANY =>
- 0L
- }
- }
-
- /**
- * Compute the locality levels used in this TaskSet. Assumes that all tasks have already been
- * added to queues using addPendingTask.
- */
- private def computeValidLocalityLevels(): Array[TaskLocality.TaskLocality] = {
- import TaskLocality.{PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY}
- val levels = new ArrayBuffer[TaskLocality.TaskLocality]
- if (!pendingTasksForExecutor.isEmpty && getLocalityWait(PROCESS_LOCAL) != 0) {
- levels += PROCESS_LOCAL
- }
- if (!pendingTasksForHost.isEmpty && getLocalityWait(NODE_LOCAL) != 0) {
- levels += NODE_LOCAL
- }
- if (!pendingTasksForRack.isEmpty && getLocalityWait(RACK_LOCAL) != 0) {
- levels += RACK_LOCAL
- }
- levels += ANY
- logDebug("Valid locality levels for " + taskSet + ": " + levels.mkString(", "))
- levels.toArray
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index a45bee536c..3bb715e7d0 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -29,7 +29,8 @@ import akka.util.Duration
import akka.util.duration._
import org.apache.spark.{SparkException, Logging, TaskState}
-import org.apache.spark.scheduler.TaskDescription
+import org.apache.spark.scheduler.{SchedulerBackend, SlaveLost, TaskDescription, ClusterScheduler,
+ WorkerOffer}
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
import org.apache.spark.util.Utils
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
index 6b91935400..cec02e945c 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
@@ -19,7 +19,9 @@ package org.apache.spark.scheduler.cluster
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{Path, FileSystem}
+
import org.apache.spark.{Logging, SparkContext}
+import org.apache.spark.scheduler.ClusterScheduler
private[spark] class SimrSchedulerBackend(
scheduler: ClusterScheduler,
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
index cefa970bb9..acf15dbc40 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
@@ -17,10 +17,12 @@
package org.apache.spark.scheduler.cluster
+import scala.collection.mutable.HashMap
+
import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.deploy.client.{Client, ClientListener}
import org.apache.spark.deploy.{Command, ApplicationDescription}
-import scala.collection.mutable.HashMap
+import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SlaveLost, ClusterScheduler}
import org.apache.spark.util.Utils
private[spark] class SparkDeploySchedulerBackend(
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
index 300fe693f1..226ea46cc7 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
@@ -30,7 +30,8 @@ import org.apache.mesos._
import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _}
import org.apache.spark.{SparkException, Logging, SparkContext, TaskState}
-import org.apache.spark.scheduler.cluster.{ClusterScheduler, CoarseGrainedSchedulerBackend}
+import org.apache.spark.scheduler.ClusterScheduler
+import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
/**
* A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
index 50cbc2ca92..3acad1bb46 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
@@ -30,9 +30,8 @@ import org.apache.mesos._
import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _}
import org.apache.spark.{Logging, SparkException, SparkContext, TaskState}
-import org.apache.spark.scheduler.TaskDescription
-import org.apache.spark.scheduler.cluster.{ClusterScheduler, ExecutorExited, ExecutorLossReason}
-import org.apache.spark.scheduler.cluster.{SchedulerBackend, SlaveLost, WorkerOffer}
+import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SchedulerBackend, SlaveLost,
+ TaskDescription, ClusterScheduler, WorkerOffer}
import org.apache.spark.util.Utils
/**
@@ -210,7 +209,7 @@ private[spark] class MesosSchedulerBackend(
getResource(offer.getResourcesList, "cpus").toInt)
}
- // Call into the ClusterScheduler
+ // Call into the TaskScheduler
val taskLists = scheduler.resourceOffers(offerableWorkers)
// Build a list of Mesos tasks for each slave
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
new file mode 100644
index 0000000000..3e9d31cd5e
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.local
+
+import java.nio.ByteBuffer
+
+import akka.actor.{Actor, ActorRef, Props}
+
+import org.apache.spark.{SparkContext, SparkEnv, TaskState}
+import org.apache.spark.TaskState.TaskState
+import org.apache.spark.executor.{Executor, ExecutorBackend}
+import org.apache.spark.scheduler.{SchedulerBackend, ClusterScheduler, WorkerOffer}
+
+/**
+ * LocalBackend is used when running a local version of Spark where the executor, backend, and
+ * master all run in the same JVM. It sits behind a ClusterScheduler and handles launching tasks
+ * on a single Executor (created by the LocalBackend) running locally.
+ *
+ * THREADING: Because methods can be called both from the Executor and the TaskScheduler, and
+ * because the Executor class is not thread safe, all methods are synchronized.
+ */
+private[spark] class LocalBackend(scheduler: ClusterScheduler, private val totalCores: Int)
+ extends SchedulerBackend with ExecutorBackend {
+
+ private var freeCores = totalCores
+
+ private val localExecutorId = "localhost"
+ private val localExecutorHostname = "localhost"
+
+ val executor = new Executor(localExecutorId, localExecutorHostname, Seq.empty, isLocal = true)
+
+ override def start() {
+ }
+
+ override def stop() {
+ }
+
+ override def reviveOffers() = synchronized {
+ val offers = Seq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores))
+ for (task <- scheduler.resourceOffers(offers).flatten) {
+ freeCores -= 1
+ executor.launchTask(this, task.taskId, task.serializedTask)
+ }
+ }
+
+ override def defaultParallelism() = totalCores
+
+ override def killTask(taskId: Long, executorId: String) = synchronized {
+ executor.killTask(taskId)
+ }
+
+ override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) = synchronized {
+ scheduler.statusUpdate(taskId, state, serializedData)
+ if (TaskState.isFinished(state)) {
+ freeCores += 1
+ reviveOffers()
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
deleted file mode 100644
index 2699f0b33e..0000000000
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
+++ /dev/null
@@ -1,219 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.scheduler.local
-
-import java.nio.ByteBuffer
-import java.util.concurrent.atomic.AtomicInteger
-
-import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
-
-import akka.actor._
-
-import org.apache.spark._
-import org.apache.spark.TaskState.TaskState
-import org.apache.spark.executor.{Executor, ExecutorBackend}
-import org.apache.spark.scheduler._
-import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
-
-
-/**
- * A FIFO or Fair TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
- * the scheduler also allows each task to fail up to maxFailures times, which is useful for
- * testing fault recovery.
- */
-
-private[local]
-case class LocalReviveOffers()
-
-private[local]
-case class LocalStatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer)
-
-private[local]
-case class KillTask(taskId: Long)
-
-private[spark]
-class LocalActor(localScheduler: LocalScheduler, private var freeCores: Int)
- extends Actor with Logging {
-
- val executor = new Executor("localhost", "localhost", Seq.empty, isLocal = true)
-
- def receive = {
- case LocalReviveOffers =>
- launchTask(localScheduler.resourceOffer(freeCores))
-
- case LocalStatusUpdate(taskId, state, serializeData) =>
- if (TaskState.isFinished(state)) {
- freeCores += 1
- launchTask(localScheduler.resourceOffer(freeCores))
- }
-
- case KillTask(taskId) =>
- executor.killTask(taskId)
- }
-
- private def launchTask(tasks: Seq[TaskDescription]) {
- for (task <- tasks) {
- freeCores -= 1
- executor.launchTask(localScheduler, task.taskId, task.serializedTask)
- }
- }
-}
-
-private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: SparkContext)
- extends TaskScheduler
- with ExecutorBackend
- with Logging {
-
- val env = SparkEnv.get
- val attemptId = new AtomicInteger
- var dagScheduler: DAGScheduler = null
-
- // Application dependencies (added through SparkContext) that we've fetched so far on this node.
- // Each map holds the master's timestamp for the version of that file or JAR we got.
- val currentFiles: HashMap[String, Long] = new HashMap[String, Long]()
- val currentJars: HashMap[String, Long] = new HashMap[String, Long]()
-
- var schedulableBuilder: SchedulableBuilder = null
- var rootPool: Pool = null
- val schedulingMode: SchedulingMode = SchedulingMode.withName(
- System.getProperty("spark.scheduler.mode", "FIFO"))
- val activeTaskSets = new HashMap[String, LocalTaskSetManager]
- val taskIdToTaskSetId = new HashMap[Long, String]
- val taskSetTaskIds = new HashMap[String, HashSet[Long]]
-
- var localActor: ActorRef = null
-
- override def start() {
- // temporarily set rootPool name to empty
- rootPool = new Pool("", schedulingMode, 0, 0)
- schedulableBuilder = {
- schedulingMode match {
- case SchedulingMode.FIFO =>
- new FIFOSchedulableBuilder(rootPool)
- case SchedulingMode.FAIR =>
- new FairSchedulableBuilder(rootPool)
- }
- }
- schedulableBuilder.buildPools()
-
- localActor = env.actorSystem.actorOf(Props(new LocalActor(this, threads)), "Test")
- }
-
- override def setDAGScheduler(dagScheduler: DAGScheduler) {
- this.dagScheduler = dagScheduler
- }
-
- override def submitTasks(taskSet: TaskSet) {
- synchronized {
- val manager = new LocalTaskSetManager(this, taskSet)
- schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
- activeTaskSets(taskSet.id) = manager
- taskSetTaskIds(taskSet.id) = new HashSet[Long]()
- localActor ! LocalReviveOffers
- }
- }
-
- override def cancelTasks(stageId: Int): Unit = synchronized {
- logInfo("Cancelling stage " + stageId)
- logInfo("Cancelling stage " + activeTaskSets.map(_._2.stageId))
- activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
- // There are two possible cases here:
- // 1. The task set manager has been created and some tasks have been scheduled.
- // In this case, send a kill signal to the executors to kill the task and then abort
- // the stage.
- // 2. The task set manager has been created but no tasks has been scheduled. In this case,
- // simply abort the stage.
- val taskIds = taskSetTaskIds(tsm.taskSet.id)
- if (taskIds.size > 0) {
- taskIds.foreach { tid =>
- localActor ! KillTask(tid)
- }
- }
- tsm.error("Stage %d was cancelled".format(stageId))
- }
- }
-
- def resourceOffer(freeCores: Int): Seq[TaskDescription] = {
- synchronized {
- var freeCpuCores = freeCores
- val tasks = new ArrayBuffer[TaskDescription](freeCores)
- val sortedTaskSetQueue = rootPool.getSortedTaskSetQueue()
- for (manager <- sortedTaskSetQueue) {
- logDebug("parentName:%s,name:%s,runningTasks:%s".format(
- manager.parent.name, manager.name, manager.runningTasks))
- }
-
- var launchTask = false
- for (manager <- sortedTaskSetQueue) {
- do {
- launchTask = false
- manager.resourceOffer(null, null, freeCpuCores, null) match {
- case Some(task) =>
- tasks += task
- taskIdToTaskSetId(task.taskId) = manager.taskSet.id
- taskSetTaskIds(manager.taskSet.id) += task.taskId
- freeCpuCores -= 1
- launchTask = true
- case None => {}
- }
- } while(launchTask)
- }
- return tasks
- }
- }
-
- def taskSetFinished(manager: TaskSetManager) {
- synchronized {
- activeTaskSets -= manager.taskSet.id
- manager.parent.removeSchedulable(manager)
- logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name))
- taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
- taskSetTaskIds -= manager.taskSet.id
- }
- }
-
- override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) {
- if (TaskState.isFinished(state)) {
- synchronized {
- taskIdToTaskSetId.get(taskId) match {
- case Some(taskSetId) =>
- val taskSetManager = activeTaskSets(taskSetId)
- taskSetTaskIds(taskSetId) -= taskId
-
- state match {
- case TaskState.FINISHED =>
- taskSetManager.taskEnded(taskId, state, serializedData)
- case TaskState.FAILED =>
- taskSetManager.taskFailed(taskId, state, serializedData)
- case TaskState.KILLED =>
- taskSetManager.error("Task %d was killed".format(taskId))
- case _ => {}
- }
- case None =>
- logInfo("Ignoring update from TID " + taskId + " because its task set is gone")
- }
- }
- localActor ! LocalStatusUpdate(taskId, state, serializedData)
- }
- }
-
- override def stop() {
- }
-
- override def defaultParallelism() = threads
-}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
deleted file mode 100644
index 53bf78267e..0000000000
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
+++ /dev/null
@@ -1,191 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.scheduler.local
-
-import java.nio.ByteBuffer
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-
-import org.apache.spark.{ExceptionFailure, Logging, SparkEnv, SparkException, Success, TaskState}
-import org.apache.spark.TaskState.TaskState
-import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Pool, Schedulable, Task,
- TaskDescription, TaskInfo, TaskLocality, TaskResult, TaskSet, TaskSetManager}
-
-
-private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: TaskSet)
- extends TaskSetManager with Logging {
-
- var parent: Pool = null
- var weight: Int = 1
- var minShare: Int = 0
- var runningTasks: Int = 0
- var priority: Int = taskSet.priority
- var stageId: Int = taskSet.stageId
- var name: String = "TaskSet_" + taskSet.stageId.toString
-
- var failCount = new Array[Int](taskSet.tasks.size)
- val taskInfos = new HashMap[Long, TaskInfo]
- val numTasks = taskSet.tasks.size
- var numFinished = 0
- val env = SparkEnv.get
- val ser = env.closureSerializer.newInstance()
- val copiesRunning = new Array[Int](numTasks)
- val finished = new Array[Boolean](numTasks)
- val numFailures = new Array[Int](numTasks)
- val MAX_TASK_FAILURES = sched.maxFailures
-
- def increaseRunningTasks(taskNum: Int): Unit = {
- runningTasks += taskNum
- if (parent != null) {
- parent.increaseRunningTasks(taskNum)
- }
- }
-
- def decreaseRunningTasks(taskNum: Int): Unit = {
- runningTasks -= taskNum
- if (parent != null) {
- parent.decreaseRunningTasks(taskNum)
- }
- }
-
- override def addSchedulable(schedulable: Schedulable): Unit = {
- // nothing
- }
-
- override def removeSchedulable(schedulable: Schedulable): Unit = {
- // nothing
- }
-
- override def getSchedulableByName(name: String): Schedulable = {
- return null
- }
-
- override def executorLost(executorId: String, host: String): Unit = {
- // nothing
- }
-
- override def checkSpeculatableTasks() = true
-
- override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
- var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager]
- sortedTaskSetQueue += this
- return sortedTaskSetQueue
- }
-
- override def hasPendingTasks() = true
-
- def findTask(): Option[Int] = {
- for (i <- 0 to numTasks-1) {
- if (copiesRunning(i) == 0 && !finished(i)) {
- return Some(i)
- }
- }
- return None
- }
-
- override def resourceOffer(
- execId: String,
- host: String,
- availableCpus: Int,
- maxLocality: TaskLocality.TaskLocality)
- : Option[TaskDescription] =
- {
- SparkEnv.set(sched.env)
- logDebug("availableCpus:%d, numFinished:%d, numTasks:%d".format(
- availableCpus.toInt, numFinished, numTasks))
- if (availableCpus > 0 && numFinished < numTasks) {
- findTask() match {
- case Some(index) =>
- val taskId = sched.attemptId.getAndIncrement()
- val task = taskSet.tasks(index)
- val info = new TaskInfo(taskId, index, System.currentTimeMillis(), "local", "local:1",
- TaskLocality.NODE_LOCAL)
- taskInfos(taskId) = info
- // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here
- // we assume the task can be serialized without exceptions.
- val bytes = Task.serializeWithDependencies(
- task, sched.sc.addedFiles, sched.sc.addedJars, ser)
- logInfo("Size of task " + taskId + " is " + bytes.limit + " bytes")
- val taskName = "task %s:%d".format(taskSet.id, index)
- copiesRunning(index) += 1
- increaseRunningTasks(1)
- taskStarted(task, info)
- return Some(new TaskDescription(taskId, null, taskName, index, bytes))
- case None => {}
- }
- }
- return None
- }
-
- def taskStarted(task: Task[_], info: TaskInfo) {
- sched.dagScheduler.taskStarted(task, info)
- }
-
- def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) {
- val info = taskInfos(tid)
- val index = info.index
- val task = taskSet.tasks(index)
- info.markSuccessful()
- val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader) match {
- case directResult: DirectTaskResult[_] => directResult
- case IndirectTaskResult(blockId) => {
- throw new SparkException("Expect only DirectTaskResults when using LocalScheduler")
- }
- }
- result.metrics.resultSize = serializedData.limit()
- sched.dagScheduler.taskEnded(task, Success, result.value, result.accumUpdates, info,
- result.metrics)
- numFinished += 1
- decreaseRunningTasks(1)
- finished(index) = true
- if (numFinished == numTasks) {
- sched.taskSetFinished(this)
- }
- }
-
- def taskFailed(tid: Long, state: TaskState, serializedData: ByteBuffer) {
- val info = taskInfos(tid)
- val index = info.index
- val task = taskSet.tasks(index)
- info.markFailed()
- decreaseRunningTasks(1)
- val reason: ExceptionFailure = ser.deserialize[ExceptionFailure](
- serializedData, getClass.getClassLoader)
- sched.dagScheduler.taskEnded(task, reason, null, null, info, reason.metrics.getOrElse(null))
- if (!finished(index)) {
- copiesRunning(index) -= 1
- numFailures(index) += 1
- val locs = reason.stackTrace.map(loc => "\tat %s".format(loc.toString))
- logInfo("Loss was due to %s\n%s\n%s".format(
- reason.className, reason.description, locs.mkString("\n")))
- if (numFailures(index) > MAX_TASK_FAILURES) {
- val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format(
- taskSet.id, index, MAX_TASK_FAILURES, reason.description)
- decreaseRunningTasks(runningTasks)
- sched.dagScheduler.taskSetFailed(taskSet, errorMessage)
- // need to delete failed Taskset from schedule queue
- sched.taskSetFinished(this)
- }
- }
- }
-
- override def error(message: String) {
- sched.dagScheduler.taskSetFailed(taskSet, message)
- sched.taskSetFinished(this)
- }
-}
diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala
index af448fcb37..2f7d6dff38 100644
--- a/core/src/test/scala/org/apache/spark/FailureSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark
-import org.scalatest.FunSuite
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
import SparkContext._
import org.apache.spark.util.NonSerializable
@@ -37,12 +37,20 @@ object FailureSuiteState {
}
}
-class FailureSuite extends FunSuite with LocalSparkContext {
+class FailureSuite extends FunSuite with LocalSparkContext with BeforeAndAfterAll {
+
+ override def beforeAll {
+ System.setProperty("spark.task.maxFailures", "1")
+ }
+
+ override def afterAll {
+ System.clearProperty("spark.task.maxFailures")
+ }
// Run a 3-task map job in which task 1 deterministically fails once, and check
// whether the job completes successfully and we ran 4 tasks in total.
test("failure in a single-stage job") {
- sc = new SparkContext("local[1,1]", "test")
+ sc = new SparkContext("local[1]", "test")
val results = sc.makeRDD(1 to 3, 3).map { x =>
FailureSuiteState.synchronized {
FailureSuiteState.tasksRun += 1
@@ -62,7 +70,7 @@ class FailureSuite extends FunSuite with LocalSparkContext {
// Run a map-reduce job in which a reduce task deterministically fails once.
test("failure in a two-stage job") {
- sc = new SparkContext("local[1,1]", "test")
+ sc = new SparkContext("local[1]", "test")
val results = sc.makeRDD(1 to 3).map(x => (x, x)).groupByKey(3).map {
case (k, v) =>
FailureSuiteState.synchronized {
@@ -82,7 +90,7 @@ class FailureSuite extends FunSuite with LocalSparkContext {
}
test("failure because task results are not serializable") {
- sc = new SparkContext("local[1,1]", "test")
+ sc = new SparkContext("local[1]", "test")
val results = sc.makeRDD(1 to 3).map(x => new NonSerializable)
val thrown = intercept[SparkException] {
@@ -95,7 +103,7 @@ class FailureSuite extends FunSuite with LocalSparkContext {
}
test("failure because task closure is not serializable") {
- sc = new SparkContext("local[1,1]", "test")
+ sc = new SparkContext("local[1]", "test")
val a = new NonSerializable
// Non-serializable closure in the final result stage
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala
index 95d3553d91..96adcf7198 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala
@@ -15,14 +15,12 @@
* limitations under the License.
*/
-package org.apache.spark.scheduler.cluster
+package org.apache.spark.scheduler
import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfter
import org.apache.spark._
-import org.apache.spark.scheduler._
-import org.apache.spark.scheduler.cluster._
import scala.collection.mutable.ArrayBuffer
import java.util.Properties
@@ -31,9 +29,9 @@ class FakeTaskSetManager(
initPriority: Int,
initStageId: Int,
initNumTasks: Int,
- clusterScheduler: ClusterScheduler,
+ taskScheduler: ClusterScheduler,
taskSet: TaskSet)
- extends ClusterTaskSetManager(clusterScheduler, taskSet) {
+ extends TaskSetManager(taskScheduler, taskSet, 1) {
parent = null
weight = 1
@@ -132,8 +130,8 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging
}
test("FIFO Scheduler Test") {
- sc = new SparkContext("local", "ClusterSchedulerSuite")
- val clusterScheduler = new ClusterScheduler(sc)
+ sc = new SparkContext("local", "TaskSchedulerSuite")
+ val taskScheduler = new ClusterScheduler(sc)
var tasks = ArrayBuffer[Task[_]]()
val task = new FakeTask(0)
tasks += task
@@ -143,9 +141,9 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging
val schedulableBuilder = new FIFOSchedulableBuilder(rootPool)
schedulableBuilder.buildPools()
- val taskSetManager0 = createDummyTaskSetManager(0, 0, 2, clusterScheduler, taskSet)
- val taskSetManager1 = createDummyTaskSetManager(0, 1, 2, clusterScheduler, taskSet)
- val taskSetManager2 = createDummyTaskSetManager(0, 2, 2, clusterScheduler, taskSet)
+ val taskSetManager0 = createDummyTaskSetManager(0, 0, 2, taskScheduler, taskSet)
+ val taskSetManager1 = createDummyTaskSetManager(0, 1, 2, taskScheduler, taskSet)
+ val taskSetManager2 = createDummyTaskSetManager(0, 2, 2, taskScheduler, taskSet)
schedulableBuilder.addTaskSetManager(taskSetManager0, null)
schedulableBuilder.addTaskSetManager(taskSetManager1, null)
schedulableBuilder.addTaskSetManager(taskSetManager2, null)
@@ -159,8 +157,8 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging
}
test("Fair Scheduler Test") {
- sc = new SparkContext("local", "ClusterSchedulerSuite")
- val clusterScheduler = new ClusterScheduler(sc)
+ sc = new SparkContext("local", "TaskSchedulerSuite")
+ val taskScheduler = new ClusterScheduler(sc)
var tasks = ArrayBuffer[Task[_]]()
val task = new FakeTask(0)
tasks += task
@@ -188,15 +186,15 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging
val properties2 = new Properties()
properties2.setProperty("spark.scheduler.pool","2")
- val taskSetManager10 = createDummyTaskSetManager(1, 0, 1, clusterScheduler, taskSet)
- val taskSetManager11 = createDummyTaskSetManager(1, 1, 1, clusterScheduler, taskSet)
- val taskSetManager12 = createDummyTaskSetManager(1, 2, 2, clusterScheduler, taskSet)
+ val taskSetManager10 = createDummyTaskSetManager(1, 0, 1, taskScheduler, taskSet)
+ val taskSetManager11 = createDummyTaskSetManager(1, 1, 1, taskScheduler, taskSet)
+ val taskSetManager12 = createDummyTaskSetManager(1, 2, 2, taskScheduler, taskSet)
schedulableBuilder.addTaskSetManager(taskSetManager10, properties1)
schedulableBuilder.addTaskSetManager(taskSetManager11, properties1)
schedulableBuilder.addTaskSetManager(taskSetManager12, properties1)
- val taskSetManager23 = createDummyTaskSetManager(2, 3, 2, clusterScheduler, taskSet)
- val taskSetManager24 = createDummyTaskSetManager(2, 4, 2, clusterScheduler, taskSet)
+ val taskSetManager23 = createDummyTaskSetManager(2, 3, 2, taskScheduler, taskSet)
+ val taskSetManager24 = createDummyTaskSetManager(2, 4, 2, taskScheduler, taskSet)
schedulableBuilder.addTaskSetManager(taskSetManager23, properties2)
schedulableBuilder.addTaskSetManager(taskSetManager24, properties2)
@@ -216,8 +214,8 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging
}
test("Nested Pool Test") {
- sc = new SparkContext("local", "ClusterSchedulerSuite")
- val clusterScheduler = new ClusterScheduler(sc)
+ sc = new SparkContext("local", "TaskSchedulerSuite")
+ val taskScheduler = new ClusterScheduler(sc)
var tasks = ArrayBuffer[Task[_]]()
val task = new FakeTask(0)
tasks += task
@@ -239,23 +237,23 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging
pool1.addSchedulable(pool10)
pool1.addSchedulable(pool11)
- val taskSetManager000 = createDummyTaskSetManager(0, 0, 5, clusterScheduler, taskSet)
- val taskSetManager001 = createDummyTaskSetManager(0, 1, 5, clusterScheduler, taskSet)
+ val taskSetManager000 = createDummyTaskSetManager(0, 0, 5, taskScheduler, taskSet)
+ val taskSetManager001 = createDummyTaskSetManager(0, 1, 5, taskScheduler, taskSet)
pool00.addSchedulable(taskSetManager000)
pool00.addSchedulable(taskSetManager001)
- val taskSetManager010 = createDummyTaskSetManager(1, 2, 5, clusterScheduler, taskSet)
- val taskSetManager011 = createDummyTaskSetManager(1, 3, 5, clusterScheduler, taskSet)
+ val taskSetManager010 = createDummyTaskSetManager(1, 2, 5, taskScheduler, taskSet)
+ val taskSetManager011 = createDummyTaskSetManager(1, 3, 5, taskScheduler, taskSet)
pool01.addSchedulable(taskSetManager010)
pool01.addSchedulable(taskSetManager011)
- val taskSetManager100 = createDummyTaskSetManager(2, 4, 5, clusterScheduler, taskSet)
- val taskSetManager101 = createDummyTaskSetManager(2, 5, 5, clusterScheduler, taskSet)
+ val taskSetManager100 = createDummyTaskSetManager(2, 4, 5, taskScheduler, taskSet)
+ val taskSetManager101 = createDummyTaskSetManager(2, 5, 5, taskScheduler, taskSet)
pool10.addSchedulable(taskSetManager100)
pool10.addSchedulable(taskSetManager101)
- val taskSetManager110 = createDummyTaskSetManager(3, 6, 5, clusterScheduler, taskSet)
- val taskSetManager111 = createDummyTaskSetManager(3, 7, 5, clusterScheduler, taskSet)
+ val taskSetManager110 = createDummyTaskSetManager(3, 6, 5, taskScheduler, taskSet)
+ val taskSetManager111 = createDummyTaskSetManager(3, 7, 5, taskScheduler, taskSet)
pool11.addSchedulable(taskSetManager110)
pool11.addSchedulable(taskSetManager111)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 00f2fdd657..24689a7093 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -34,6 +34,25 @@ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster}
/**
+ * TaskScheduler that records the task sets that the DAGScheduler requested executed.
+ */
+class TaskSetRecordingTaskScheduler(sc: SparkContext,
+ mapOutputTrackerMaster: MapOutputTrackerMaster) extends ClusterScheduler(sc) {
+ /** Set of TaskSets the DAGScheduler has requested executed. */
+ val taskSets = scala.collection.mutable.Buffer[TaskSet]()
+ override def start() = {}
+ override def stop() = {}
+ override def submitTasks(taskSet: TaskSet) = {
+ // normally done by TaskSetManager
+ taskSet.tasks.foreach(_.epoch = mapOutputTrackerMaster.getEpoch)
+ taskSets += taskSet
+ }
+ override def cancelTasks(stageId: Int) {}
+ override def setDAGScheduler(dagScheduler: DAGScheduler) = {}
+ override def defaultParallelism() = 2
+}
+
+/**
* Tests for DAGScheduler. These tests directly call the event processing functions in DAGScheduler
* rather than spawning an event loop thread as happens in the real code. They use EasyMock
* to mock out two classes that DAGScheduler interacts with: TaskScheduler (to which TaskSets are
@@ -46,24 +65,7 @@ import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster}
* and capturing the resulting TaskSets from the mock TaskScheduler.
*/
class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {
-
- /** Set of TaskSets the DAGScheduler has requested executed. */
- val taskSets = scala.collection.mutable.Buffer[TaskSet]()
- val taskScheduler = new TaskScheduler() {
- override def rootPool: Pool = null
- override def schedulingMode: SchedulingMode = SchedulingMode.NONE
- override def start() = {}
- override def stop() = {}
- override def submitTasks(taskSet: TaskSet) = {
- // normally done by TaskSetManager
- taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch)
- taskSets += taskSet
- }
- override def cancelTasks(stageId: Int) {}
- override def setDAGScheduler(dagScheduler: DAGScheduler) = {}
- override def defaultParallelism() = 2
- }
-
+ var taskScheduler: TaskSetRecordingTaskScheduler = null
var mapOutputTracker: MapOutputTrackerMaster = null
var scheduler: DAGScheduler = null
@@ -96,10 +98,11 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
before {
sc = new SparkContext("local", "DAGSchedulerSuite")
- taskSets.clear()
+ mapOutputTracker = new MapOutputTrackerMaster()
+ taskScheduler = new TaskSetRecordingTaskScheduler(sc, mapOutputTracker)
+ taskScheduler.taskSets.clear()
cacheLocations.clear()
results.clear()
- mapOutputTracker = new MapOutputTrackerMaster()
scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null) {
override def runLocally(job: ActiveJob) {
// don't bother with the thread while unit testing
@@ -204,7 +207,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
test("run trivial job") {
val rdd = makeRdd(1, Nil)
submit(rdd, Array(0))
- complete(taskSets(0), List((Success, 42)))
+ complete(taskScheduler.taskSets(0), List((Success, 42)))
assert(results === Map(0 -> 42))
}
@@ -225,7 +228,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
val baseRdd = makeRdd(1, Nil)
val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd)))
submit(finalRdd, Array(0))
- complete(taskSets(0), Seq((Success, 42)))
+ complete(taskScheduler.taskSets(0), Seq((Success, 42)))
assert(results === Map(0 -> 42))
}
@@ -235,7 +238,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
cacheLocations(baseRdd.id -> 0) =
Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))
submit(finalRdd, Array(0))
- val taskSet = taskSets(0)
+ val taskSet = taskScheduler.taskSets(0)
assertLocations(taskSet, Seq(Seq("hostA", "hostB")))
complete(taskSet, Seq((Success, 42)))
assert(results === Map(0 -> 42))
@@ -243,7 +246,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
test("trivial job failure") {
submit(makeRdd(1, Nil), Array(0))
- failed(taskSets(0), "some failure")
+ failed(taskScheduler.taskSets(0), "some failure")
assert(failure.getMessage === "Job aborted: some failure")
}
@@ -253,12 +256,12 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
val shuffleId = shuffleDep.shuffleId
val reduceRdd = makeRdd(1, List(shuffleDep))
submit(reduceRdd, Array(0))
- complete(taskSets(0), Seq(
+ complete(taskScheduler.taskSets(0), Seq(
(Success, makeMapStatus("hostA", 1)),
(Success, makeMapStatus("hostB", 1))))
assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
- complete(taskSets(1), Seq((Success, 42)))
+ complete(taskScheduler.taskSets(1), Seq((Success, 42)))
assert(results === Map(0 -> 42))
}
@@ -268,11 +271,11 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
val shuffleId = shuffleDep.shuffleId
val reduceRdd = makeRdd(2, List(shuffleDep))
submit(reduceRdd, Array(0, 1))
- complete(taskSets(0), Seq(
+ complete(taskScheduler.taskSets(0), Seq(
(Success, makeMapStatus("hostA", 1)),
(Success, makeMapStatus("hostB", 1))))
// the 2nd ResultTask failed
- complete(taskSets(1), Seq(
+ complete(taskScheduler.taskSets(1), Seq(
(Success, 42),
(FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0), null)))
// this will get called
@@ -280,10 +283,10 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
// ask the scheduler to try it again
scheduler.resubmitFailedStages()
// have the 2nd attempt pass
- complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1))))
+ complete(taskScheduler.taskSets(2), Seq((Success, makeMapStatus("hostA", 1))))
// we can see both result blocks now
assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === Array("hostA", "hostB"))
- complete(taskSets(3), Seq((Success, 43)))
+ complete(taskScheduler.taskSets(3), Seq((Success, 43)))
assert(results === Map(0 -> 42, 1 -> 43))
}
@@ -299,7 +302,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
val newEpoch = mapOutputTracker.getEpoch
assert(newEpoch > oldEpoch)
val noAccum = Map[Long, Any]()
- val taskSet = taskSets(0)
+ val taskSet = taskScheduler.taskSets(0)
// should be ignored for being too old
runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum, null, null))
// should work because it's a non-failed host
@@ -311,7 +314,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum, null, null))
assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA")))
- complete(taskSets(1), Seq((Success, 42), (Success, 43)))
+ complete(taskScheduler.taskSets(1), Seq((Success, 42), (Success, 43)))
assert(results === Map(0 -> 42, 1 -> 43))
}
@@ -326,14 +329,14 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
runEvent(ExecutorLost("exec-hostA"))
// DAGScheduler will immediately resubmit the stage after it appears to have no pending tasks
// rather than marking it is as failed and waiting.
- complete(taskSets(0), Seq(
+ complete(taskScheduler.taskSets(0), Seq(
(Success, makeMapStatus("hostA", 1)),
(Success, makeMapStatus("hostB", 1))))
// have hostC complete the resubmitted task
- complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1))))
+ complete(taskScheduler.taskSets(1), Seq((Success, makeMapStatus("hostC", 1))))
assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB")))
- complete(taskSets(2), Seq((Success, 42)))
+ complete(taskScheduler.taskSets(2), Seq((Success, 42)))
assert(results === Map(0 -> 42))
}
@@ -345,23 +348,23 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
val finalRdd = makeRdd(1, List(shuffleDepTwo))
submit(finalRdd, Array(0))
// have the first stage complete normally
- complete(taskSets(0), Seq(
+ complete(taskScheduler.taskSets(0), Seq(
(Success, makeMapStatus("hostA", 2)),
(Success, makeMapStatus("hostB", 2))))
// have the second stage complete normally
- complete(taskSets(1), Seq(
+ complete(taskScheduler.taskSets(1), Seq(
(Success, makeMapStatus("hostA", 1)),
(Success, makeMapStatus("hostC", 1))))
// fail the third stage because hostA went down
- complete(taskSets(2), Seq(
+ complete(taskScheduler.taskSets(2), Seq(
(FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)))
// TODO assert this:
// blockManagerMaster.removeExecutor("exec-hostA")
// have DAGScheduler try again
scheduler.resubmitFailedStages()
- complete(taskSets(3), Seq((Success, makeMapStatus("hostA", 2))))
- complete(taskSets(4), Seq((Success, makeMapStatus("hostA", 1))))
- complete(taskSets(5), Seq((Success, 42)))
+ complete(taskScheduler.taskSets(3), Seq((Success, makeMapStatus("hostA", 2))))
+ complete(taskScheduler.taskSets(4), Seq((Success, makeMapStatus("hostA", 1))))
+ complete(taskScheduler.taskSets(5), Seq((Success, 42)))
assert(results === Map(0 -> 42))
}
@@ -375,24 +378,24 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD"))
cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC"))
// complete stage 2
- complete(taskSets(0), Seq(
+ complete(taskScheduler.taskSets(0), Seq(
(Success, makeMapStatus("hostA", 2)),
(Success, makeMapStatus("hostB", 2))))
// complete stage 1
- complete(taskSets(1), Seq(
+ complete(taskScheduler.taskSets(1), Seq(
(Success, makeMapStatus("hostA", 1)),
(Success, makeMapStatus("hostB", 1))))
// pretend stage 0 failed because hostA went down
- complete(taskSets(2), Seq(
+ complete(taskScheduler.taskSets(2), Seq(
(FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)))
// TODO assert this:
// blockManagerMaster.removeExecutor("exec-hostA")
// DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun.
scheduler.resubmitFailedStages()
- assertLocations(taskSets(3), Seq(Seq("hostD")))
+ assertLocations(taskScheduler.taskSets(3), Seq(Seq("hostD")))
// allow hostD to recover
- complete(taskSets(3), Seq((Success, makeMapStatus("hostD", 1))))
- complete(taskSets(4), Seq((Success, 42)))
+ complete(taskScheduler.taskSets(3), Seq((Success, makeMapStatus("hostD", 1))))
+ complete(taskScheduler.taskSets(4), Seq((Success, 42)))
assert(results === Map(0 -> 42))
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
index 0f01515179..0b90c4e74c 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/cluster/FakeTask.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala
@@ -15,10 +15,9 @@
* limitations under the License.
*/
-package org.apache.spark.scheduler.cluster
+package org.apache.spark.scheduler
import org.apache.spark.TaskContext
-import org.apache.spark.scheduler.{TaskLocation, Task}
class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0) {
override def runTask(context: TaskContext): Int = 0
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
index 1fd76420ea..f3e592bf5c 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
@@ -19,23 +19,26 @@ package org.apache.spark.scheduler
import scala.collection.mutable.{Buffer, HashSet}
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
import org.scalatest.matchers.ShouldMatchers
import org.apache.spark.{LocalSparkContext, SparkContext}
import org.apache.spark.SparkContext._
class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
- with BeforeAndAfterAll {
+ with BeforeAndAfter with BeforeAndAfterAll {
/** Length of time to wait while draining listener events. */
val WAIT_TIMEOUT_MILLIS = 10000
+ before {
+ sc = new SparkContext("local", "SparkListenerSuite")
+ }
+
override def afterAll {
System.clearProperty("spark.akka.frameSize")
}
test("basic creation of StageInfo") {
- sc = new SparkContext("local", "DAGSchedulerSuite")
val listener = new SaveStageInfo
sc.addSparkListener(listener)
val rdd1 = sc.parallelize(1 to 100, 4)
@@ -56,7 +59,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
}
test("StageInfo with fewer tasks than partitions") {
- sc = new SparkContext("local", "DAGSchedulerSuite")
val listener = new SaveStageInfo
sc.addSparkListener(listener)
val rdd1 = sc.parallelize(1 to 100, 4)
@@ -72,7 +74,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
}
test("local metrics") {
- sc = new SparkContext("local", "DAGSchedulerSuite")
val listener = new SaveStageInfo
sc.addSparkListener(listener)
sc.addSparkListener(new StatsReportListener)
@@ -135,10 +136,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
}
test("onTaskGettingResult() called when result fetched remotely") {
- // Need to use local cluster mode here, because results are not ever returned through the
- // block manager when using the LocalScheduler.
- sc = new SparkContext("local-cluster[1,1,512]", "test")
-
val listener = new SaveTaskEvents
sc.addSparkListener(listener)
@@ -157,10 +154,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
}
test("onTaskGettingResult() not called when result sent directly") {
- // Need to use local cluster mode here, because results are not ever returned through the
- // block manager when using the LocalScheduler.
- sc = new SparkContext("local-cluster[1,1,512]", "test")
-
val listener = new SaveTaskEvents
sc.addSparkListener(listener)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
index ee150a3107..2ac2d7a36a 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
@@ -15,14 +15,13 @@
* limitations under the License.
*/
-package org.apache.spark.scheduler.cluster
+package org.apache.spark.scheduler
import java.nio.ByteBuffer
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
import org.apache.spark.{LocalSparkContext, SparkContext, SparkEnv}
-import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, TaskResult}
import org.apache.spark.storage.TaskResultBlockId
/**
@@ -36,7 +35,7 @@ class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterSched
var removedResult = false
override def enqueueSuccessfulTask(
- taskSetManager: ClusterTaskSetManager, tid: Long, serializedData: ByteBuffer) {
+ taskSetManager: TaskSetManager, tid: Long, serializedData: ByteBuffer) {
if (!removedResult) {
// Only remove the result once, since we'd like to test the case where the task eventually
// succeeds.
@@ -66,9 +65,7 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndA
}
before {
- // Use local-cluster mode because results are returned differently when running with the
- // LocalScheduler.
- sc = new SparkContext("local-cluster[1,1,512]", "test")
+ sc = new SparkContext("local", "test")
}
override def afterAll {
@@ -98,7 +95,7 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndA
case clusterScheduler: ClusterScheduler =>
clusterScheduler
case _ =>
- assert(false, "Expect local cluster to use ClusterScheduler")
+ assert(false, "Expect local cluster to use TaskScheduler")
throw new ClassCastException
}
scheduler.taskResultGetter = new ResultDeletingTaskResultGetter(sc.env, scheduler)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index b97f2b19b5..592bb11364 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.scheduler.cluster
+package org.apache.spark.scheduler
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable
@@ -23,12 +23,11 @@ import scala.collection.mutable
import org.scalatest.FunSuite
import org.apache.spark._
-import org.apache.spark.scheduler._
import org.apache.spark.executor.TaskMetrics
import java.nio.ByteBuffer
import org.apache.spark.util.{Utils, FakeClock}
-class FakeDAGScheduler(taskScheduler: FakeClusterScheduler) extends DAGScheduler(taskScheduler) {
+class FakeDAGScheduler(taskScheduler: FakeTaskScheduler) extends DAGScheduler(taskScheduler) {
override def taskStarted(task: Task[_], taskInfo: TaskInfo) {
taskScheduler.startedTasks += taskInfo.index
}
@@ -53,12 +52,12 @@ class FakeDAGScheduler(taskScheduler: FakeClusterScheduler) extends DAGScheduler
}
/**
- * A mock ClusterScheduler implementation that just remembers information about tasks started and
+ * A mock TaskScheduler implementation that just remembers information about tasks started and
* feedback received from the TaskSetManagers. Note that it's important to initialize this with
* a list of "live" executors and their hostnames for isExecutorAlive and hasExecutorsAliveOnHost
- * to work, and these are required for locality in ClusterTaskSetManager.
+ * to work, and these are required for locality in TaskSetManager.
*/
-class FakeClusterScheduler(sc: SparkContext, liveExecutors: (String, String)* /* execId, host */)
+class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* execId, host */)
extends ClusterScheduler(sc)
{
val startedTasks = new ArrayBuffer[Long]
@@ -79,16 +78,17 @@ class FakeClusterScheduler(sc: SparkContext, liveExecutors: (String, String)* /*
override def hasExecutorsAliveOnHost(host: String): Boolean = executors.values.exists(_ == host)
}
-class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
+class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
import TaskLocality.{ANY, PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL}
val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong
+ val MAX_TASK_FAILURES = 4
test("TaskSet with no preferences") {
sc = new SparkContext("local", "test")
- val sched = new FakeClusterScheduler(sc, ("exec1", "host1"))
+ val sched = new FakeTaskScheduler(sc, ("exec1", "host1"))
val taskSet = createTaskSet(1)
- val manager = new ClusterTaskSetManager(sched, taskSet)
+ val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES)
// Offer a host with no CPUs
assert(manager.resourceOffer("exec1", "host1", 0, ANY) === None)
@@ -112,9 +112,9 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo
test("multiple offers with no preferences") {
sc = new SparkContext("local", "test")
- val sched = new FakeClusterScheduler(sc, ("exec1", "host1"))
+ val sched = new FakeTaskScheduler(sc, ("exec1", "host1"))
val taskSet = createTaskSet(3)
- val manager = new ClusterTaskSetManager(sched, taskSet)
+ val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES)
// First three offers should all find tasks
for (i <- 0 until 3) {
@@ -143,7 +143,7 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo
test("basic delay scheduling") {
sc = new SparkContext("local", "test")
- val sched = new FakeClusterScheduler(sc, ("exec1", "host1"), ("exec2", "host2"))
+ val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2"))
val taskSet = createTaskSet(4,
Seq(TaskLocation("host1", "exec1")),
Seq(TaskLocation("host2", "exec2")),
@@ -151,7 +151,7 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo
Seq() // Last task has no locality prefs
)
val clock = new FakeClock
- val manager = new ClusterTaskSetManager(sched, taskSet, clock)
+ val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock)
// First offer host1, exec1: first task should be chosen
assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0)
@@ -187,7 +187,7 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo
test("delay scheduling with fallback") {
sc = new SparkContext("local", "test")
- val sched = new FakeClusterScheduler(sc,
+ val sched = new FakeTaskScheduler(sc,
("exec1", "host1"), ("exec2", "host2"), ("exec3", "host3"))
val taskSet = createTaskSet(5,
Seq(TaskLocation("host1")),
@@ -197,7 +197,7 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo
Seq(TaskLocation("host2"))
)
val clock = new FakeClock
- val manager = new ClusterTaskSetManager(sched, taskSet, clock)
+ val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock)
// First offer host1: first task should be chosen
assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0)
@@ -227,14 +227,14 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo
test("delay scheduling with failed hosts") {
sc = new SparkContext("local", "test")
- val sched = new FakeClusterScheduler(sc, ("exec1", "host1"), ("exec2", "host2"))
+ val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2"))
val taskSet = createTaskSet(3,
Seq(TaskLocation("host1")),
Seq(TaskLocation("host2")),
Seq(TaskLocation("host3"))
)
val clock = new FakeClock
- val manager = new ClusterTaskSetManager(sched, taskSet, clock)
+ val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock)
// First offer host1: first task should be chosen
assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0)
@@ -259,10 +259,10 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo
test("task result lost") {
sc = new SparkContext("local", "test")
- val sched = new FakeClusterScheduler(sc, ("exec1", "host1"))
+ val sched = new FakeTaskScheduler(sc, ("exec1", "host1"))
val taskSet = createTaskSet(1)
val clock = new FakeClock
- val manager = new ClusterTaskSetManager(sched, taskSet, clock)
+ val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock)
assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0)
@@ -276,20 +276,20 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo
test("repeated failures lead to task set abortion") {
sc = new SparkContext("local", "test")
- val sched = new FakeClusterScheduler(sc, ("exec1", "host1"))
+ val sched = new FakeTaskScheduler(sc, ("exec1", "host1"))
val taskSet = createTaskSet(1)
val clock = new FakeClock
- val manager = new ClusterTaskSetManager(sched, taskSet, clock)
+ val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock)
// Fail the task MAX_TASK_FAILURES times, and check that the task set is aborted
// after the last failure.
- (0 until manager.MAX_TASK_FAILURES).foreach { index =>
+ (0 until MAX_TASK_FAILURES).foreach { index =>
val offerResult = manager.resourceOffer("exec1", "host1", 1, ANY)
assert(offerResult != None,
"Expect resource offer on iteration %s to return a task".format(index))
assert(offerResult.get.index === 0)
manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, Some(TaskResultLost))
- if (index < manager.MAX_TASK_FAILURES) {
+ if (index < MAX_TASK_FAILURES) {
assert(!sched.taskSetsFailed.contains(taskSet.id))
} else {
assert(sched.taskSetsFailed.contains(taskSet.id))
diff --git a/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala
deleted file mode 100644
index 1e676c1719..0000000000
--- a/core/src/test/scala/org/apache/spark/scheduler/local/LocalSchedulerSuite.scala
+++ /dev/null
@@ -1,227 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.scheduler.local
-
-import java.util.concurrent.Semaphore
-import java.util.concurrent.CountDownLatch
-
-import scala.collection.mutable.HashMap
-
-import org.scalatest.{BeforeAndAfterEach, FunSuite}
-
-import org.apache.spark._
-
-
-class Lock() {
- var finished = false
- def jobWait() = {
- synchronized {
- while(!finished) {
- this.wait()
- }
- }
- }
-
- def jobFinished() = {
- synchronized {
- finished = true
- this.notifyAll()
- }
- }
-}
-
-object TaskThreadInfo {
- val threadToLock = HashMap[Int, Lock]()
- val threadToRunning = HashMap[Int, Boolean]()
- val threadToStarted = HashMap[Int, CountDownLatch]()
-}
-
-/*
- * 1. each thread contains one job.
- * 2. each job contains one stage.
- * 3. each stage only contains one task.
- * 4. each task(launched) must be lanched orderly(using threadToStarted) to make sure
- * it will get cpu core resource, and will wait to finished after user manually
- * release "Lock" and then cluster will contain another free cpu cores.
- * 5. each task(pending) must use "sleep" to make sure it has been added to taskSetManager queue,
- * thus it will be scheduled later when cluster has free cpu cores.
- */
-class LocalSchedulerSuite extends FunSuite with LocalSparkContext with BeforeAndAfterEach {
-
- override def afterEach() {
- super.afterEach()
- System.clearProperty("spark.scheduler.mode")
- }
-
- def createThread(threadIndex: Int, poolName: String, sc: SparkContext, sem: Semaphore) {
-
- TaskThreadInfo.threadToRunning(threadIndex) = false
- val nums = sc.parallelize(threadIndex to threadIndex, 1)
- TaskThreadInfo.threadToLock(threadIndex) = new Lock()
- TaskThreadInfo.threadToStarted(threadIndex) = new CountDownLatch(1)
- new Thread {
- if (poolName != null) {
- sc.setLocalProperty("spark.scheduler.pool", poolName)
- }
- override def run() {
- val ans = nums.map(number => {
- TaskThreadInfo.threadToRunning(number) = true
- TaskThreadInfo.threadToStarted(number).countDown()
- TaskThreadInfo.threadToLock(number).jobWait()
- TaskThreadInfo.threadToRunning(number) = false
- number
- }).collect()
- assert(ans.toList === List(threadIndex))
- sem.release()
- }
- }.start()
- }
-
- test("Local FIFO scheduler end-to-end test") {
- System.setProperty("spark.scheduler.mode", "FIFO")
- sc = new SparkContext("local[4]", "test")
- val sem = new Semaphore(0)
-
- createThread(1,null,sc,sem)
- TaskThreadInfo.threadToStarted(1).await()
- createThread(2,null,sc,sem)
- TaskThreadInfo.threadToStarted(2).await()
- createThread(3,null,sc,sem)
- TaskThreadInfo.threadToStarted(3).await()
- createThread(4,null,sc,sem)
- TaskThreadInfo.threadToStarted(4).await()
- // thread 5 and 6 (stage pending)must meet following two points
- // 1. stages (taskSetManager) of jobs in thread 5 and 6 should be add to taskSetManager
- // queue before executing TaskThreadInfo.threadToLock(1).jobFinished()
- // 2. priority of stage in thread 5 should be prior to priority of stage in thread 6
- // So I just use "sleep" 1s here for each thread.
- // TODO: any better solution?
- createThread(5,null,sc,sem)
- Thread.sleep(1000)
- createThread(6,null,sc,sem)
- Thread.sleep(1000)
-
- assert(TaskThreadInfo.threadToRunning(1) === true)
- assert(TaskThreadInfo.threadToRunning(2) === true)
- assert(TaskThreadInfo.threadToRunning(3) === true)
- assert(TaskThreadInfo.threadToRunning(4) === true)
- assert(TaskThreadInfo.threadToRunning(5) === false)
- assert(TaskThreadInfo.threadToRunning(6) === false)
-
- TaskThreadInfo.threadToLock(1).jobFinished()
- TaskThreadInfo.threadToStarted(5).await()
-
- assert(TaskThreadInfo.threadToRunning(1) === false)
- assert(TaskThreadInfo.threadToRunning(2) === true)
- assert(TaskThreadInfo.threadToRunning(3) === true)
- assert(TaskThreadInfo.threadToRunning(4) === true)
- assert(TaskThreadInfo.threadToRunning(5) === true)
- assert(TaskThreadInfo.threadToRunning(6) === false)
-
- TaskThreadInfo.threadToLock(3).jobFinished()
- TaskThreadInfo.threadToStarted(6).await()
-
- assert(TaskThreadInfo.threadToRunning(1) === false)
- assert(TaskThreadInfo.threadToRunning(2) === true)
- assert(TaskThreadInfo.threadToRunning(3) === false)
- assert(TaskThreadInfo.threadToRunning(4) === true)
- assert(TaskThreadInfo.threadToRunning(5) === true)
- assert(TaskThreadInfo.threadToRunning(6) === true)
-
- TaskThreadInfo.threadToLock(2).jobFinished()
- TaskThreadInfo.threadToLock(4).jobFinished()
- TaskThreadInfo.threadToLock(5).jobFinished()
- TaskThreadInfo.threadToLock(6).jobFinished()
- sem.acquire(6)
- }
-
- test("Local fair scheduler end-to-end test") {
- System.setProperty("spark.scheduler.mode", "FAIR")
- val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile()
- System.setProperty("spark.scheduler.allocation.file", xmlPath)
-
- sc = new SparkContext("local[8]", "LocalSchedulerSuite")
- val sem = new Semaphore(0)
-
- createThread(10,"1",sc,sem)
- TaskThreadInfo.threadToStarted(10).await()
- createThread(20,"2",sc,sem)
- TaskThreadInfo.threadToStarted(20).await()
- createThread(30,"3",sc,sem)
- TaskThreadInfo.threadToStarted(30).await()
-
- assert(TaskThreadInfo.threadToRunning(10) === true)
- assert(TaskThreadInfo.threadToRunning(20) === true)
- assert(TaskThreadInfo.threadToRunning(30) === true)
-
- createThread(11,"1",sc,sem)
- TaskThreadInfo.threadToStarted(11).await()
- createThread(21,"2",sc,sem)
- TaskThreadInfo.threadToStarted(21).await()
- createThread(31,"3",sc,sem)
- TaskThreadInfo.threadToStarted(31).await()
-
- assert(TaskThreadInfo.threadToRunning(11) === true)
- assert(TaskThreadInfo.threadToRunning(21) === true)
- assert(TaskThreadInfo.threadToRunning(31) === true)
-
- createThread(12,"1",sc,sem)
- TaskThreadInfo.threadToStarted(12).await()
- createThread(22,"2",sc,sem)
- TaskThreadInfo.threadToStarted(22).await()
- createThread(32,"3",sc,sem)
-
- assert(TaskThreadInfo.threadToRunning(12) === true)
- assert(TaskThreadInfo.threadToRunning(22) === true)
- assert(TaskThreadInfo.threadToRunning(32) === false)
-
- TaskThreadInfo.threadToLock(10).jobFinished()
- TaskThreadInfo.threadToStarted(32).await()
-
- assert(TaskThreadInfo.threadToRunning(32) === true)
-
- //1. Similar with above scenario, sleep 1s for stage of 23 and 33 to be added to taskSetManager
- // queue so that cluster will assign free cpu core to stage 23 after stage 11 finished.
- //2. priority of 23 and 33 will be meaningless as using fair scheduler here.
- createThread(23,"2",sc,sem)
- createThread(33,"3",sc,sem)
- Thread.sleep(1000)
-
- TaskThreadInfo.threadToLock(11).jobFinished()
- TaskThreadInfo.threadToStarted(23).await()
-
- assert(TaskThreadInfo.threadToRunning(23) === true)
- assert(TaskThreadInfo.threadToRunning(33) === false)
-
- TaskThreadInfo.threadToLock(12).jobFinished()
- TaskThreadInfo.threadToStarted(33).await()
-
- assert(TaskThreadInfo.threadToRunning(33) === true)
-
- TaskThreadInfo.threadToLock(20).jobFinished()
- TaskThreadInfo.threadToLock(21).jobFinished()
- TaskThreadInfo.threadToLock(22).jobFinished()
- TaskThreadInfo.threadToLock(23).jobFinished()
- TaskThreadInfo.threadToLock(30).jobFinished()
- TaskThreadInfo.threadToLock(31).jobFinished()
- TaskThreadInfo.threadToLock(32).jobFinished()
- TaskThreadInfo.threadToLock(33).jobFinished()
-
- sem.acquire(11)
- }
-}
diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
index 29b3f22e13..e873400680 100644
--- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
+++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
@@ -17,16 +17,20 @@
package org.apache.spark.scheduler.cluster
+import org.apache.hadoop.conf.Configuration
+
import org.apache.spark._
import org.apache.spark.deploy.yarn.{ApplicationMaster, YarnAllocationHandler}
+import org.apache.spark.scheduler.TaskScheduler
import org.apache.spark.util.Utils
-import org.apache.hadoop.conf.Configuration
/**
*
- * This is a simple extension to ClusterScheduler - to ensure that appropriate initialization of ApplicationMaster, etc is done
+ * This is a simple extension to TaskScheduler - to ensure that appropriate initialization of
+ * ApplicationMaster, etc. is done
*/
-private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration) extends ClusterScheduler(sc) {
+private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration)
+ extends TaskScheduler(sc) {
logInfo("Created YarnClusterScheduler")