From c07087364bac672ed7ded6dfeef00bab628c2f9b Mon Sep 17 00:00:00 2001 From: Harold Lim Date: Mon, 4 Mar 2013 16:37:27 -0500 Subject: Made changes to the SparkContext to have a DynamicVariable for setting local properties that can be passed down the stack. Added an implementation of the fair scheduler --- core/src/main/scala/spark/SparkContext.scala | 38 ++- .../main/scala/spark/scheduler/DAGScheduler.scala | 36 ++- .../scala/spark/scheduler/DAGSchedulerEvent.scala | 5 +- core/src/main/scala/spark/scheduler/Stage.scala | 5 +- core/src/main/scala/spark/scheduler/TaskSet.scala | 4 +- .../cluster/fair/FairClusterScheduler.scala | 341 +++++++++++++++++++++ .../cluster/fair/FairTaskSetManager.scala | 130 ++++++++ 7 files changed, 530 insertions(+), 29 deletions(-) create mode 100644 core/src/main/scala/spark/scheduler/cluster/fair/FairClusterScheduler.scala create mode 100644 core/src/main/scala/spark/scheduler/cluster/fair/FairTaskSetManager.scala (limited to 'core/src') diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 4957a54c1b..bd2261cf0d 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -3,11 +3,13 @@ package spark import java.io._ import java.util.concurrent.atomic.AtomicInteger import java.net.URI +import java.util.Properties import scala.collection.Map import scala.collection.generic.Growable import scala.collection.mutable.HashMap import scala.collection.JavaConversions._ +import scala.util.DynamicVariable import org.apache.hadoop.fs.Path import org.apache.hadoop.conf.Configuration @@ -72,6 +74,11 @@ class SparkContext( if (System.getProperty("spark.driver.port") == null) { System.setProperty("spark.driver.port", "0") } + + //Set the default task scheduler + if (System.getProperty("spark.cluster.taskscheduler") == null) { + System.setProperty("spark.cluster.taskscheduler", "spark.scheduler.cluster.ClusterScheduler") + } private val isLocal = (master == "local" || master.startsWith("local[")) @@ -112,7 +119,7 @@ class SparkContext( } } executorEnvs ++= environment - + // Create and start the scheduler private var taskScheduler: TaskScheduler = { // Regular expression used for local[N] master format @@ -137,7 +144,7 @@ class SparkContext( new LocalScheduler(threads.toInt, maxFailures.toInt, this) case SPARK_REGEX(sparkUrl) => - val scheduler = new ClusterScheduler(this) + val scheduler = Class.forName(System.getProperty("spark.cluster.taskscheduler")).getConstructors()(0).newInstance(Array[AnyRef](this):_*).asInstanceOf[ClusterScheduler] val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, appName) scheduler.initialize(backend) scheduler @@ -153,7 +160,7 @@ class SparkContext( memoryPerSlaveInt, sparkMemEnvInt)) } - val scheduler = new ClusterScheduler(this) + val scheduler = Class.forName(System.getProperty("spark.cluster.taskscheduler")).getConstructors()(0).newInstance(Array[AnyRef](this):_*).asInstanceOf[ClusterScheduler] val localCluster = new LocalSparkCluster( numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt) val sparkUrl = localCluster.start() @@ -169,7 +176,7 @@ class SparkContext( logWarning("Master %s does not match expected format, parsing as Mesos URL".format(master)) } MesosNativeLibrary.load() - val scheduler = new ClusterScheduler(this) + val scheduler = Class.forName(System.getProperty("spark.cluster.taskscheduler")).getConstructors()(0).newInstance(Array[AnyRef](this):_*).asInstanceOf[ClusterScheduler] val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean val masterWithoutProtocol = master.replaceFirst("^mesos://", "") // Strip initial mesos:// val backend = if (coarseGrained) { @@ -206,6 +213,20 @@ class SparkContext( } private[spark] var checkpointDir: Option[String] = None + + // Thread Local variable that can be used by users to pass information down the stack + private val localProperties = new DynamicVariable[Properties](null) + + def initLocalProperties() { + localProperties.value = new Properties() + } + + def addLocalProperties(key: String, value: String) { + if(localProperties.value == null) { + localProperties.value = new Properties() + } + localProperties.value.setProperty(key,value) + } // Methods for creating RDDs @@ -578,7 +599,7 @@ class SparkContext( val callSite = Utils.getSparkCallSite logInfo("Starting job: " + callSite) val start = System.nanoTime - val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler) + val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler,localProperties.value) logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s") rdd.doCheckpoint() result @@ -649,7 +670,7 @@ class SparkContext( val processFunc = (context: TaskContext, iter: Iterator[T]) => processPartition(iter) runJob[T, U](rdd, processFunc, 0 until rdd.partitions.size, false, resultHandler) } - + /** * Run a job that can return approximate results. */ @@ -657,12 +678,11 @@ class SparkContext( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, evaluator: ApproximateEvaluator[U, R], - timeout: Long - ): PartialResult[R] = { + timeout: Long): PartialResult[R] = { val callSite = Utils.getSparkCallSite logInfo("Starting job: " + callSite) val start = System.nanoTime - val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout) + val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout, localProperties.value) logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s") result } diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index c54dce51d7..2ad73f3232 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -4,6 +4,7 @@ import cluster.TaskInfo import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.LinkedBlockingQueue import java.util.concurrent.TimeUnit +import java.util.Properties import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} @@ -128,11 +129,11 @@ class DAGScheduler( * The priority value passed in will be used if the stage doesn't already exist with * a lower priority (we assume that priorities always increase across jobs for now). */ - private def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], priority: Int): Stage = { + private def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], priority: Int, properties: Properties): Stage = { shuffleToMapStage.get(shuffleDep.shuffleId) match { case Some(stage) => stage case None => - val stage = newStage(shuffleDep.rdd, Some(shuffleDep), priority) + val stage = newStage(shuffleDep.rdd, Some(shuffleDep), priority, properties) shuffleToMapStage(shuffleDep.shuffleId) = stage stage } @@ -143,7 +144,7 @@ class DAGScheduler( * as a result stage for the final RDD used directly in an action. The stage will also be given * the provided priority. */ - private def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_]], priority: Int): Stage = { + private def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_]], priority: Int, properties: Properties): Stage = { if (shuffleDep != None) { // Kind of ugly: need to register RDDs with the cache and map output tracker here // since we can't do it in the RDD constructor because # of partitions is unknown @@ -151,7 +152,7 @@ class DAGScheduler( mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.partitions.size) } val id = nextStageId.getAndIncrement() - val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, priority), priority) + val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, priority, properties), priority, properties) idToStage(id) = stage stageToInfos(stage) = StageInfo(stage) stage @@ -161,7 +162,7 @@ class DAGScheduler( * Get or create the list of parent stages for a given RDD. The stages will be assigned the * provided priority if they haven't already been created with a lower priority. */ - private def getParentStages(rdd: RDD[_], priority: Int): List[Stage] = { + private def getParentStages(rdd: RDD[_], priority: Int, properties: Properties): List[Stage] = { val parents = new HashSet[Stage] val visited = new HashSet[RDD[_]] def visit(r: RDD[_]) { @@ -172,7 +173,7 @@ class DAGScheduler( for (dep <- r.dependencies) { dep match { case shufDep: ShuffleDependency[_,_] => - parents += getShuffleMapStage(shufDep, priority) + parents += getShuffleMapStage(shufDep, priority, properties) case _ => visit(dep.rdd) } @@ -193,7 +194,7 @@ class DAGScheduler( for (dep <- rdd.dependencies) { dep match { case shufDep: ShuffleDependency[_,_] => - val mapStage = getShuffleMapStage(shufDep, stage.priority) + val mapStage = getShuffleMapStage(shufDep, stage.priority, stage.properties) if (!mapStage.isAvailable) { missing += mapStage } @@ -221,13 +222,14 @@ class DAGScheduler( partitions: Seq[Int], callSite: String, allowLocal: Boolean, - resultHandler: (Int, U) => Unit) + resultHandler: (Int, U) => Unit, + properties: Properties = null) : (JobSubmitted, JobWaiter[U]) = { assert(partitions.size > 0) val waiter = new JobWaiter(partitions.size, resultHandler) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] - val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter) + val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties) return (toSubmit, waiter) } @@ -237,13 +239,13 @@ class DAGScheduler( partitions: Seq[Int], callSite: String, allowLocal: Boolean, - resultHandler: (Int, U) => Unit) + resultHandler: (Int, U) => Unit, properties: Properties = null) { if (partitions.size == 0) { return } val (toSubmit, waiter) = prepareJob( - finalRdd, func, partitions, callSite, allowLocal, resultHandler) + finalRdd, func, partitions, callSite, allowLocal, resultHandler, properties) eventQueue.put(toSubmit) waiter.awaitResult() match { case JobSucceeded => {} @@ -258,13 +260,13 @@ class DAGScheduler( func: (TaskContext, Iterator[T]) => U, evaluator: ApproximateEvaluator[U, R], callSite: String, - timeout: Long) + timeout: Long, properties: Properties = null) : PartialResult[R] = { val listener = new ApproximateActionListener(rdd, func, evaluator, timeout) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val partitions = (0 until rdd.partitions.size).toArray - eventQueue.put(JobSubmitted(rdd, func2, partitions, false, callSite, listener)) + eventQueue.put(JobSubmitted(rdd, func2, partitions, false, callSite, listener, properties)) return listener.awaitResult() // Will throw an exception if the job fails } @@ -274,9 +276,9 @@ class DAGScheduler( */ private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = { event match { - case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener) => + case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener, properties) => val runId = nextRunId.getAndIncrement() - val finalStage = newStage(finalRDD, None, runId) + val finalStage = newStage(finalRDD, None, runId, properties) val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener) clearCacheLocs() logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length + @@ -458,7 +460,7 @@ class DAGScheduler( myPending ++= tasks logDebug("New pending tasks: " + myPending) taskSched.submitTasks( - new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.priority)) + new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.priority, stage.properties)) if (!stage.submissionTime.isDefined) { stage.submissionTime = Some(System.currentTimeMillis()) } @@ -663,7 +665,7 @@ class DAGScheduler( for (dep <- rdd.dependencies) { dep match { case shufDep: ShuffleDependency[_,_] => - val mapStage = getShuffleMapStage(shufDep, stage.priority) + val mapStage = getShuffleMapStage(shufDep, stage.priority, stage.properties) if (!mapStage.isAvailable) { visitedStages += mapStage visit(mapStage.rdd) diff --git a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala index ed0b9bf178..79588891e7 100644 --- a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala @@ -1,5 +1,8 @@ package spark.scheduler + +import java.util.Properties + import spark.scheduler.cluster.TaskInfo import scala.collection.mutable.Map @@ -20,7 +23,7 @@ private[spark] case class JobSubmitted( partitions: Array[Int], allowLocal: Boolean, callSite: String, - listener: JobListener) + listener: JobListener, properties: Properties) extends DAGSchedulerEvent private[spark] case class CompletionEvent( diff --git a/core/src/main/scala/spark/scheduler/Stage.scala b/core/src/main/scala/spark/scheduler/Stage.scala index 552061e46b..97afa27a60 100644 --- a/core/src/main/scala/spark/scheduler/Stage.scala +++ b/core/src/main/scala/spark/scheduler/Stage.scala @@ -1,10 +1,12 @@ package spark.scheduler import java.net.URI +import java.util.Properties import spark._ import spark.storage.BlockManagerId + /** * A stage is a set of independent tasks all computing the same function that need to run as part * of a Spark job, where all the tasks have the same shuffle dependencies. Each DAG of tasks run @@ -24,7 +26,8 @@ private[spark] class Stage( val rdd: RDD[_], val shuffleDep: Option[ShuffleDependency[_,_]], // Output shuffle if stage is a map stage val parents: List[Stage], - val priority: Int) + val priority: Int, + val properties: Properties = null) extends Logging { val isShuffleMap = shuffleDep != None diff --git a/core/src/main/scala/spark/scheduler/TaskSet.scala b/core/src/main/scala/spark/scheduler/TaskSet.scala index a3002ca477..2498e8a5aa 100644 --- a/core/src/main/scala/spark/scheduler/TaskSet.scala +++ b/core/src/main/scala/spark/scheduler/TaskSet.scala @@ -1,10 +1,12 @@ package spark.scheduler +import java.util.Properties + /** * A set of tasks submitted together to the low-level TaskScheduler, usually representing * missing partitions of a particular stage. */ -private[spark] class TaskSet(val tasks: Array[Task[_]], val stageId: Int, val attempt: Int, val priority: Int) { +private[spark] class TaskSet(val tasks: Array[Task[_]], val stageId: Int, val attempt: Int, val priority: Int, val properties: Properties) { val id: String = stageId + "." + attempt override def toString: String = "TaskSet " + id diff --git a/core/src/main/scala/spark/scheduler/cluster/fair/FairClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/fair/FairClusterScheduler.scala new file mode 100644 index 0000000000..37d98ccb2a --- /dev/null +++ b/core/src/main/scala/spark/scheduler/cluster/fair/FairClusterScheduler.scala @@ -0,0 +1,341 @@ +package spark.scheduler.cluster.fair + +import java.io.{File, FileInputStream, FileOutputStream} +import java.util.{TimerTask, Timer} + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.collection.mutable.HashSet +import scala.util.control.Breaks._ +import scala.xml._ + +import spark._ +import spark.TaskState.TaskState +import spark.scheduler._ +import spark.scheduler.cluster._ +import java.nio.ByteBuffer +import java.util.concurrent.atomic.AtomicLong +import scala.io.Source + +/** + * An implementation of a fair TaskScheduler, for running tasks on a cluster. Clients should first call + * start(), then submit task sets through the runTasks method. + * + * The current implementation makes the following assumptions: A pool has a fixed configuration of weight. + * Within a pool, it just uses FIFO. + * Also, currently we assume that pools are statically defined + * We currently don't support min shares + */ +private[spark] class FairClusterScheduler(override val sc: SparkContext) + extends ClusterScheduler(sc) + with Logging { + + + val schedulerAllocFile = System.getProperty("mapred.fairscheduler.allocation.file","unspecified") + + val poolNameToPool= new HashMap[String, Pool] + var pools = new ArrayBuffer[Pool] + + loadPoolProperties() + + def loadPoolProperties() { + //first check if the file exists + val file = new File(schedulerAllocFile) + if(!file.exists()) { + //if file does not exist, we just create 1 pool, default + val pool = new Pool("default",100) + pools += pool + poolNameToPool("default") = pool + logInfo("Created a default pool with weight = 100") + } + else { + val xml = XML.loadFile(file) + for (poolNode <- (xml \\ "pool")) { + if((poolNode \ "weight").text != ""){ + val pool = new Pool((poolNode \ "@name").text,(poolNode \ "weight").text.toInt) + pools += pool + poolNameToPool((poolNode \ "@name").text) = pool + logInfo("Created pool "+ pool.name +"with weight = "+pool.weight) + } else { + val pool = new Pool((poolNode \ "@name").text,100) + pools += pool + poolNameToPool((poolNode \ "@name").text) = pool + logInfo("Created pool "+ pool.name +"with weight = 100") + } + } + if(!poolNameToPool.contains("default")) { + val pool = new Pool("default", 100) + pools += pool + poolNameToPool("default") = pool + logInfo("Created a default pool with weight = 100") + } + + } + } + + def taskFinished(manager: TaskSetManager) { + var poolName = "default" + if(manager.taskSet.properties != null) + poolName = manager.taskSet.properties.getProperty("spark.scheduler.cluster.fair.pool","default") + + this.synchronized { + //have to check that poolName exists + if(poolNameToPool.contains(poolName)) + { + poolNameToPool(poolName).numRunningTasks -= 1 + } + else + { + poolNameToPool("default").numRunningTasks -= 1 + } + } + } + + override def submitTasks(taskSet: TaskSet) { + val tasks = taskSet.tasks + + + var poolName = "default" + if(taskSet.properties != null) + poolName = taskSet.properties.getProperty("spark.scheduler.cluster.fair.pool","default") + + this.synchronized { + if(poolNameToPool.contains(poolName)) + { + val manager = new FairTaskSetManager(this, taskSet) + poolNameToPool(poolName).activeTaskSetsQueue += manager + activeTaskSets(taskSet.id) = manager + //activeTaskSetsQueue += manager + taskSetTaskIds(taskSet.id) = new HashSet[Long]() + logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks to pool "+poolName) + } + else //If the pool name does not exists, where do we put them? We put them in default + { + val manager = new FairTaskSetManager(this, taskSet) + poolNameToPool("default").activeTaskSetsQueue += manager + activeTaskSets(taskSet.id) = manager + //activeTaskSetsQueue += manager + taskSetTaskIds(taskSet.id) = new HashSet[Long]() + logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks to pool default") + } + if (hasReceivedTask == false) { + starvationTimer.scheduleAtFixedRate(new TimerTask() { + override def run() { + if (!hasLaunchedTask) { + logWarning("Initial job has not accepted any resources; " + + "check your cluster UI to ensure that workers are registered") + } else { + this.cancel() + } + } + }, STARVATION_TIMEOUT, STARVATION_TIMEOUT) + } + hasReceivedTask = true; + + } + backend.reviveOffers() + } + + override def taskSetFinished(manager: TaskSetManager) { + + var poolName = "default" + if(manager.taskSet.properties != null) + poolName = manager.taskSet.properties.getProperty("spark.scheduler.cluster.fair.pool","default") + + + this.synchronized { + //have to check that poolName exists + if(poolNameToPool.contains(poolName)) + { + poolNameToPool(poolName).activeTaskSetsQueue -= manager + } + else + { + poolNameToPool("default").activeTaskSetsQueue -= manager + } + //activeTaskSetsQueue -= manager + activeTaskSets -= manager.taskSet.id + taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id) + taskIdToExecutorId --= taskSetTaskIds(manager.taskSet.id) + taskSetTaskIds.remove(manager.taskSet.id) + } + //backend.reviveOffers() + } + + /** + * This is the comparison function used for sorting to determine which + * pool to allocate next based on fairness. + * The algorithm is as follows: we sort by the pool's running tasks to weight ratio + * (pools number running tast / pool's weight) + */ + def poolFairCompFn(pool1: Pool, pool2: Pool): Boolean = { + val tasksToWeightRatio1 = pool1.numRunningTasks.toDouble / pool1.weight.toDouble + val tasksToWeightRatio2 = pool2.numRunningTasks.toDouble / pool2.weight.toDouble + var res = Math.signum(tasksToWeightRatio1 - tasksToWeightRatio2) + if (res == 0) { + //Jobs are tied in fairness ratio. We break the tie by name + res = pool1.name.compareTo(pool2.name) + } + if (res < 0) + return true + else + return false + } + + /** + * Called by cluster manager to offer resources on slaves. We respond by asking our active task + * sets for tasks in order of priority. We fill each node with tasks in a fair manner so + * that tasks are balanced across the cluster. + */ + override def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = { + synchronized { + SparkEnv.set(sc.env) + // Mark each slave as alive and remember its hostname + for (o <- offers) { + executorIdToHost(o.executorId) = o.hostname + if (!executorsByHost.contains(o.hostname)) { + executorsByHost(o.hostname) = new HashSet() + } + } + // Build a list of tasks to assign to each slave + val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores)) + val availableCpus = offers.map(o => o.cores).toArray + var launchedTask = false + + for (i <- 0 until offers.size) { //we loop through the list of offers + val execId = offers(i).executorId + val host = offers(i).hostname + var breakOut = false + while(availableCpus(i) > 0 && !breakOut) { + breakable{ + launchedTask = false + for (pool <- pools.sortWith(poolFairCompFn)) { //we loop through the list of pools + if(!pool.activeTaskSetsQueue.isEmpty) { + //sort the tasksetmanager in the pool + pool.activeTaskSetsQueue.sortBy(m => (m.taskSet.priority, m.taskSet.stageId)) + for(manager <- pool.activeTaskSetsQueue) { //we loop through the activeTaskSets in this pool +// val manager = pool.activeTaskSetsQueue.head + //Make an offer + manager.slaveOffer(execId, host, availableCpus(i)) match { + case Some(task) => + tasks(i) += task + val tid = task.taskId + taskIdToTaskSetId(tid) = manager.taskSet.id + taskSetTaskIds(manager.taskSet.id) += tid + taskIdToExecutorId(tid) = execId + activeExecutorIds += execId + executorsByHost(host) += execId + availableCpus(i) -= 1 + pool.numRunningTasks += 1 + launchedTask = true + logInfo("launched task for pool"+pool.name); + break + case None => {} + } + } + } + } + //If there is not one pool that can assign the task then we have to exit the outer loop and continue to the next offer + if(!launchedTask){ + breakOut = true + } + } + } + } + if (tasks.size > 0) { + hasLaunchedTask = true + } + return tasks + } + } + + override def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { + var taskSetToUpdate: Option[TaskSetManager] = None + var failedExecutor: Option[String] = None + var taskFailed = false + synchronized { + try { + if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) { + // We lost this entire executor, so remember that it's gone + val execId = taskIdToExecutorId(tid) + if (activeExecutorIds.contains(execId)) { + removeExecutor(execId) + failedExecutor = Some(execId) + } + } + taskIdToTaskSetId.get(tid) match { + case Some(taskSetId) => + if (activeTaskSets.contains(taskSetId)) { + taskSetToUpdate = Some(activeTaskSets(taskSetId)) + } + if (TaskState.isFinished(state)) { + taskIdToTaskSetId.remove(tid) + if (taskSetTaskIds.contains(taskSetId)) { + taskSetTaskIds(taskSetId) -= tid + } + taskIdToExecutorId.remove(tid) + } + if (state == TaskState.FAILED) { + taskFailed = true + } + case None => + logInfo("Ignoring update from TID " + tid + " because its task set is gone") + } + } catch { + case e: Exception => logError("Exception in statusUpdate", e) + } + } + // Update the task set and DAGScheduler without holding a lock on this, since that can deadlock + if (taskSetToUpdate != None) { + taskSetToUpdate.get.statusUpdate(tid, state, serializedData) + } + if (failedExecutor != None) { + listener.executorLost(failedExecutor.get) + backend.reviveOffers() + } + if (taskFailed) { + // Also revive offers if a task had failed for some reason other than host lost + backend.reviveOffers() + } + } + + // Check for speculatable tasks in all our active jobs. + override def checkSpeculatableTasks() { + var shouldRevive = false + synchronized { + for (pool <- pools) { + for (ts <- pool.activeTaskSetsQueue) { + shouldRevive |= ts.checkSpeculatableTasks() + } + } + } + if (shouldRevive) { + backend.reviveOffers() + } + } + + /** Remove an executor from all our data structures and mark it as lost */ + private def removeExecutor(executorId: String) { + activeExecutorIds -= executorId + val host = executorIdToHost(executorId) + val execs = executorsByHost.getOrElse(host, new HashSet) + execs -= executorId + if (execs.isEmpty) { + executorsByHost -= host + } + executorIdToHost -= executorId + for (pool <- pools) { + pool.activeTaskSetsQueue.foreach(_.executorLost(executorId, host)) + } + } + +} + +/** + * An internal representation of a pool. It contains an ArrayBuffer of TaskSets and also weight and minshare + */ +class Pool(val name: String, val weight: Int) +{ + var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager] + var numRunningTasks: Int = 0 +} diff --git a/core/src/main/scala/spark/scheduler/cluster/fair/FairTaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/fair/FairTaskSetManager.scala new file mode 100644 index 0000000000..4b0277d2d5 --- /dev/null +++ b/core/src/main/scala/spark/scheduler/cluster/fair/FairTaskSetManager.scala @@ -0,0 +1,130 @@ +package spark.scheduler.cluster.fair + +import scala.collection.mutable.ArrayBuffer + +import spark._ +import spark.scheduler._ +import spark.scheduler.cluster._ +import spark.TaskState.TaskState +import java.nio.ByteBuffer + +/** + * Schedules the tasks within a single TaskSet in the FairClusterScheduler. + */ +private[spark] class FairTaskSetManager(sched: FairClusterScheduler, override val taskSet: TaskSet) extends TaskSetManager(sched, taskSet) with Logging { + + // Add a task to all the pending-task lists that it should be on. + private def addPendingTask(index: Int) { + val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive + if (locations.size == 0) { + pendingTasksWithNoPrefs += index + } else { + for (host <- locations) { + val list = pendingTasksForHost.getOrElseUpdate(host, ArrayBuffer()) + list += index + } + } + allPendingTasks += index + } + + override def taskFinished(tid: Long, state: TaskState, serializedData: ByteBuffer) { + val info = taskInfos(tid) + if (info.failed) { + // We might get two task-lost messages for the same task in coarse-grained Mesos mode, + // or even from Mesos itself when acks get delayed. + return + } + val index = info.index + info.markSuccessful() + sched.taskFinished(this) + if (!finished(index)) { + tasksFinished += 1 + logInfo("Finished TID %s in %d ms (progress: %d/%d)".format( + tid, info.duration, tasksFinished, numTasks)) + // Deserialize task result and pass it to the scheduler + val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader) + result.metrics.resultSize = serializedData.limit() + sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates, info, result.metrics) + // Mark finished and stop if we've finished all the tasks + finished(index) = true + if (tasksFinished == numTasks) { + sched.taskSetFinished(this) + } + } else { + logInfo("Ignoring task-finished event for TID " + tid + + " because task " + index + " is already finished") + } + } + + override def taskLost(tid: Long, state: TaskState, serializedData: ByteBuffer) { + val info = taskInfos(tid) + if (info.failed) { + // We might get two task-lost messages for the same task in coarse-grained Mesos mode, + // or even from Mesos itself when acks get delayed. + return + } + val index = info.index + info.markFailed() + //Bookkeeping necessary for the pools in the scheduler + sched.taskFinished(this) + if (!finished(index)) { + logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index)) + copiesRunning(index) -= 1 + // Check if the problem is a map output fetch failure. In that case, this + // task will never succeed on any node, so tell the scheduler about it. + if (serializedData != null && serializedData.limit() > 0) { + val reason = ser.deserialize[TaskEndReason](serializedData, getClass.getClassLoader) + reason match { + case fetchFailed: FetchFailed => + logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress) + sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null) + finished(index) = true + tasksFinished += 1 + sched.taskSetFinished(this) + return + + case ef: ExceptionFailure => + val key = ef.exception.toString + val now = System.currentTimeMillis + val (printFull, dupCount) = { + if (recentExceptions.contains(key)) { + val (dupCount, printTime) = recentExceptions(key) + if (now - printTime > EXCEPTION_PRINT_INTERVAL) { + recentExceptions(key) = (0, now) + (true, 0) + } else { + recentExceptions(key) = (dupCount + 1, printTime) + (false, dupCount + 1) + } + } else { + recentExceptions(key) = (0, now) + (true, 0) + } + } + if (printFull) { + val locs = ef.exception.getStackTrace.map(loc => "\tat %s".format(loc.toString)) + logInfo("Loss was due to %s\n%s".format(ef.exception.toString, locs.mkString("\n"))) + } else { + logInfo("Loss was due to %s [duplicate %d]".format(ef.exception.toString, dupCount)) + } + + case _ => {} + } + } + // On non-fetch failures, re-enqueue the task as pending for a max number of retries + addPendingTask(index) + // Count failed attempts only on FAILED and LOST state (not on KILLED) + if (state == TaskState.FAILED || state == TaskState.LOST) { + numFailures(index) += 1 + if (numFailures(index) > MAX_TASK_FAILURES) { + logError("Task %s:%d failed more than %d times; aborting job".format( + taskSet.id, index, MAX_TASK_FAILURES)) + abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES)) + } + } + } else { + logInfo("Ignoring task-lost event for TID " + tid + + " because task " + index + " is already finished") + } + } +} \ No newline at end of file -- cgit v1.2.3