aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndrew xia <junluan.xia@intel.com>2013-05-30 20:49:40 +0800
committerAndrew xia <junluan.xia@intel.com>2013-05-30 20:49:40 +0800
commitc3db3ea55467c3fb053453c8c567db357d939640 (patch)
tree3bd62a650096e109fba7a8e058794d4a51b3e3af
parentecceb101d3019ef511c42a8a8a3bb0e46520ffef (diff)
downloadspark-c3db3ea55467c3fb053453c8c567db357d939640.tar.gz
spark-c3db3ea55467c3fb053453c8c567db357d939640.tar.bz2
spark-c3db3ea55467c3fb053453c8c567db357d939640.zip
1. Add unit test for local scheduler
2. Move localTaskSetManager to a new file
-rw-r--r--core/src/main/scala/spark/scheduler/local/LocalScheduler.scala241
-rw-r--r--core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala173
-rw-r--r--core/src/test/scala/spark/scheduler/LocalSchedulerSuite.scala171
3 files changed, 385 insertions, 200 deletions
diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
index 664dc9e886..69dacfc2bd 100644
--- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
@@ -15,7 +15,7 @@ import spark.scheduler.cluster._
import akka.actor._
/**
- * A simple TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
+ * A FIFO or Fair TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
* the scheduler also allows each task to fail up to maxFailures times, which is useful for
* testing fault recovery.
*/
@@ -26,10 +26,8 @@ private[spark] case class LocalStatusUpdate(taskId: Long, state: TaskState, seri
private[spark] class LocalActor(localScheduler: LocalScheduler, var freeCores: Int) extends Actor with Logging {
def receive = {
case LocalReviveOffers =>
- logInfo("LocalReviveOffers")
launchTask(localScheduler.resourceOffer(freeCores))
case LocalStatusUpdate(taskId, state, serializeData) =>
- logInfo("LocalStatusUpdate")
freeCores += 1
localScheduler.statusUpdate(taskId, state, serializeData)
launchTask(localScheduler.resourceOffer(freeCores))
@@ -48,168 +46,6 @@ private[spark] class LocalActor(localScheduler: LocalScheduler, var freeCores: I
}
}
-private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: TaskSet) extends TaskSetManager with Logging {
- var parent: Schedulable = null
- var weight: Int = 1
- var minShare: Int = 0
- var runningTasks: Int = 0
- var priority: Int = taskSet.priority
- var stageId: Int = taskSet.stageId
- var name: String = "TaskSet_"+taskSet.stageId.toString
-
-
- var failCount = new Array[Int](taskSet.tasks.size)
- val taskInfos = new HashMap[Long, TaskInfo]
- val numTasks = taskSet.tasks.size
- var numFinished = 0
- val ser = SparkEnv.get.closureSerializer.newInstance()
- val copiesRunning = new Array[Int](numTasks)
- val finished = new Array[Boolean](numTasks)
- val numFailures = new Array[Int](numTasks)
-
- def increaseRunningTasks(taskNum: Int): Unit = {
- runningTasks += taskNum
- if (parent != null) {
- parent.increaseRunningTasks(taskNum)
- }
- }
-
- def decreaseRunningTasks(taskNum: Int): Unit = {
- runningTasks -= taskNum
- if (parent != null) {
- parent.decreaseRunningTasks(taskNum)
- }
- }
-
- def addSchedulable(schedulable: Schedulable): Unit = {
- //nothing
- }
-
- def removeSchedulable(schedulable: Schedulable): Unit = {
- //nothing
- }
-
- def getSchedulableByName(name: String): Schedulable = {
- return null
- }
-
- def executorLost(executorId: String, host: String): Unit = {
- //nothing
- }
-
- def checkSpeculatableTasks(): Boolean = {
- return true
- }
-
- def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
- var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager]
- sortedTaskSetQueue += this
- return sortedTaskSetQueue
- }
-
- def hasPendingTasks(): Boolean = {
- return true
- }
-
- def findTask(): Option[Int] = {
- for (i <- 0 to numTasks-1) {
- if (copiesRunning(i) == 0 && !finished(i)) {
- return Some(i)
- }
- }
- return None
- }
-
- def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = {
- Thread.currentThread().setContextClassLoader(sched.classLoader)
- SparkEnv.set(sched.env)
- logInfo("availableCpus:%d,numFinished:%d,numTasks:%d".format(availableCpus.toInt, numFinished, numTasks))
- if (availableCpus > 0 && numFinished < numTasks) {
- findTask() match {
- case Some(index) =>
- logInfo(taskSet.tasks(index).toString)
- val taskId = sched.attemptId.getAndIncrement()
- val task = taskSet.tasks(index)
- logInfo("taskId:%d,task:%s".format(index,task))
- val info = new TaskInfo(taskId, index, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL)
- taskInfos(taskId) = info
- val bytes = Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser)
- logInfo("Size of task " + taskId + " is " + bytes.limit + " bytes")
- val taskName = "task %s:%d".format(taskSet.id, index)
- copiesRunning(index) += 1
- increaseRunningTasks(1)
- return Some(new TaskDescription(taskId, null, taskName, bytes))
- case None => {}
- }
- }
- return None
- }
-
- def numPendingTasksForHostPort(hostPort: String): Int = {
- return 0
- }
-
- def numRackLocalPendingTasksForHost(hostPort :String): Int = {
- return 0
- }
-
- def numPendingTasksForHost(hostPort: String): Int = {
- return 0
- }
-
- def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
- state match {
- case TaskState.FINISHED =>
- taskEnded(tid, state, serializedData)
- case TaskState.FAILED =>
- taskFailed(tid, state, serializedData)
- case _ => {}
- }
- }
-
- def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) {
- val info = taskInfos(tid)
- val index = info.index
- val task = taskSet.tasks(index)
- info.markSuccessful()
- val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader)
- result.metrics.resultSize = serializedData.limit()
- sched.listener.taskEnded(task, Success, result.value, result.accumUpdates, info, result.metrics)
- numFinished += 1
- decreaseRunningTasks(1)
- finished(index) = true
- if (numFinished == numTasks) {
- sched.taskSetFinished(this)
- }
- }
-
- def taskFailed(tid: Long, state: TaskState, serializedData: ByteBuffer) {
- val info = taskInfos(tid)
- val index = info.index
- val task = taskSet.tasks(index)
- info.markFailed()
- decreaseRunningTasks(1)
- val reason: ExceptionFailure = ser.deserialize[ExceptionFailure](serializedData, getClass.getClassLoader)
- if (!finished(index)) {
- copiesRunning(index) -= 1
- numFailures(index) += 1
- val locs = reason.stackTrace.map(loc => "\tat %s".format(loc.toString))
- logInfo("Loss was due to %s\n%s\n%s".format(reason.className, reason.description, locs.mkString("\n")))
- if (numFailures(index) > 4) {
- val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format(taskSet.id, index, 4, reason.description)
- //logError(errorMessage)
- //sched.listener.taskEnded(task, reason, null, null, info, null)
- sched.listener.taskSetFailed(taskSet, errorMessage)
- sched.taskSetFinished(this)
- decreaseRunningTasks(runningTasks)
- }
- }
- }
-
- def error(message: String) {
- }
-}
-
private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: SparkContext)
extends TaskScheduler
with Logging {
@@ -233,7 +69,6 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
val taskSetTaskIds = new HashMap[String, HashSet[Long]]
var localActor: ActorRef = null
- // TODO: Need to take into account stage priority in scheduling
override def start() {
//default scheduler is FIFO
@@ -250,7 +85,6 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
}
schedulableBuilder.buildPools()
- //val properties = new ArrayBuffer[(String, String)]
localActor = env.actorSystem.actorOf(
Props(new LocalActor(this, threads)), "Test")
}
@@ -260,51 +94,56 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
}
override def submitTasks(taskSet: TaskSet) {
- var manager = new LocalTaskSetManager(this, taskSet)
- schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
- activeTaskSets(taskSet.id) = manager
- taskSetTaskIds(taskSet.id) = new HashSet[Long]()
- localActor ! LocalReviveOffers
+ synchronized {
+ var manager = new LocalTaskSetManager(this, taskSet)
+ schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
+ activeTaskSets(taskSet.id) = manager
+ taskSetTaskIds(taskSet.id) = new HashSet[Long]()
+ localActor ! LocalReviveOffers
+ }
}
def resourceOffer(freeCores: Int): Seq[TaskDescription] = {
- var freeCpuCores = freeCores
- val tasks = new ArrayBuffer[TaskDescription](freeCores)
- val sortedTaskSetQueue = rootPool.getSortedTaskSetQueue()
- for (manager <- sortedTaskSetQueue) {
- logInfo("parentName:%s,name:%s,runningTasks:%s".format(manager.parent.name, manager.name, manager.runningTasks))
- }
+ synchronized {
+ var freeCpuCores = freeCores
+ val tasks = new ArrayBuffer[TaskDescription](freeCores)
+ val sortedTaskSetQueue = rootPool.getSortedTaskSetQueue()
+ for (manager <- sortedTaskSetQueue) {
+ logInfo("parentName:%s,name:%s,runningTasks:%s".format(manager.parent.name, manager.name, manager.runningTasks))
+ }
- var launchTask = false
- for (manager <- sortedTaskSetQueue) {
+ var launchTask = false
+ for (manager <- sortedTaskSetQueue) {
do {
launchTask = false
- logInfo("freeCores is" + freeCpuCores)
manager.slaveOffer(null,null,freeCpuCores) match {
case Some(task) =>
- tasks += task
- taskIdToTaskSetId(task.taskId) = manager.taskSet.id
- taskSetTaskIds(manager.taskSet.id) += task.taskId
- freeCpuCores -= 1
- launchTask = true
+ tasks += task
+ taskIdToTaskSetId(task.taskId) = manager.taskSet.id
+ taskSetTaskIds(manager.taskSet.id) += task.taskId
+ freeCpuCores -= 1
+ launchTask = true
case None => {}
- }
+ }
} while(launchTask)
+ }
+ return tasks
}
- return tasks
}
def taskSetFinished(manager: TaskSetManager) {
- activeTaskSets -= manager.taskSet.id
- manager.parent.removeSchedulable(manager)
- logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name))
- taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
- taskSetTaskIds -= manager.taskSet.id
+ synchronized {
+ activeTaskSets -= manager.taskSet.id
+ manager.parent.removeSchedulable(manager)
+ logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name))
+ taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
+ taskSetTaskIds -= manager.taskSet.id
+ }
}
def runTask(taskId: Long, bytes: ByteBuffer) {
logInfo("Running " + taskId)
- val info = new TaskInfo(taskId, 0 , System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL)
+ val info = new TaskInfo(taskId, 0, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL)
// Set the Spark execution environment for the worker thread
SparkEnv.set(env)
val ser = SparkEnv.get.closureSerializer.newInstance()
@@ -344,8 +183,8 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
case t: Throwable => {
val failure = new ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace)
localActor ! LocalStatusUpdate(taskId, TaskState.FAILED, ser.serialize(failure))
- }
}
+ }
}
/**
@@ -376,11 +215,13 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
}
}
- def statusUpdate(taskId :Long, state: TaskState, serializedData: ByteBuffer)
- {
- val taskSetId = taskIdToTaskSetId(taskId)
- val taskSetManager = activeTaskSets(taskSetId)
- taskSetManager.statusUpdate(taskId, state, serializedData)
+ def statusUpdate(taskId :Long, state: TaskState, serializedData: ByteBuffer) {
+ synchronized {
+ val taskSetId = taskIdToTaskSetId(taskId)
+ val taskSetManager = activeTaskSets(taskSetId)
+ taskSetTaskIds(taskSetId) -= taskId
+ taskSetManager.statusUpdate(taskId, state, serializedData)
+ }
}
override def stop() {
diff --git a/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala
new file mode 100644
index 0000000000..f2e07d162a
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala
@@ -0,0 +1,173 @@
+package spark.scheduler.local
+
+import java.io.File
+import java.util.concurrent.atomic.AtomicInteger
+import java.nio.ByteBuffer
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
+
+import spark._
+import spark.TaskState.TaskState
+import spark.scheduler._
+import spark.scheduler.cluster._
+
+private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: TaskSet) extends TaskSetManager with Logging {
+ var parent: Schedulable = null
+ var weight: Int = 1
+ var minShare: Int = 0
+ var runningTasks: Int = 0
+ var priority: Int = taskSet.priority
+ var stageId: Int = taskSet.stageId
+ var name: String = "TaskSet_"+taskSet.stageId.toString
+
+
+ var failCount = new Array[Int](taskSet.tasks.size)
+ val taskInfos = new HashMap[Long, TaskInfo]
+ val numTasks = taskSet.tasks.size
+ var numFinished = 0
+ val ser = SparkEnv.get.closureSerializer.newInstance()
+ val copiesRunning = new Array[Int](numTasks)
+ val finished = new Array[Boolean](numTasks)
+ val numFailures = new Array[Int](numTasks)
+ val MAX_TASK_FAILURES = sched.maxFailures
+
+ def increaseRunningTasks(taskNum: Int): Unit = {
+ runningTasks += taskNum
+ if (parent != null) {
+ parent.increaseRunningTasks(taskNum)
+ }
+ }
+
+ def decreaseRunningTasks(taskNum: Int): Unit = {
+ runningTasks -= taskNum
+ if (parent != null) {
+ parent.decreaseRunningTasks(taskNum)
+ }
+ }
+
+ def addSchedulable(schedulable: Schedulable): Unit = {
+ //nothing
+ }
+
+ def removeSchedulable(schedulable: Schedulable): Unit = {
+ //nothing
+ }
+
+ def getSchedulableByName(name: String): Schedulable = {
+ return null
+ }
+
+ def executorLost(executorId: String, host: String): Unit = {
+ //nothing
+ }
+
+ def checkSpeculatableTasks(): Boolean = {
+ return true
+ }
+
+ def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
+ var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager]
+ sortedTaskSetQueue += this
+ return sortedTaskSetQueue
+ }
+
+ def hasPendingTasks(): Boolean = {
+ return true
+ }
+
+ def findTask(): Option[Int] = {
+ for (i <- 0 to numTasks-1) {
+ if (copiesRunning(i) == 0 && !finished(i)) {
+ return Some(i)
+ }
+ }
+ return None
+ }
+
+ def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = {
+ SparkEnv.set(sched.env)
+ logDebug("availableCpus:%d,numFinished:%d,numTasks:%d".format(availableCpus.toInt, numFinished, numTasks))
+ if (availableCpus > 0 && numFinished < numTasks) {
+ findTask() match {
+ case Some(index) =>
+ logInfo(taskSet.tasks(index).toString)
+ val taskId = sched.attemptId.getAndIncrement()
+ val task = taskSet.tasks(index)
+ val info = new TaskInfo(taskId, index, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL)
+ taskInfos(taskId) = info
+ val bytes = Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser)
+ logInfo("Size of task " + taskId + " is " + bytes.limit + " bytes")
+ val taskName = "task %s:%d".format(taskSet.id, index)
+ copiesRunning(index) += 1
+ increaseRunningTasks(1)
+ return Some(new TaskDescription(taskId, null, taskName, bytes))
+ case None => {}
+ }
+ }
+ return None
+ }
+
+ def numPendingTasksForHostPort(hostPort: String): Int = {
+ return 0
+ }
+
+ def numRackLocalPendingTasksForHost(hostPort :String): Int = {
+ return 0
+ }
+
+ def numPendingTasksForHost(hostPort: String): Int = {
+ return 0
+ }
+
+ def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+ state match {
+ case TaskState.FINISHED =>
+ taskEnded(tid, state, serializedData)
+ case TaskState.FAILED =>
+ taskFailed(tid, state, serializedData)
+ case _ => {}
+ }
+ }
+
+ def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+ val info = taskInfos(tid)
+ val index = info.index
+ val task = taskSet.tasks(index)
+ info.markSuccessful()
+ val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader)
+ result.metrics.resultSize = serializedData.limit()
+ sched.listener.taskEnded(task, Success, result.value, result.accumUpdates, info, result.metrics)
+ numFinished += 1
+ decreaseRunningTasks(1)
+ finished(index) = true
+ if (numFinished == numTasks) {
+ sched.taskSetFinished(this)
+ }
+ }
+
+ def taskFailed(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+ val info = taskInfos(tid)
+ val index = info.index
+ val task = taskSet.tasks(index)
+ info.markFailed()
+ decreaseRunningTasks(1)
+ val reason: ExceptionFailure = ser.deserialize[ExceptionFailure](serializedData, getClass.getClassLoader)
+ if (!finished(index)) {
+ copiesRunning(index) -= 1
+ numFailures(index) += 1
+ val locs = reason.stackTrace.map(loc => "\tat %s".format(loc.toString))
+ logInfo("Loss was due to %s\n%s\n%s".format(reason.className, reason.description, locs.mkString("\n")))
+ if (numFailures(index) > MAX_TASK_FAILURES) {
+ val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format(taskSet.id, index, 4, reason.description)
+ decreaseRunningTasks(runningTasks)
+ sched.listener.taskSetFailed(taskSet, errorMessage)
+ // need to delete failed Taskset from schedule queue
+ sched.taskSetFinished(this)
+ }
+ }
+ }
+
+ def error(message: String) {
+ }
+}
diff --git a/core/src/test/scala/spark/scheduler/LocalSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/LocalSchedulerSuite.scala
new file mode 100644
index 0000000000..37d14ed113
--- /dev/null
+++ b/core/src/test/scala/spark/scheduler/LocalSchedulerSuite.scala
@@ -0,0 +1,171 @@
+package spark.scheduler
+
+import org.scalatest.FunSuite
+import org.scalatest.BeforeAndAfter
+
+import spark._
+import spark.scheduler._
+import spark.scheduler.cluster._
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.{ConcurrentMap, HashMap}
+import java.util.concurrent.Semaphore
+import java.util.concurrent.atomic.AtomicBoolean
+import java.util.concurrent.atomic.AtomicInteger
+
+import java.util.Properties
+
+class Lock() {
+ var finished = false
+ def jobWait() = {
+ synchronized {
+ while(!finished) {
+ this.wait()
+ }
+ }
+ }
+
+ def jobFinished() = {
+ synchronized {
+ finished = true
+ this.notifyAll()
+ }
+ }
+}
+
+object TaskThreadInfo {
+ val threadToLock = HashMap[Int, Lock]()
+ val threadToRunning = HashMap[Int, Boolean]()
+}
+
+
+class LocalSchedulerSuite extends FunSuite with LocalSparkContext {
+
+ def createThread(threadIndex: Int, poolName: String, sc: SparkContext, sem: Semaphore) {
+
+ TaskThreadInfo.threadToRunning(threadIndex) = false
+ val nums = sc.parallelize(threadIndex to threadIndex, 1)
+ TaskThreadInfo.threadToLock(threadIndex) = new Lock()
+ new Thread {
+ if (poolName != null) {
+ sc.addLocalProperties("spark.scheduler.cluster.fair.pool",poolName)
+ }
+ override def run() {
+ val ans = nums.map(number => {
+ TaskThreadInfo.threadToRunning(number) = true
+ TaskThreadInfo.threadToLock(number).jobWait()
+ number
+ }).collect()
+ assert(ans.toList === List(threadIndex))
+ sem.release()
+ TaskThreadInfo.threadToRunning(threadIndex) = false
+ }
+ }.start()
+ Thread.sleep(2000)
+ }
+
+ test("Local FIFO scheduler end-to-end test") {
+ System.setProperty("spark.cluster.schedulingmode", "FIFO")
+ sc = new SparkContext("local[4]", "test")
+ val sem = new Semaphore(0)
+
+ createThread(1,null,sc,sem)
+ createThread(2,null,sc,sem)
+ createThread(3,null,sc,sem)
+ createThread(4,null,sc,sem)
+ createThread(5,null,sc,sem)
+ createThread(6,null,sc,sem)
+ assert(TaskThreadInfo.threadToRunning(1) === true)
+ assert(TaskThreadInfo.threadToRunning(2) === true)
+ assert(TaskThreadInfo.threadToRunning(3) === true)
+ assert(TaskThreadInfo.threadToRunning(4) === true)
+ assert(TaskThreadInfo.threadToRunning(5) === false)
+ assert(TaskThreadInfo.threadToRunning(6) === false)
+
+ TaskThreadInfo.threadToLock(1).jobFinished()
+ Thread.sleep(1000)
+
+ assert(TaskThreadInfo.threadToRunning(1) === false)
+ assert(TaskThreadInfo.threadToRunning(2) === true)
+ assert(TaskThreadInfo.threadToRunning(3) === true)
+ assert(TaskThreadInfo.threadToRunning(4) === true)
+ assert(TaskThreadInfo.threadToRunning(5) === true)
+ assert(TaskThreadInfo.threadToRunning(6) === false)
+
+ TaskThreadInfo.threadToLock(3).jobFinished()
+ Thread.sleep(1000)
+
+ assert(TaskThreadInfo.threadToRunning(1) === false)
+ assert(TaskThreadInfo.threadToRunning(2) === true)
+ assert(TaskThreadInfo.threadToRunning(3) === false)
+ assert(TaskThreadInfo.threadToRunning(4) === true)
+ assert(TaskThreadInfo.threadToRunning(5) === true)
+ assert(TaskThreadInfo.threadToRunning(6) === true)
+
+ TaskThreadInfo.threadToLock(2).jobFinished()
+ TaskThreadInfo.threadToLock(4).jobFinished()
+ TaskThreadInfo.threadToLock(5).jobFinished()
+ TaskThreadInfo.threadToLock(6).jobFinished()
+ sem.acquire(6)
+ }
+
+ test("Local fair scheduler end-to-end test") {
+ sc = new SparkContext("local[8]", "LocalSchedulerSuite")
+ val sem = new Semaphore(0)
+ System.setProperty("spark.cluster.schedulingmode", "FAIR")
+ val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile()
+ System.setProperty("spark.fairscheduler.allocation.file", xmlPath)
+
+ createThread(10,"1",sc,sem)
+ createThread(20,"2",sc,sem)
+ createThread(30,"3",sc,sem)
+
+ assert(TaskThreadInfo.threadToRunning(10) === true)
+ assert(TaskThreadInfo.threadToRunning(20) === true)
+ assert(TaskThreadInfo.threadToRunning(30) === true)
+
+ createThread(11,"1",sc,sem)
+ createThread(21,"2",sc,sem)
+ createThread(31,"3",sc,sem)
+
+ assert(TaskThreadInfo.threadToRunning(11) === true)
+ assert(TaskThreadInfo.threadToRunning(21) === true)
+ assert(TaskThreadInfo.threadToRunning(31) === true)
+
+ createThread(12,"1",sc,sem)
+ createThread(22,"2",sc,sem)
+ createThread(32,"3",sc,sem)
+
+ assert(TaskThreadInfo.threadToRunning(12) === true)
+ assert(TaskThreadInfo.threadToRunning(22) === true)
+ assert(TaskThreadInfo.threadToRunning(32) === false)
+
+ TaskThreadInfo.threadToLock(10).jobFinished()
+ Thread.sleep(1000)
+ assert(TaskThreadInfo.threadToRunning(32) === true)
+
+ createThread(23,"2",sc,sem)
+ createThread(33,"3",sc,sem)
+
+ TaskThreadInfo.threadToLock(11).jobFinished()
+ Thread.sleep(1000)
+
+ assert(TaskThreadInfo.threadToRunning(23) === true)
+ assert(TaskThreadInfo.threadToRunning(33) === false)
+
+ TaskThreadInfo.threadToLock(12).jobFinished()
+ Thread.sleep(1000)
+
+ assert(TaskThreadInfo.threadToRunning(33) === true)
+
+ TaskThreadInfo.threadToLock(20).jobFinished()
+ TaskThreadInfo.threadToLock(21).jobFinished()
+ TaskThreadInfo.threadToLock(22).jobFinished()
+ TaskThreadInfo.threadToLock(23).jobFinished()
+ TaskThreadInfo.threadToLock(30).jobFinished()
+ TaskThreadInfo.threadToLock(31).jobFinished()
+ TaskThreadInfo.threadToLock(32).jobFinished()
+ TaskThreadInfo.threadToLock(33).jobFinished()
+
+ sem.acquire(11)
+ }
+}