aboutsummaryrefslogtreecommitdiff
path: root/core/src
diff options
context:
space:
mode:
authorHarold Lim <harold@cs.duke.edu>2013-03-04 16:37:27 -0500
committerAndrew xia <junluan.xia@intel.com>2013-03-12 13:31:27 +0800
commitc07087364bac672ed7ded6dfeef00bab628c2f9b (patch)
treee91fcb4cef176b50a433c001f6520bdb5179b941 /core/src
parentcbf8f0d4dda41ffd45855eab8401fda9b64168cd (diff)
downloadspark-c07087364bac672ed7ded6dfeef00bab628c2f9b.tar.gz
spark-c07087364bac672ed7ded6dfeef00bab628c2f9b.tar.bz2
spark-c07087364bac672ed7ded6dfeef00bab628c2f9b.zip
Made changes to the SparkContext to have a DynamicVariable for setting local properties that can be passed down the stack. Added an implementation of the fair scheduler
Diffstat (limited to 'core/src')
-rw-r--r--core/src/main/scala/spark/SparkContext.scala38
-rw-r--r--core/src/main/scala/spark/scheduler/DAGScheduler.scala36
-rw-r--r--core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala5
-rw-r--r--core/src/main/scala/spark/scheduler/Stage.scala5
-rw-r--r--core/src/main/scala/spark/scheduler/TaskSet.scala4
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/fair/FairClusterScheduler.scala341
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/fair/FairTaskSetManager.scala130
7 files changed, 530 insertions, 29 deletions
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index 4957a54c1b..bd2261cf0d 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -3,11 +3,13 @@ package spark
import java.io._
import java.util.concurrent.atomic.AtomicInteger
import java.net.URI
+import java.util.Properties
import scala.collection.Map
import scala.collection.generic.Growable
import scala.collection.mutable.HashMap
import scala.collection.JavaConversions._
+import scala.util.DynamicVariable
import org.apache.hadoop.fs.Path
import org.apache.hadoop.conf.Configuration
@@ -72,6 +74,11 @@ class SparkContext(
if (System.getProperty("spark.driver.port") == null) {
System.setProperty("spark.driver.port", "0")
}
+
+ //Set the default task scheduler
+ if (System.getProperty("spark.cluster.taskscheduler") == null) {
+ System.setProperty("spark.cluster.taskscheduler", "spark.scheduler.cluster.ClusterScheduler")
+ }
private val isLocal = (master == "local" || master.startsWith("local["))
@@ -112,7 +119,7 @@ class SparkContext(
}
}
executorEnvs ++= environment
-
+
// Create and start the scheduler
private var taskScheduler: TaskScheduler = {
// Regular expression used for local[N] master format
@@ -137,7 +144,7 @@ class SparkContext(
new LocalScheduler(threads.toInt, maxFailures.toInt, this)
case SPARK_REGEX(sparkUrl) =>
- val scheduler = new ClusterScheduler(this)
+ val scheduler = Class.forName(System.getProperty("spark.cluster.taskscheduler")).getConstructors()(0).newInstance(Array[AnyRef](this):_*).asInstanceOf[ClusterScheduler]
val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, appName)
scheduler.initialize(backend)
scheduler
@@ -153,7 +160,7 @@ class SparkContext(
memoryPerSlaveInt, sparkMemEnvInt))
}
- val scheduler = new ClusterScheduler(this)
+ val scheduler = Class.forName(System.getProperty("spark.cluster.taskscheduler")).getConstructors()(0).newInstance(Array[AnyRef](this):_*).asInstanceOf[ClusterScheduler]
val localCluster = new LocalSparkCluster(
numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt)
val sparkUrl = localCluster.start()
@@ -169,7 +176,7 @@ class SparkContext(
logWarning("Master %s does not match expected format, parsing as Mesos URL".format(master))
}
MesosNativeLibrary.load()
- val scheduler = new ClusterScheduler(this)
+ val scheduler = Class.forName(System.getProperty("spark.cluster.taskscheduler")).getConstructors()(0).newInstance(Array[AnyRef](this):_*).asInstanceOf[ClusterScheduler]
val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean
val masterWithoutProtocol = master.replaceFirst("^mesos://", "") // Strip initial mesos://
val backend = if (coarseGrained) {
@@ -206,6 +213,20 @@ class SparkContext(
}
private[spark] var checkpointDir: Option[String] = None
+
+ // Thread Local variable that can be used by users to pass information down the stack
+ private val localProperties = new DynamicVariable[Properties](null)
+
+ def initLocalProperties() {
+ localProperties.value = new Properties()
+ }
+
+ def addLocalProperties(key: String, value: String) {
+ if(localProperties.value == null) {
+ localProperties.value = new Properties()
+ }
+ localProperties.value.setProperty(key,value)
+ }
// Methods for creating RDDs
@@ -578,7 +599,7 @@ class SparkContext(
val callSite = Utils.getSparkCallSite
logInfo("Starting job: " + callSite)
val start = System.nanoTime
- val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler)
+ val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler,localProperties.value)
logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
rdd.doCheckpoint()
result
@@ -649,7 +670,7 @@ class SparkContext(
val processFunc = (context: TaskContext, iter: Iterator[T]) => processPartition(iter)
runJob[T, U](rdd, processFunc, 0 until rdd.partitions.size, false, resultHandler)
}
-
+
/**
* Run a job that can return approximate results.
*/
@@ -657,12 +678,11 @@ class SparkContext(
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
evaluator: ApproximateEvaluator[U, R],
- timeout: Long
- ): PartialResult[R] = {
+ timeout: Long): PartialResult[R] = {
val callSite = Utils.getSparkCallSite
logInfo("Starting job: " + callSite)
val start = System.nanoTime
- val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout)
+ val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout, localProperties.value)
logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
result
}
diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
index c54dce51d7..2ad73f3232 100644
--- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
@@ -4,6 +4,7 @@ import cluster.TaskInfo
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.TimeUnit
+import java.util.Properties
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
@@ -128,11 +129,11 @@ class DAGScheduler(
* The priority value passed in will be used if the stage doesn't already exist with
* a lower priority (we assume that priorities always increase across jobs for now).
*/
- private def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], priority: Int): Stage = {
+ private def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], priority: Int, properties: Properties): Stage = {
shuffleToMapStage.get(shuffleDep.shuffleId) match {
case Some(stage) => stage
case None =>
- val stage = newStage(shuffleDep.rdd, Some(shuffleDep), priority)
+ val stage = newStage(shuffleDep.rdd, Some(shuffleDep), priority, properties)
shuffleToMapStage(shuffleDep.shuffleId) = stage
stage
}
@@ -143,7 +144,7 @@ class DAGScheduler(
* as a result stage for the final RDD used directly in an action. The stage will also be given
* the provided priority.
*/
- private def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_]], priority: Int): Stage = {
+ private def newStage(rdd: RDD[_], shuffleDep: Option[ShuffleDependency[_,_]], priority: Int, properties: Properties): Stage = {
if (shuffleDep != None) {
// Kind of ugly: need to register RDDs with the cache and map output tracker here
// since we can't do it in the RDD constructor because # of partitions is unknown
@@ -151,7 +152,7 @@ class DAGScheduler(
mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.partitions.size)
}
val id = nextStageId.getAndIncrement()
- val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, priority), priority)
+ val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, priority, properties), priority, properties)
idToStage(id) = stage
stageToInfos(stage) = StageInfo(stage)
stage
@@ -161,7 +162,7 @@ class DAGScheduler(
* Get or create the list of parent stages for a given RDD. The stages will be assigned the
* provided priority if they haven't already been created with a lower priority.
*/
- private def getParentStages(rdd: RDD[_], priority: Int): List[Stage] = {
+ private def getParentStages(rdd: RDD[_], priority: Int, properties: Properties): List[Stage] = {
val parents = new HashSet[Stage]
val visited = new HashSet[RDD[_]]
def visit(r: RDD[_]) {
@@ -172,7 +173,7 @@ class DAGScheduler(
for (dep <- r.dependencies) {
dep match {
case shufDep: ShuffleDependency[_,_] =>
- parents += getShuffleMapStage(shufDep, priority)
+ parents += getShuffleMapStage(shufDep, priority, properties)
case _ =>
visit(dep.rdd)
}
@@ -193,7 +194,7 @@ class DAGScheduler(
for (dep <- rdd.dependencies) {
dep match {
case shufDep: ShuffleDependency[_,_] =>
- val mapStage = getShuffleMapStage(shufDep, stage.priority)
+ val mapStage = getShuffleMapStage(shufDep, stage.priority, stage.properties)
if (!mapStage.isAvailable) {
missing += mapStage
}
@@ -221,13 +222,14 @@ class DAGScheduler(
partitions: Seq[Int],
callSite: String,
allowLocal: Boolean,
- resultHandler: (Int, U) => Unit)
+ resultHandler: (Int, U) => Unit,
+ properties: Properties = null)
: (JobSubmitted, JobWaiter[U]) =
{
assert(partitions.size > 0)
val waiter = new JobWaiter(partitions.size, resultHandler)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
- val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter)
+ val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties)
return (toSubmit, waiter)
}
@@ -237,13 +239,13 @@ class DAGScheduler(
partitions: Seq[Int],
callSite: String,
allowLocal: Boolean,
- resultHandler: (Int, U) => Unit)
+ resultHandler: (Int, U) => Unit, properties: Properties = null)
{
if (partitions.size == 0) {
return
}
val (toSubmit, waiter) = prepareJob(
- finalRdd, func, partitions, callSite, allowLocal, resultHandler)
+ finalRdd, func, partitions, callSite, allowLocal, resultHandler, properties)
eventQueue.put(toSubmit)
waiter.awaitResult() match {
case JobSucceeded => {}
@@ -258,13 +260,13 @@ class DAGScheduler(
func: (TaskContext, Iterator[T]) => U,
evaluator: ApproximateEvaluator[U, R],
callSite: String,
- timeout: Long)
+ timeout: Long, properties: Properties = null)
: PartialResult[R] =
{
val listener = new ApproximateActionListener(rdd, func, evaluator, timeout)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val partitions = (0 until rdd.partitions.size).toArray
- eventQueue.put(JobSubmitted(rdd, func2, partitions, false, callSite, listener))
+ eventQueue.put(JobSubmitted(rdd, func2, partitions, false, callSite, listener, properties))
return listener.awaitResult() // Will throw an exception if the job fails
}
@@ -274,9 +276,9 @@ class DAGScheduler(
*/
private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = {
event match {
- case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener) =>
+ case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener, properties) =>
val runId = nextRunId.getAndIncrement()
- val finalStage = newStage(finalRDD, None, runId)
+ val finalStage = newStage(finalRDD, None, runId, properties)
val job = new ActiveJob(runId, finalStage, func, partitions, callSite, listener)
clearCacheLocs()
logInfo("Got job " + job.runId + " (" + callSite + ") with " + partitions.length +
@@ -458,7 +460,7 @@ class DAGScheduler(
myPending ++= tasks
logDebug("New pending tasks: " + myPending)
taskSched.submitTasks(
- new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.priority))
+ new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.priority, stage.properties))
if (!stage.submissionTime.isDefined) {
stage.submissionTime = Some(System.currentTimeMillis())
}
@@ -663,7 +665,7 @@ class DAGScheduler(
for (dep <- rdd.dependencies) {
dep match {
case shufDep: ShuffleDependency[_,_] =>
- val mapStage = getShuffleMapStage(shufDep, stage.priority)
+ val mapStage = getShuffleMapStage(shufDep, stage.priority, stage.properties)
if (!mapStage.isAvailable) {
visitedStages += mapStage
visit(mapStage.rdd)
diff --git a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala
index ed0b9bf178..79588891e7 100644
--- a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala
@@ -1,5 +1,8 @@
package spark.scheduler
+
+import java.util.Properties
+
import spark.scheduler.cluster.TaskInfo
import scala.collection.mutable.Map
@@ -20,7 +23,7 @@ private[spark] case class JobSubmitted(
partitions: Array[Int],
allowLocal: Boolean,
callSite: String,
- listener: JobListener)
+ listener: JobListener, properties: Properties)
extends DAGSchedulerEvent
private[spark] case class CompletionEvent(
diff --git a/core/src/main/scala/spark/scheduler/Stage.scala b/core/src/main/scala/spark/scheduler/Stage.scala
index 552061e46b..97afa27a60 100644
--- a/core/src/main/scala/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/spark/scheduler/Stage.scala
@@ -1,10 +1,12 @@
package spark.scheduler
import java.net.URI
+import java.util.Properties
import spark._
import spark.storage.BlockManagerId
+
/**
* A stage is a set of independent tasks all computing the same function that need to run as part
* of a Spark job, where all the tasks have the same shuffle dependencies. Each DAG of tasks run
@@ -24,7 +26,8 @@ private[spark] class Stage(
val rdd: RDD[_],
val shuffleDep: Option[ShuffleDependency[_,_]], // Output shuffle if stage is a map stage
val parents: List[Stage],
- val priority: Int)
+ val priority: Int,
+ val properties: Properties = null)
extends Logging {
val isShuffleMap = shuffleDep != None
diff --git a/core/src/main/scala/spark/scheduler/TaskSet.scala b/core/src/main/scala/spark/scheduler/TaskSet.scala
index a3002ca477..2498e8a5aa 100644
--- a/core/src/main/scala/spark/scheduler/TaskSet.scala
+++ b/core/src/main/scala/spark/scheduler/TaskSet.scala
@@ -1,10 +1,12 @@
package spark.scheduler
+import java.util.Properties
+
/**
* A set of tasks submitted together to the low-level TaskScheduler, usually representing
* missing partitions of a particular stage.
*/
-private[spark] class TaskSet(val tasks: Array[Task[_]], val stageId: Int, val attempt: Int, val priority: Int) {
+private[spark] class TaskSet(val tasks: Array[Task[_]], val stageId: Int, val attempt: Int, val priority: Int, val properties: Properties) {
val id: String = stageId + "." + attempt
override def toString: String = "TaskSet " + id
diff --git a/core/src/main/scala/spark/scheduler/cluster/fair/FairClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/fair/FairClusterScheduler.scala
new file mode 100644
index 0000000000..37d98ccb2a
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/cluster/fair/FairClusterScheduler.scala
@@ -0,0 +1,341 @@
+package spark.scheduler.cluster.fair
+
+import java.io.{File, FileInputStream, FileOutputStream}
+import java.util.{TimerTask, Timer}
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
+import scala.util.control.Breaks._
+import scala.xml._
+
+import spark._
+import spark.TaskState.TaskState
+import spark.scheduler._
+import spark.scheduler.cluster._
+import java.nio.ByteBuffer
+import java.util.concurrent.atomic.AtomicLong
+import scala.io.Source
+
+/**
+ * An implementation of a fair TaskScheduler, for running tasks on a cluster. Clients should first call
+ * start(), then submit task sets through the runTasks method.
+ *
+ * The current implementation makes the following assumptions: A pool has a fixed configuration of weight.
+ * Within a pool, it just uses FIFO.
+ * Also, currently we assume that pools are statically defined
+ * We currently don't support min shares
+ */
+private[spark] class FairClusterScheduler(override val sc: SparkContext)
+ extends ClusterScheduler(sc)
+ with Logging {
+
+
+ val schedulerAllocFile = System.getProperty("mapred.fairscheduler.allocation.file","unspecified")
+
+ val poolNameToPool= new HashMap[String, Pool]
+ var pools = new ArrayBuffer[Pool]
+
+ loadPoolProperties()
+
+ def loadPoolProperties() {
+ //first check if the file exists
+ val file = new File(schedulerAllocFile)
+ if(!file.exists()) {
+ //if file does not exist, we just create 1 pool, default
+ val pool = new Pool("default",100)
+ pools += pool
+ poolNameToPool("default") = pool
+ logInfo("Created a default pool with weight = 100")
+ }
+ else {
+ val xml = XML.loadFile(file)
+ for (poolNode <- (xml \\ "pool")) {
+ if((poolNode \ "weight").text != ""){
+ val pool = new Pool((poolNode \ "@name").text,(poolNode \ "weight").text.toInt)
+ pools += pool
+ poolNameToPool((poolNode \ "@name").text) = pool
+ logInfo("Created pool "+ pool.name +"with weight = "+pool.weight)
+ } else {
+ val pool = new Pool((poolNode \ "@name").text,100)
+ pools += pool
+ poolNameToPool((poolNode \ "@name").text) = pool
+ logInfo("Created pool "+ pool.name +"with weight = 100")
+ }
+ }
+ if(!poolNameToPool.contains("default")) {
+ val pool = new Pool("default", 100)
+ pools += pool
+ poolNameToPool("default") = pool
+ logInfo("Created a default pool with weight = 100")
+ }
+
+ }
+ }
+
+ def taskFinished(manager: TaskSetManager) {
+ var poolName = "default"
+ if(manager.taskSet.properties != null)
+ poolName = manager.taskSet.properties.getProperty("spark.scheduler.cluster.fair.pool","default")
+
+ this.synchronized {
+ //have to check that poolName exists
+ if(poolNameToPool.contains(poolName))
+ {
+ poolNameToPool(poolName).numRunningTasks -= 1
+ }
+ else
+ {
+ poolNameToPool("default").numRunningTasks -= 1
+ }
+ }
+ }
+
+ override def submitTasks(taskSet: TaskSet) {
+ val tasks = taskSet.tasks
+
+
+ var poolName = "default"
+ if(taskSet.properties != null)
+ poolName = taskSet.properties.getProperty("spark.scheduler.cluster.fair.pool","default")
+
+ this.synchronized {
+ if(poolNameToPool.contains(poolName))
+ {
+ val manager = new FairTaskSetManager(this, taskSet)
+ poolNameToPool(poolName).activeTaskSetsQueue += manager
+ activeTaskSets(taskSet.id) = manager
+ //activeTaskSetsQueue += manager
+ taskSetTaskIds(taskSet.id) = new HashSet[Long]()
+ logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks to pool "+poolName)
+ }
+ else //If the pool name does not exists, where do we put them? We put them in default
+ {
+ val manager = new FairTaskSetManager(this, taskSet)
+ poolNameToPool("default").activeTaskSetsQueue += manager
+ activeTaskSets(taskSet.id) = manager
+ //activeTaskSetsQueue += manager
+ taskSetTaskIds(taskSet.id) = new HashSet[Long]()
+ logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks to pool default")
+ }
+ if (hasReceivedTask == false) {
+ starvationTimer.scheduleAtFixedRate(new TimerTask() {
+ override def run() {
+ if (!hasLaunchedTask) {
+ logWarning("Initial job has not accepted any resources; " +
+ "check your cluster UI to ensure that workers are registered")
+ } else {
+ this.cancel()
+ }
+ }
+ }, STARVATION_TIMEOUT, STARVATION_TIMEOUT)
+ }
+ hasReceivedTask = true;
+
+ }
+ backend.reviveOffers()
+ }
+
+ override def taskSetFinished(manager: TaskSetManager) {
+
+ var poolName = "default"
+ if(manager.taskSet.properties != null)
+ poolName = manager.taskSet.properties.getProperty("spark.scheduler.cluster.fair.pool","default")
+
+
+ this.synchronized {
+ //have to check that poolName exists
+ if(poolNameToPool.contains(poolName))
+ {
+ poolNameToPool(poolName).activeTaskSetsQueue -= manager
+ }
+ else
+ {
+ poolNameToPool("default").activeTaskSetsQueue -= manager
+ }
+ //activeTaskSetsQueue -= manager
+ activeTaskSets -= manager.taskSet.id
+ taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
+ taskIdToExecutorId --= taskSetTaskIds(manager.taskSet.id)
+ taskSetTaskIds.remove(manager.taskSet.id)
+ }
+ //backend.reviveOffers()
+ }
+
+ /**
+ * This is the comparison function used for sorting to determine which
+ * pool to allocate next based on fairness.
+ * The algorithm is as follows: we sort by the pool's running tasks to weight ratio
+ * (pools number running tast / pool's weight)
+ */
+ def poolFairCompFn(pool1: Pool, pool2: Pool): Boolean = {
+ val tasksToWeightRatio1 = pool1.numRunningTasks.toDouble / pool1.weight.toDouble
+ val tasksToWeightRatio2 = pool2.numRunningTasks.toDouble / pool2.weight.toDouble
+ var res = Math.signum(tasksToWeightRatio1 - tasksToWeightRatio2)
+ if (res == 0) {
+ //Jobs are tied in fairness ratio. We break the tie by name
+ res = pool1.name.compareTo(pool2.name)
+ }
+ if (res < 0)
+ return true
+ else
+ return false
+ }
+
+ /**
+ * Called by cluster manager to offer resources on slaves. We respond by asking our active task
+ * sets for tasks in order of priority. We fill each node with tasks in a fair manner so
+ * that tasks are balanced across the cluster.
+ */
+ override def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = {
+ synchronized {
+ SparkEnv.set(sc.env)
+ // Mark each slave as alive and remember its hostname
+ for (o <- offers) {
+ executorIdToHost(o.executorId) = o.hostname
+ if (!executorsByHost.contains(o.hostname)) {
+ executorsByHost(o.hostname) = new HashSet()
+ }
+ }
+ // Build a list of tasks to assign to each slave
+ val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores))
+ val availableCpus = offers.map(o => o.cores).toArray
+ var launchedTask = false
+
+ for (i <- 0 until offers.size) { //we loop through the list of offers
+ val execId = offers(i).executorId
+ val host = offers(i).hostname
+ var breakOut = false
+ while(availableCpus(i) > 0 && !breakOut) {
+ breakable{
+ launchedTask = false
+ for (pool <- pools.sortWith(poolFairCompFn)) { //we loop through the list of pools
+ if(!pool.activeTaskSetsQueue.isEmpty) {
+ //sort the tasksetmanager in the pool
+ pool.activeTaskSetsQueue.sortBy(m => (m.taskSet.priority, m.taskSet.stageId))
+ for(manager <- pool.activeTaskSetsQueue) { //we loop through the activeTaskSets in this pool
+// val manager = pool.activeTaskSetsQueue.head
+ //Make an offer
+ manager.slaveOffer(execId, host, availableCpus(i)) match {
+ case Some(task) =>
+ tasks(i) += task
+ val tid = task.taskId
+ taskIdToTaskSetId(tid) = manager.taskSet.id
+ taskSetTaskIds(manager.taskSet.id) += tid
+ taskIdToExecutorId(tid) = execId
+ activeExecutorIds += execId
+ executorsByHost(host) += execId
+ availableCpus(i) -= 1
+ pool.numRunningTasks += 1
+ launchedTask = true
+ logInfo("launched task for pool"+pool.name);
+ break
+ case None => {}
+ }
+ }
+ }
+ }
+ //If there is not one pool that can assign the task then we have to exit the outer loop and continue to the next offer
+ if(!launchedTask){
+ breakOut = true
+ }
+ }
+ }
+ }
+ if (tasks.size > 0) {
+ hasLaunchedTask = true
+ }
+ return tasks
+ }
+ }
+
+ override def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+ var taskSetToUpdate: Option[TaskSetManager] = None
+ var failedExecutor: Option[String] = None
+ var taskFailed = false
+ synchronized {
+ try {
+ if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) {
+ // We lost this entire executor, so remember that it's gone
+ val execId = taskIdToExecutorId(tid)
+ if (activeExecutorIds.contains(execId)) {
+ removeExecutor(execId)
+ failedExecutor = Some(execId)
+ }
+ }
+ taskIdToTaskSetId.get(tid) match {
+ case Some(taskSetId) =>
+ if (activeTaskSets.contains(taskSetId)) {
+ taskSetToUpdate = Some(activeTaskSets(taskSetId))
+ }
+ if (TaskState.isFinished(state)) {
+ taskIdToTaskSetId.remove(tid)
+ if (taskSetTaskIds.contains(taskSetId)) {
+ taskSetTaskIds(taskSetId) -= tid
+ }
+ taskIdToExecutorId.remove(tid)
+ }
+ if (state == TaskState.FAILED) {
+ taskFailed = true
+ }
+ case None =>
+ logInfo("Ignoring update from TID " + tid + " because its task set is gone")
+ }
+ } catch {
+ case e: Exception => logError("Exception in statusUpdate", e)
+ }
+ }
+ // Update the task set and DAGScheduler without holding a lock on this, since that can deadlock
+ if (taskSetToUpdate != None) {
+ taskSetToUpdate.get.statusUpdate(tid, state, serializedData)
+ }
+ if (failedExecutor != None) {
+ listener.executorLost(failedExecutor.get)
+ backend.reviveOffers()
+ }
+ if (taskFailed) {
+ // Also revive offers if a task had failed for some reason other than host lost
+ backend.reviveOffers()
+ }
+ }
+
+ // Check for speculatable tasks in all our active jobs.
+ override def checkSpeculatableTasks() {
+ var shouldRevive = false
+ synchronized {
+ for (pool <- pools) {
+ for (ts <- pool.activeTaskSetsQueue) {
+ shouldRevive |= ts.checkSpeculatableTasks()
+ }
+ }
+ }
+ if (shouldRevive) {
+ backend.reviveOffers()
+ }
+ }
+
+ /** Remove an executor from all our data structures and mark it as lost */
+ private def removeExecutor(executorId: String) {
+ activeExecutorIds -= executorId
+ val host = executorIdToHost(executorId)
+ val execs = executorsByHost.getOrElse(host, new HashSet)
+ execs -= executorId
+ if (execs.isEmpty) {
+ executorsByHost -= host
+ }
+ executorIdToHost -= executorId
+ for (pool <- pools) {
+ pool.activeTaskSetsQueue.foreach(_.executorLost(executorId, host))
+ }
+ }
+
+}
+
+/**
+ * An internal representation of a pool. It contains an ArrayBuffer of TaskSets and also weight and minshare
+ */
+class Pool(val name: String, val weight: Int)
+{
+ var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager]
+ var numRunningTasks: Int = 0
+}
diff --git a/core/src/main/scala/spark/scheduler/cluster/fair/FairTaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/fair/FairTaskSetManager.scala
new file mode 100644
index 0000000000..4b0277d2d5
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/cluster/fair/FairTaskSetManager.scala
@@ -0,0 +1,130 @@
+package spark.scheduler.cluster.fair
+
+import scala.collection.mutable.ArrayBuffer
+
+import spark._
+import spark.scheduler._
+import spark.scheduler.cluster._
+import spark.TaskState.TaskState
+import java.nio.ByteBuffer
+
+/**
+ * Schedules the tasks within a single TaskSet in the FairClusterScheduler.
+ */
+private[spark] class FairTaskSetManager(sched: FairClusterScheduler, override val taskSet: TaskSet) extends TaskSetManager(sched, taskSet) with Logging {
+
+ // Add a task to all the pending-task lists that it should be on.
+ private def addPendingTask(index: Int) {
+ val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive
+ if (locations.size == 0) {
+ pendingTasksWithNoPrefs += index
+ } else {
+ for (host <- locations) {
+ val list = pendingTasksForHost.getOrElseUpdate(host, ArrayBuffer())
+ list += index
+ }
+ }
+ allPendingTasks += index
+ }
+
+ override def taskFinished(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+ val info = taskInfos(tid)
+ if (info.failed) {
+ // We might get two task-lost messages for the same task in coarse-grained Mesos mode,
+ // or even from Mesos itself when acks get delayed.
+ return
+ }
+ val index = info.index
+ info.markSuccessful()
+ sched.taskFinished(this)
+ if (!finished(index)) {
+ tasksFinished += 1
+ logInfo("Finished TID %s in %d ms (progress: %d/%d)".format(
+ tid, info.duration, tasksFinished, numTasks))
+ // Deserialize task result and pass it to the scheduler
+ val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader)
+ result.metrics.resultSize = serializedData.limit()
+ sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
+ // Mark finished and stop if we've finished all the tasks
+ finished(index) = true
+ if (tasksFinished == numTasks) {
+ sched.taskSetFinished(this)
+ }
+ } else {
+ logInfo("Ignoring task-finished event for TID " + tid +
+ " because task " + index + " is already finished")
+ }
+ }
+
+ override def taskLost(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+ val info = taskInfos(tid)
+ if (info.failed) {
+ // We might get two task-lost messages for the same task in coarse-grained Mesos mode,
+ // or even from Mesos itself when acks get delayed.
+ return
+ }
+ val index = info.index
+ info.markFailed()
+ //Bookkeeping necessary for the pools in the scheduler
+ sched.taskFinished(this)
+ if (!finished(index)) {
+ logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
+ copiesRunning(index) -= 1
+ // Check if the problem is a map output fetch failure. In that case, this
+ // task will never succeed on any node, so tell the scheduler about it.
+ if (serializedData != null && serializedData.limit() > 0) {
+ val reason = ser.deserialize[TaskEndReason](serializedData, getClass.getClassLoader)
+ reason match {
+ case fetchFailed: FetchFailed =>
+ logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress)
+ sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null)
+ finished(index) = true
+ tasksFinished += 1
+ sched.taskSetFinished(this)
+ return
+
+ case ef: ExceptionFailure =>
+ val key = ef.exception.toString
+ val now = System.currentTimeMillis
+ val (printFull, dupCount) = {
+ if (recentExceptions.contains(key)) {
+ val (dupCount, printTime) = recentExceptions(key)
+ if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
+ recentExceptions(key) = (0, now)
+ (true, 0)
+ } else {
+ recentExceptions(key) = (dupCount + 1, printTime)
+ (false, dupCount + 1)
+ }
+ } else {
+ recentExceptions(key) = (0, now)
+ (true, 0)
+ }
+ }
+ if (printFull) {
+ val locs = ef.exception.getStackTrace.map(loc => "\tat %s".format(loc.toString))
+ logInfo("Loss was due to %s\n%s".format(ef.exception.toString, locs.mkString("\n")))
+ } else {
+ logInfo("Loss was due to %s [duplicate %d]".format(ef.exception.toString, dupCount))
+ }
+
+ case _ => {}
+ }
+ }
+ // On non-fetch failures, re-enqueue the task as pending for a max number of retries
+ addPendingTask(index)
+ // Count failed attempts only on FAILED and LOST state (not on KILLED)
+ if (state == TaskState.FAILED || state == TaskState.LOST) {
+ numFailures(index) += 1
+ if (numFailures(index) > MAX_TASK_FAILURES) {
+ logError("Task %s:%d failed more than %d times; aborting job".format(
+ taskSet.id, index, MAX_TASK_FAILURES))
+ abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES))
+ }
+ }
+ } else {
+ logInfo("Ignoring task-lost event for TID " + tid +
+ " because task " + index + " is already finished")
+ }
+ }
+} \ No newline at end of file