aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2013-07-28 22:02:01 -0400
committerMatei Zaharia <matei@eecs.berkeley.edu>2013-08-18 19:51:06 -0700
commit90a04dab8d9a2a9a372cea7cdf46cc0fd0f2f76c (patch)
tree2f03c8d7b586d27ea1b0949337e05bd7ad8536cf /core
parent8fa0747978e2600705657ba21f8d1ccef37dc722 (diff)
downloadspark-90a04dab8d9a2a9a372cea7cdf46cc0fd0f2f76c.tar.gz
spark-90a04dab8d9a2a9a372cea7cdf46cc0fd0f2f76c.tar.bz2
spark-90a04dab8d9a2a9a372cea7cdf46cc0fd0f2f76c.zip
Initial work towards scheduler refactoring:
- Replace use of hostPort vs host in Task.preferredLocations with a TaskLocation class that contains either an executorId and a host or just a host. This is part of a bigger effort to eliminate hostPort based data structures and just use executorID, since the hostPort vs host stuff is confusing (and not checkable with static typing, leading to ugly debug code), and hostPorts are not provided by Mesos. - Replaced most hostPort-based data structures and fields as above. - Simplified ClusterTaskSetManager to deal with preferred locations in a more concise way and generally be more concise. - Updated the way ClusterTaskSetManager handles racks: instead of enqueueing a task to a separate queue for all the hosts in the rack, which would create lots of large queues, have one queue per rack name. - Removed non-local fallback stuff in ClusterScheduler that tried to launch less-local tasks on a node once the local ones were all assigned. This change didn't work because many cluster schedulers send offers for just one node at a time (even the standalone and YARN ones do so as nodes join the cluster one by one). Thus, lots of non-local tasks would be assigned even though a node with locality for them would be able to receive tasks just a short time later. - Renamed MapOutputTracker "generations" to "epochs".
Diffstat (limited to 'core')
-rw-r--r--core/src/hadoop2-yarn/scala/spark/scheduler/cluster/YarnClusterScheduler.scala7
-rw-r--r--core/src/main/scala/spark/MapOutputTracker.scala62
-rw-r--r--core/src/main/scala/spark/RDD.scala4
-rw-r--r--core/src/main/scala/spark/executor/Executor.scala12
-rw-r--r--core/src/main/scala/spark/rdd/BlockRDD.scala7
-rw-r--r--core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala2
-rw-r--r--core/src/main/scala/spark/scheduler/DAGScheduler.scala79
-rw-r--r--core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala4
-rw-r--r--core/src/main/scala/spark/scheduler/ResultTask.scala15
-rw-r--r--core/src/main/scala/spark/scheduler/ShuffleMapTask.scala15
-rw-r--r--core/src/main/scala/spark/scheduler/Task.scala4
-rw-r--r--core/src/main/scala/spark/scheduler/TaskLocation.scala32
-rw-r--r--core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala2
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala283
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala510
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala15
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala4
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/TaskLocality.scala32
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala13
-rw-r--r--core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala3
-rw-r--r--core/src/main/scala/spark/scheduler/local/LocalScheduler.scala3
-rw-r--r--core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala21
-rw-r--r--core/src/main/scala/spark/storage/BlockManager.scala60
-rw-r--r--core/src/main/scala/spark/ui/jobs/StagePage.scala2
-rw-r--r--core/src/test/scala/spark/MapOutputTrackerSuite.scala12
-rw-r--r--core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala6
-rw-r--r--core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala26
27 files changed, 484 insertions, 751 deletions
diff --git a/core/src/hadoop2-yarn/scala/spark/scheduler/cluster/YarnClusterScheduler.scala b/core/src/hadoop2-yarn/scala/spark/scheduler/cluster/YarnClusterScheduler.scala
index 307d96111c..bb58353e0c 100644
--- a/core/src/hadoop2-yarn/scala/spark/scheduler/cluster/YarnClusterScheduler.scala
+++ b/core/src/hadoop2-yarn/scala/spark/scheduler/cluster/YarnClusterScheduler.scala
@@ -41,13 +41,6 @@ private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration)
if (retval != null) Some(retval) else None
}
- // By default, if rack is unknown, return nothing
- override def getCachedHostsForRack(rack: String): Option[Set[String]] = {
- if (rack == None || rack == null) return None
-
- YarnAllocationHandler.fetchCachedHostsForRack(rack)
- }
-
override def postStartHook() {
val sparkContextInitialized = ApplicationMaster.sparkContextInitialized(sc)
if (sparkContextInitialized){
diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala
index 2c417e31db..0cd0341a72 100644
--- a/core/src/main/scala/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/spark/MapOutputTracker.scala
@@ -64,11 +64,11 @@ private[spark] class MapOutputTracker extends Logging {
// Incremented every time a fetch fails so that client nodes know to clear
// their cache of map output locations if this happens.
- private var generation: Long = 0
- private val generationLock = new java.lang.Object
+ private var epoch: Long = 0
+ private val epochLock = new java.lang.Object
// Cache a serialized version of the output statuses for each shuffle to send them out faster
- var cacheGeneration = generation
+ var cacheEpoch = epoch
private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
val metadataCleaner = new MetadataCleaner("MapOutputTracker", this.cleanup)
@@ -108,10 +108,10 @@ private[spark] class MapOutputTracker extends Logging {
def registerMapOutputs(
shuffleId: Int,
statuses: Array[MapStatus],
- changeGeneration: Boolean = false) {
+ changeEpoch: Boolean = false) {
mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses)
- if (changeGeneration) {
- incrementGeneration()
+ if (changeEpoch) {
+ incrementEpoch()
}
}
@@ -124,7 +124,7 @@ private[spark] class MapOutputTracker extends Logging {
array(mapId) = null
}
}
- incrementGeneration()
+ incrementEpoch()
} else {
throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
}
@@ -206,58 +206,58 @@ private[spark] class MapOutputTracker extends Logging {
trackerActor = null
}
- // Called on master to increment the generation number
- def incrementGeneration() {
- generationLock.synchronized {
- generation += 1
- logDebug("Increasing generation to " + generation)
+ // Called on master to increment the epoch number
+ def incrementEpoch() {
+ epochLock.synchronized {
+ epoch += 1
+ logDebug("Increasing epoch to " + epoch)
}
}
- // Called on master or workers to get current generation number
- def getGeneration: Long = {
- generationLock.synchronized {
- return generation
+ // Called on master or workers to get current epoch number
+ def getEpoch: Long = {
+ epochLock.synchronized {
+ return epoch
}
}
- // Called on workers to update the generation number, potentially clearing old outputs
- // because of a fetch failure. (Each Mesos task calls this with the latest generation
+ // Called on workers to update the epoch number, potentially clearing old outputs
+ // because of a fetch failure. (Each worker task calls this with the latest epoch
// number on the master at the time it was created.)
- def updateGeneration(newGen: Long) {
- generationLock.synchronized {
- if (newGen > generation) {
- logInfo("Updating generation to " + newGen + " and clearing cache")
+ def updateEpoch(newEpoch: Long) {
+ epochLock.synchronized {
+ if (newEpoch > epoch) {
+ logInfo("Updating epoch to " + newEpoch + " and clearing cache")
// mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
mapStatuses.clear()
- generation = newGen
+ epoch = newEpoch
}
}
}
def getSerializedLocations(shuffleId: Int): Array[Byte] = {
var statuses: Array[MapStatus] = null
- var generationGotten: Long = -1
- generationLock.synchronized {
- if (generation > cacheGeneration) {
+ var epochGotten: Long = -1
+ epochLock.synchronized {
+ if (epoch > cacheEpoch) {
cachedSerializedStatuses.clear()
- cacheGeneration = generation
+ cacheEpoch = epoch
}
cachedSerializedStatuses.get(shuffleId) match {
case Some(bytes) =>
return bytes
case None =>
statuses = mapStatuses(shuffleId)
- generationGotten = generation
+ epochGotten = epoch
}
}
// If we got here, we failed to find the serialized locations in the cache, so we pulled
// out a snapshot of the locations as "locs"; let's serialize and return that
val bytes = serializeStatuses(statuses)
logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length))
- // Add them into the table only if the generation hasn't changed while we were working
- generationLock.synchronized {
- if (generation == generationGotten) {
+ // Add them into the table only if the epoch hasn't changed while we were working
+ epochLock.synchronized {
+ if (epoch == epochGotten) {
cachedSerializedStatuses(shuffleId) = bytes
}
}
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 503ea6ccbf..f5767a3858 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -221,8 +221,8 @@ abstract class RDD[T: ClassManifest](
}
/**
- * Get the preferred location of a split, taking into account whether the
- * RDD is checkpointed or not.
+ * Get the preferred locations of a partition (as hostnames), taking into account whether the
+ * RDD is checkpointed.
*/
final def preferredLocations(split: Partition): Seq[String] = {
checkpointRDD.map(_.getPreferredLocations(split)).getOrElse {
diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala
index 05a960d7c5..036c7191ad 100644
--- a/core/src/main/scala/spark/executor/Executor.scala
+++ b/core/src/main/scala/spark/executor/Executor.scala
@@ -32,8 +32,12 @@ import spark._
/**
* The Mesos executor for Spark.
*/
-private[spark] class Executor(executorId: String, slaveHostname: String, properties: Seq[(String, String)]) extends Logging {
-
+private[spark] class Executor(
+ executorId: String,
+ slaveHostname: String,
+ properties: Seq[(String, String)])
+ extends Logging
+{
// Application dependencies (added through SparkContext) that we've fetched so far on this node.
// Each map holds the master's timestamp for the version of that file or JAR we got.
private val currentFiles: HashMap[String, Long] = new HashMap[String, Long]()
@@ -125,8 +129,8 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert
updateDependencies(taskFiles, taskJars)
val task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
attemptedTask = Some(task)
- logInfo("Its generation is " + task.generation)
- env.mapOutputTracker.updateGeneration(task.generation)
+ logInfo("Its epoch is " + task.epoch)
+ env.mapOutputTracker.updateEpoch(task.epoch)
taskStart = System.currentTimeMillis()
val value = task.run(taskId.toInt)
val taskFinish = System.currentTimeMillis()
diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala
index 0ebb722d73..03800584ae 100644
--- a/core/src/main/scala/spark/rdd/BlockRDD.scala
+++ b/core/src/main/scala/spark/rdd/BlockRDD.scala
@@ -28,13 +28,12 @@ private[spark]
class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[String])
extends RDD[T](sc, Nil) {
- @transient lazy val locations_ = BlockManager.blockIdsToExecutorLocations(blockIds, SparkEnv.get)
+ @transient lazy val locations_ = BlockManager.blockIdsToHosts(blockIds, SparkEnv.get)
override def getPartitions: Array[Partition] = (0 until blockIds.size).map(i => {
new BlockRDDPartition(blockIds(i), i).asInstanceOf[Partition]
}).toArray
-
override def compute(split: Partition, context: TaskContext): Iterator[T] = {
val blockManager = SparkEnv.get.blockManager
val blockId = split.asInstanceOf[BlockRDDPartition].blockId
@@ -45,8 +44,8 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St
}
}
- override def getPreferredLocations(split: Partition): Seq[String] =
+ override def getPreferredLocations(split: Partition): Seq[String] = {
locations_(split.asInstanceOf[BlockRDDPartition].blockId)
-
+ }
}
diff --git a/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala
index 6a4fa13ad6..51f5cc3251 100644
--- a/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala
+++ b/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala
@@ -55,6 +55,8 @@ abstract class ZippedPartitionsBaseRDD[V: ClassManifest](
}
override def getPreferredLocations(s: Partition): Seq[String] = {
+ // TODO(matei): Fix this for hostPort
+
// Note that as number of rdd's increase and/or number of slaves in cluster increase, the computed preferredLocations below
// become diminishingly small : so we might need to look at alternate strategies to alleviate this.
// If there are no (or very small number of preferred locations), we will end up transferred the blocks to 'any' node in the
diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
index fbf3f4c807..2f7e6d98f8 100644
--- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
@@ -72,8 +72,8 @@ class DAGScheduler(
}
// Called by TaskScheduler when a host is added
- override def executorGained(execId: String, hostPort: String) {
- eventQueue.put(ExecutorGained(execId, hostPort))
+ override def executorGained(execId: String, host: String) {
+ eventQueue.put(ExecutorGained(execId, host))
}
// Called by TaskScheduler to cancel an entire TaskSet due to repeated failures.
@@ -104,15 +104,16 @@ class DAGScheduler(
private val listenerBus = new SparkListenerBus()
- var cacheLocs = new HashMap[Int, Array[List[String]]]
+ // Contains the locations that each RDD's partitions are cached on
+ private val cacheLocs = new HashMap[Int, Array[Seq[TaskLocation]]]
- // For tracking failed nodes, we use the MapOutputTracker's generation number, which is
- // sent with every task. When we detect a node failing, we note the current generation number
- // and failed executor, increment it for new tasks, and use this to ignore stray ShuffleMapTask
- // results.
- // TODO: Garbage collect information about failure generations when we know there are no more
+ // For tracking failed nodes, we use the MapOutputTracker's epoch number, which is sent with
+ // every task. When we detect a node failing, we note the current epoch number and failed
+ // executor, increment it for new tasks, and use this to ignore stray ShuffleMapTask results.
+ //
+ // TODO: Garbage collect information about failure epochs when we know there are no more
// stray messages to detect.
- val failedGeneration = new HashMap[String, Long]
+ val failedEpoch = new HashMap[String, Long]
val idToActiveJob = new HashMap[Int, ActiveJob]
@@ -141,11 +142,13 @@ class DAGScheduler(
listenerBus.addListener(listener)
}
- private def getCacheLocs(rdd: RDD[_]): Array[List[String]] = {
+ private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = {
if (!cacheLocs.contains(rdd.id)) {
val blockIds = rdd.partitions.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray
- val locs = BlockManager.blockIdsToExecutorLocations(blockIds, env, blockManagerMaster)
- cacheLocs(rdd.id) = blockIds.map(locs.getOrElse(_, Nil))
+ val locs = BlockManager.blockIdsToBlockManagers(blockIds, env, blockManagerMaster)
+ cacheLocs(rdd.id) = blockIds.map { id =>
+ locs.getOrElse(id, Nil).map(bm => TaskLocation(bm.host, bm.executorId))
+ }
}
cacheLocs(rdd.id)
}
@@ -345,8 +348,8 @@ class DAGScheduler(
submitStage(finalStage)
}
- case ExecutorGained(execId, hostPort) =>
- handleExecutorGained(execId, hostPort)
+ case ExecutorGained(execId, host) =>
+ handleExecutorGained(execId, host)
case ExecutorLost(execId) =>
handleExecutorLost(execId)
@@ -508,7 +511,7 @@ class DAGScheduler(
} else {
// This is a final stage; figure out its job's missing partitions
val job = resultStageToJob(stage)
- for (id <- 0 until job.numPartitions if (!job.finished(id))) {
+ for (id <- 0 until job.numPartitions if !job.finished(id)) {
val partition = job.partitions(id)
val locs = getPreferredLocs(stage.rdd, partition)
tasks += new ResultTask(stage.id, stage.rdd, job.func, partition, locs, id)
@@ -518,7 +521,7 @@ class DAGScheduler(
// should be "StageSubmitted" first and then "JobEnded"
val properties = idToActiveJob(stage.priority).properties
listenerBus.post(SparkListenerStageSubmitted(stage, tasks.size, properties))
-
+
if (tasks.size > 0) {
// Preemptively serialize a task to make sure it can be serialized. We are catching this
// exception here because it would be fairly hard to catch the non-serializable exception
@@ -599,7 +602,7 @@ class DAGScheduler(
val status = event.result.asInstanceOf[MapStatus]
val execId = status.location.executorId
logDebug("ShuffleMapTask finished on " + execId)
- if (failedGeneration.contains(execId) && smt.generation <= failedGeneration(execId)) {
+ if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) {
logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId)
} else {
stage.addOutputLoc(smt.partition, status)
@@ -611,11 +614,11 @@ class DAGScheduler(
logInfo("waiting: " + waiting)
logInfo("failed: " + failed)
if (stage.shuffleDep != None) {
- // We supply true to increment the generation number here in case this is a
+ // We supply true to increment the epoch number here in case this is a
// recomputation of the map outputs. In that case, some nodes may have cached
// locations with holes (from when we detected the error) and will need the
- // generation incremented to refetch them.
- // TODO: Only increment the generation number if this is not the first time
+ // epoch incremented to refetch them.
+ // TODO: Only increment the epoch number if this is not the first time
// we registered these map outputs.
mapOutputTracker.registerMapOutputs(
stage.shuffleDep.get.shuffleId,
@@ -674,7 +677,7 @@ class DAGScheduler(
lastFetchFailureTime = System.currentTimeMillis() // TODO: Use pluggable clock
// TODO: mark the executor as failed only if there were lots of fetch failures on it
if (bmAddress != null) {
- handleExecutorLost(bmAddress.executorId, Some(task.generation))
+ handleExecutorLost(bmAddress.executorId, Some(task.epoch))
}
case ExceptionFailure(className, description, stackTrace, metrics) =>
@@ -690,14 +693,14 @@ class DAGScheduler(
* Responds to an executor being lost. This is called inside the event loop, so it assumes it can
* modify the scheduler's internal state. Use executorLost() to post a loss event from outside.
*
- * Optionally the generation during which the failure was caught can be passed to avoid allowing
+ * Optionally the epoch during which the failure was caught can be passed to avoid allowing
* stray fetch failures from possibly retriggering the detection of a node as lost.
*/
- private def handleExecutorLost(execId: String, maybeGeneration: Option[Long] = None) {
- val currentGeneration = maybeGeneration.getOrElse(mapOutputTracker.getGeneration)
- if (!failedGeneration.contains(execId) || failedGeneration(execId) < currentGeneration) {
- failedGeneration(execId) = currentGeneration
- logInfo("Executor lost: %s (generation %d)".format(execId, currentGeneration))
+ private def handleExecutorLost(execId: String, maybeEpoch: Option[Long] = None) {
+ val currentEpoch = maybeEpoch.getOrElse(mapOutputTracker.getEpoch)
+ if (!failedEpoch.contains(execId) || failedEpoch(execId) < currentEpoch) {
+ failedEpoch(execId) = currentEpoch
+ logInfo("Executor lost: %s (epoch %d)".format(execId, currentEpoch))
blockManagerMaster.removeExecutor(execId)
// TODO: This will be really slow if we keep accumulating shuffle map stages
for ((shuffleId, stage) <- shuffleToMapStage) {
@@ -706,20 +709,20 @@ class DAGScheduler(
mapOutputTracker.registerMapOutputs(shuffleId, locs, true)
}
if (shuffleToMapStage.isEmpty) {
- mapOutputTracker.incrementGeneration()
+ mapOutputTracker.incrementEpoch()
}
clearCacheLocs()
} else {
logDebug("Additional executor lost message for " + execId +
- "(generation " + currentGeneration + ")")
+ "(epoch " + currentEpoch + ")")
}
}
- private def handleExecutorGained(execId: String, hostPort: String) {
- // remove from failedGeneration(execId) ?
- if (failedGeneration.contains(execId)) {
- logInfo("Host gained which was in lost list earlier: " + hostPort)
- failedGeneration -= execId
+ private def handleExecutorGained(execId: String, host: String) {
+ // remove from failedEpoch(execId) ?
+ if (failedEpoch.contains(execId)) {
+ logInfo("Host gained which was in lost list earlier: " + host)
+ failedEpoch -= execId
}
}
@@ -774,16 +777,16 @@ class DAGScheduler(
visitedRdds.contains(target.rdd)
}
- private def getPreferredLocs(rdd: RDD[_], partition: Int): List[String] = {
+ private def getPreferredLocs(rdd: RDD[_], partition: Int): Seq[TaskLocation] = {
// If the partition is cached, return the cache locations
val cached = getCacheLocs(rdd)(partition)
- if (cached != Nil) {
+ if (!cached.isEmpty) {
return cached
}
// If the RDD has some placement preferences (as is the case for input RDDs), get those
val rddPrefs = rdd.preferredLocations(rdd.partitions(partition)).toList
- if (rddPrefs != Nil) {
- return rddPrefs
+ if (!rddPrefs.isEmpty) {
+ return rddPrefs.map(host => TaskLocation(host))
}
// If the RDD has narrow dependencies, pick the first partition of the first narrow dep
// that has any placement preferences. Ideally we would choose based on transfer sizes,
diff --git a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala
index 3b4ee6287a..b8ba0e9239 100644
--- a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala
@@ -54,9 +54,7 @@ private[spark] case class CompletionEvent(
taskMetrics: TaskMetrics)
extends DAGSchedulerEvent
-private[spark] case class ExecutorGained(execId: String, hostPort: String) extends DAGSchedulerEvent {
- Utils.checkHostPort(hostPort, "Required hostport")
-}
+private[spark] case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent
private[spark] case class ExecutorLost(execId: String) extends DAGSchedulerEvent
diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala
index 832ca18b8c..d066df5dc1 100644
--- a/core/src/main/scala/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/spark/scheduler/ResultTask.scala
@@ -73,7 +73,7 @@ private[spark] class ResultTask[T, U](
var rdd: RDD[T],
var func: (TaskContext, Iterator[T]) => U,
var partition: Int,
- @transient locs: Seq[String],
+ @transient locs: Seq[TaskLocation],
val outputId: Int)
extends Task[U](stageId) with Externalizable {
@@ -85,11 +85,8 @@ private[spark] class ResultTask[T, U](
rdd.partitions(partition)
}
- private val preferredLocs: Seq[String] = if (locs == null) Nil else locs.toSet.toSeq
-
- {
- // DEBUG code
- preferredLocs.foreach (hostPort => Utils.checkHost(Utils.parseHostPort(hostPort)._1, "preferredLocs : " + preferredLocs))
+ @transient private val preferredLocs: Seq[TaskLocation] = {
+ if (locs == null) Nil else locs.toSet.toSeq
}
override def run(attemptId: Long): U = {
@@ -102,7 +99,7 @@ private[spark] class ResultTask[T, U](
}
}
- override def preferredLocations: Seq[String] = preferredLocs
+ override def preferredLocations: Seq[TaskLocation] = preferredLocs
override def toString = "ResultTask(" + stageId + ", " + partition + ")"
@@ -116,7 +113,7 @@ private[spark] class ResultTask[T, U](
out.write(bytes)
out.writeInt(partition)
out.writeInt(outputId)
- out.writeLong(generation)
+ out.writeLong(epoch)
out.writeObject(split)
}
}
@@ -131,7 +128,7 @@ private[spark] class ResultTask[T, U](
func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U]
partition = in.readInt()
val outputId = in.readInt()
- generation = in.readLong()
+ epoch = in.readLong()
split = in.readObject().asInstanceOf[Partition]
}
}
diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
index e3bb6d1e60..2dbaef24ac 100644
--- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
@@ -88,18 +88,15 @@ private[spark] class ShuffleMapTask(
var rdd: RDD[_],
var dep: ShuffleDependency[_,_],
var partition: Int,
- @transient private var locs: Seq[String])
+ @transient private var locs: Seq[TaskLocation])
extends Task[MapStatus](stageId)
with Externalizable
with Logging {
protected def this() = this(0, null, null, 0, null)
- @transient private val preferredLocs: Seq[String] = if (locs == null) Nil else locs.toSet.toSeq
-
- {
- // DEBUG code
- preferredLocs.foreach (hostPort => Utils.checkHost(Utils.parseHostPort(hostPort)._1, "preferredLocs : " + preferredLocs))
+ @transient private val preferredLocs: Seq[TaskLocation] = {
+ if (locs == null) Nil else locs.toSet.toSeq
}
var split = if (rdd == null) null else rdd.partitions(partition)
@@ -112,7 +109,7 @@ private[spark] class ShuffleMapTask(
out.writeInt(bytes.length)
out.write(bytes)
out.writeInt(partition)
- out.writeLong(generation)
+ out.writeLong(epoch)
out.writeObject(split)
}
}
@@ -126,7 +123,7 @@ private[spark] class ShuffleMapTask(
rdd = rdd_
dep = dep_
partition = in.readInt()
- generation = in.readLong()
+ epoch = in.readLong()
split = in.readObject().asInstanceOf[Partition]
}
@@ -186,7 +183,7 @@ private[spark] class ShuffleMapTask(
}
}
- override def preferredLocations: Seq[String] = preferredLocs
+ override def preferredLocations: Seq[TaskLocation] = preferredLocs
override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partition)
}
diff --git a/core/src/main/scala/spark/scheduler/Task.scala b/core/src/main/scala/spark/scheduler/Task.scala
index 50768d43e0..0ab2ae6cfe 100644
--- a/core/src/main/scala/spark/scheduler/Task.scala
+++ b/core/src/main/scala/spark/scheduler/Task.scala
@@ -30,9 +30,9 @@ import spark.executor.TaskMetrics
*/
private[spark] abstract class Task[T](val stageId: Int) extends Serializable {
def run(attemptId: Long): T
- def preferredLocations: Seq[String] = Nil
+ def preferredLocations: Seq[TaskLocation] = Nil
- var generation: Long = -1 // Map output tracker generation. Will be set by TaskScheduler.
+ var epoch: Long = -1 // Map output tracker epoch. Will be set by TaskScheduler.
var metrics: Option[TaskMetrics] = None
diff --git a/core/src/main/scala/spark/scheduler/TaskLocation.scala b/core/src/main/scala/spark/scheduler/TaskLocation.scala
new file mode 100644
index 0000000000..0e97c61188
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/TaskLocation.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package spark.scheduler
+
+/**
+ * A location where a task should run. This can either be a host or a (host, executorID) pair.
+ * In the latter case, we will prefer to launch the task on that executorID, but our next level
+ * of preference will be executors on the same host if this is not possible.
+ */
+private[spark]
+class TaskLocation private (val host: String, val executorId: Option[String]) extends Serializable
+
+private[spark] object TaskLocation {
+ def apply(host: String, executorId: String) = new TaskLocation(host, Some(executorId))
+
+ def apply(host: String) = new TaskLocation(host, None)
+}
diff --git a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala
index 2cdeb1c8c0..64be50b2d0 100644
--- a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala
+++ b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala
@@ -35,7 +35,7 @@ private[spark] trait TaskSchedulerListener {
taskInfo: TaskInfo, taskMetrics: TaskMetrics): Unit
// A node was added to the cluster.
- def executorGained(execId: String, hostPort: String): Unit
+ def executorGained(execId: String, host: String): Unit
// A node was lost from the cluster.
def executorLost(execId: String): Unit
diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
index 96568e0d27..036e36bca0 100644
--- a/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/ClusterScheduler.scala
@@ -37,18 +37,22 @@ import java.util.{TimerTask, Timer}
*/
private[spark] class ClusterScheduler(val sc: SparkContext)
extends TaskScheduler
- with Logging {
-
+ with Logging
+{
// How often to check for speculative tasks
val SPECULATION_INTERVAL = System.getProperty("spark.speculation.interval", "100").toLong
+
// Threshold above which we warn user initial TaskSet may be starved
val STARVATION_TIMEOUT = System.getProperty("spark.starvation.timeout", "15000").toLong
+
// How often to revive offers in case there are pending tasks - that is how often to try to get
// tasks scheduled in case there are nodes available : default 0 is to disable it - to preserve existing behavior
- // Note that this is required due to delayed scheduling due to data locality waits, etc.
- // TODO: rename property ?
- val TASK_REVIVAL_INTERVAL = System.getProperty("spark.tasks.revive.interval", "0").toLong
+ // Note that this is required due to delay scheduling due to data locality waits, etc.
+ // TODO(matei): move to StandaloneSchedulerBackend?
+ val TASK_REVIVAL_INTERVAL = System.getProperty("spark.scheduler.revival.interval", "1000").toLong
+ // TODO(matei): replace this with something that only affects levels past PROCESS_LOCAL;
+ // basically it can be a "cliff" for locality
/*
This property controls how aggressive we should be to modulate waiting for node local task scheduling.
To elaborate, currently there is a time limit (3 sec def) to ensure that spark attempts to wait for node locality of tasks before
@@ -71,7 +75,8 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
If cluster is rack aware, then setting it to RACK_LOCAL gives best tradeoff and a 3x - 4x performance improvement while minimizing IO impact.
Also, it brings down the variance in running time drastically.
*/
- val TASK_SCHEDULING_AGGRESSION = TaskLocality.parse(System.getProperty("spark.tasks.schedule.aggression", "NODE_LOCAL"))
+ val TASK_SCHEDULING_AGGRESSION = TaskLocality.withName(
+ System.getProperty("spark.tasks.schedule.aggression", "NODE_LOCAL"))
val activeTaskSets = new HashMap[String, TaskSetManager]
@@ -89,16 +94,11 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
// Which executor IDs we have executors on
val activeExecutorIds = new HashSet[String]
- // TODO: We might want to remove this and merge it with execId datastructures - but later.
- // Which hosts in the cluster are alive (contains hostPort's) - used for process local and node local task locality.
- private val hostPortsAlive = new HashSet[String]
- private val hostToAliveHostPorts = new HashMap[String, HashSet[String]]
-
// The set of executors we have on each host; this is used to compute hostsAlive, which
// in turn is used to decide when we can attain data locality on a given host
- private val executorsByHostPort = new HashMap[String, HashSet[String]]
+ private val executorsByHost = new HashMap[String, HashSet[String]]
- private val executorIdToHostPort = new HashMap[String, String]
+ private val executorIdToHost = new HashMap[String, String]
// JAR server, if any JARs were added by the user to the SparkContext
var jarServer: HttpServer = null
@@ -138,7 +138,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
schedulableBuilder.buildPools()
// resolve executorId to hostPort mapping.
def executorToHostPort(executorId: String, defaultHostPort: String): String = {
- executorIdToHostPort.getOrElse(executorId, defaultHostPort)
+ executorIdToHost.getOrElse(executorId, defaultHostPort)
}
// Unfortunately, this means that SparkEnv is indirectly referencing ClusterScheduler
@@ -146,13 +146,12 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
SparkEnv.get.executorIdToHostPort = Some(executorToHostPort)
}
-
def newTaskId(): Long = nextTaskId.getAndIncrement()
override def start() {
backend.start()
- if (JBoolean.getBoolean("spark.speculation")) {
+ if (System.getProperty("spark.speculation", "false").toBoolean) {
new Thread("ClusterScheduler speculation check") {
setDaemon(true)
@@ -172,6 +171,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
// Change to always run with some default if TASK_REVIVAL_INTERVAL <= 0 ?
+ // TODO(matei): remove this thread
if (TASK_REVIVAL_INTERVAL > 0) {
new Thread("ClusterScheduler task offer revival check") {
setDaemon(true)
@@ -201,7 +201,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
taskSetTaskIds(taskSet.id) = new HashSet[Long]()
- if (hasReceivedTask == false) {
+ if (!hasReceivedTask) {
starvationTimer.scheduleAtFixedRate(new TimerTask() {
override def run() {
if (!hasLaunchedTask) {
@@ -214,7 +214,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
}, STARVATION_TIMEOUT, STARVATION_TIMEOUT)
}
- hasReceivedTask = true;
+ hasReceivedTask = true
}
backend.reviveOffers()
}
@@ -235,172 +235,53 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
* sets for tasks in order of priority. We fill each node with tasks in a round-robin manner so
* that tasks are balanced across the cluster.
*/
- def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = {
- synchronized {
- SparkEnv.set(sc.env)
- // Mark each slave as alive and remember its hostname
- for (o <- offers) {
- // DEBUG Code
- Utils.checkHostPort(o.hostPort)
-
- executorIdToHostPort(o.executorId) = o.hostPort
- if (! executorsByHostPort.contains(o.hostPort)) {
- executorsByHostPort(o.hostPort) = new HashSet[String]()
- }
-
- hostPortsAlive += o.hostPort
- hostToAliveHostPorts.getOrElseUpdate(Utils.parseHostPort(o.hostPort)._1, new HashSet[String]).add(o.hostPort)
- executorGained(o.executorId, o.hostPort)
- }
- // Build a list of tasks to assign to each slave
- val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores))
- // merge availableCpus into nodeToAvailableCpus block ?
- val availableCpus = offers.map(o => o.cores).toArray
- val nodeToAvailableCpus = {
- val map = new HashMap[String, Int]()
- for (offer <- offers) {
- val hostPort = offer.hostPort
- val cores = offer.cores
- // DEBUG code
- Utils.checkHostPort(hostPort)
-
- val host = Utils.parseHostPort(hostPort)._1
-
- map.put(host, map.getOrElse(host, 0) + cores)
- }
-
- map
+ def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = synchronized {
+ SparkEnv.set(sc.env)
+
+ // Mark each slave as alive and remember its hostname
+ for (o <- offers) {
+ executorIdToHost(o.executorId) = o.host
+ if (!executorsByHost.contains(o.host)) {
+ executorsByHost(o.host) = new HashSet[String]()
+ executorGained(o.executorId, o.host)
}
- var launchedTask = false
- val sortedTaskSetQueue = rootPool.getSortedTaskSetQueue()
-
- for (manager <- sortedTaskSetQueue) {
- logDebug("parentName:%s, name:%s, runningTasks:%s".format(
- manager.parent.name, manager.name, manager.runningTasks))
- }
-
- for (manager <- sortedTaskSetQueue) {
+ }
- // Split offers based on node local, rack local and off-rack tasks.
- val processLocalOffers = new HashMap[String, ArrayBuffer[Int]]()
- val nodeLocalOffers = new HashMap[String, ArrayBuffer[Int]]()
- val rackLocalOffers = new HashMap[String, ArrayBuffer[Int]]()
- val otherOffers = new HashMap[String, ArrayBuffer[Int]]()
+ // Build a list of tasks to assign to each slave
+ val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores))
+ val availableCpus = offers.map(o => o.cores).toArray
+ val sortedTaskSetQueue = rootPool.getSortedTaskSetQueue()
+ for (manager <- sortedTaskSetQueue) {
+ logDebug("parentName: %s, name: %s, runningTasks: %s".format(
+ manager.parent.name, manager.name, manager.runningTasks))
+ }
+ var launchedTask = false
+ for (manager <- sortedTaskSetQueue; offer <- offers) {
+ do {
+ launchedTask = false
for (i <- 0 until offers.size) {
- val hostPort = offers(i).hostPort
- // DEBUG code
- Utils.checkHostPort(hostPort)
-
- val numProcessLocalTasks = math.max(0, math.min(manager.numPendingTasksForHostPort(hostPort), availableCpus(i)))
- if (numProcessLocalTasks > 0){
- val list = processLocalOffers.getOrElseUpdate(hostPort, new ArrayBuffer[Int])
- for (j <- 0 until numProcessLocalTasks) list += i
+ val execId = offers(i).executorId
+ val host = offers(i).host
+ for (task <- manager.resourceOffer(execId, host, availableCpus(i))) {
+ tasks(i) += task
+ val tid = task.taskId
+ taskIdToTaskSetId(tid) = manager.taskSet.id
+ taskSetTaskIds(manager.taskSet.id) += tid
+ taskIdToExecutorId(tid) = execId
+ activeExecutorIds += execId
+ executorsByHost(host) += execId
+ availableCpus(i) -= 1
+ launchedTask = true
}
-
- val host = Utils.parseHostPort(hostPort)._1
- val numNodeLocalTasks = math.max(0,
- // Remove process local tasks (which are also host local btw !) from this
- math.min(manager.numPendingTasksForHost(hostPort) - numProcessLocalTasks, nodeToAvailableCpus(host)))
- if (numNodeLocalTasks > 0){
- val list = nodeLocalOffers.getOrElseUpdate(host, new ArrayBuffer[Int])
- for (j <- 0 until numNodeLocalTasks) list += i
- }
-
- val numRackLocalTasks = math.max(0,
- // Remove node local tasks (which are also rack local btw !) from this
- math.min(manager.numRackLocalPendingTasksForHost(hostPort) - numProcessLocalTasks - numNodeLocalTasks, nodeToAvailableCpus(host)))
- if (numRackLocalTasks > 0){
- val list = rackLocalOffers.getOrElseUpdate(host, new ArrayBuffer[Int])
- for (j <- 0 until numRackLocalTasks) list += i
- }
- if (numNodeLocalTasks <= 0 && numRackLocalTasks <= 0){
- // add to others list - spread even this across cluster.
- val list = otherOffers.getOrElseUpdate(host, new ArrayBuffer[Int])
- list += i
- }
- }
-
- val offersPriorityList = new ArrayBuffer[Int](
- processLocalOffers.size + nodeLocalOffers.size + rackLocalOffers.size + otherOffers.size)
-
- // First process local, then host local, then rack, then others
-
- // numNodeLocalOffers contains count of both process local and host offers.
- val numNodeLocalOffers = {
- val processLocalPriorityList = ClusterScheduler.prioritizeContainers(processLocalOffers)
- offersPriorityList ++= processLocalPriorityList
-
- val nodeLocalPriorityList = ClusterScheduler.prioritizeContainers(nodeLocalOffers)
- offersPriorityList ++= nodeLocalPriorityList
-
- processLocalPriorityList.size + nodeLocalPriorityList.size
- }
- val numRackLocalOffers = {
- val rackLocalPriorityList = ClusterScheduler.prioritizeContainers(rackLocalOffers)
- offersPriorityList ++= rackLocalPriorityList
- rackLocalPriorityList.size
}
- offersPriorityList ++= ClusterScheduler.prioritizeContainers(otherOffers)
-
- var lastLoop = false
- val lastLoopIndex = TASK_SCHEDULING_AGGRESSION match {
- case TaskLocality.NODE_LOCAL => numNodeLocalOffers
- case TaskLocality.RACK_LOCAL => numRackLocalOffers + numNodeLocalOffers
- case TaskLocality.ANY => offersPriorityList.size
- }
-
- do {
- launchedTask = false
- var loopCount = 0
- for (i <- offersPriorityList) {
- val execId = offers(i).executorId
- val hostPort = offers(i).hostPort
-
- // If last loop and within the lastLoopIndex, expand scope - else use null (which will use default/existing)
- val overrideLocality = if (lastLoop && loopCount < lastLoopIndex) TASK_SCHEDULING_AGGRESSION else null
-
- // If last loop, override waiting for host locality - we scheduled all local tasks already and there might be more available ...
- loopCount += 1
-
- manager.slaveOffer(execId, hostPort, availableCpus(i), overrideLocality) match {
- case Some(task) =>
- tasks(i) += task
- val tid = task.taskId
- taskIdToTaskSetId(tid) = manager.taskSet.id
- taskSetTaskIds(manager.taskSet.id) += tid
- taskIdToExecutorId(tid) = execId
- activeExecutorIds += execId
- executorsByHostPort(hostPort) += execId
- availableCpus(i) -= 1
- launchedTask = true
-
- case None => {}
- }
- }
- // Loop once more - when lastLoop = true, then we try to schedule task on all nodes irrespective of
- // data locality (we still go in order of priority : but that would not change anything since
- // if data local tasks had been available, we would have scheduled them already)
- if (lastLoop) {
- // prevent more looping
- launchedTask = false
- } else if (!lastLoop && !launchedTask) {
- // Do this only if TASK_SCHEDULING_AGGRESSION != NODE_LOCAL
- if (TASK_SCHEDULING_AGGRESSION != TaskLocality.NODE_LOCAL) {
- // fudge launchedTask to ensure we loop once more
- launchedTask = true
- // dont loop anymore
- lastLoop = true
- }
- }
- } while (launchedTask)
- }
+ } while (launchedTask)
+ }
- if (tasks.size > 0) {
- hasLaunchedTask = true
- }
- return tasks
+ if (tasks.size > 0) {
+ hasLaunchedTask = true
}
+ return tasks
}
def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
@@ -514,7 +395,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
synchronized {
if (activeExecutorIds.contains(executorId)) {
- val hostPort = executorIdToHostPort(executorId)
+ val hostPort = executorIdToHost(executorId)
logError("Lost executor %s on %s: %s".format(executorId, hostPort, reason))
removeExecutor(executorId)
failedExecutor = Some(executorId)
@@ -535,52 +416,37 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
/** Remove an executor from all our data structures and mark it as lost */
private def removeExecutor(executorId: String) {
+ // TODO(matei): remove HostPort
activeExecutorIds -= executorId
- val hostPort = executorIdToHostPort(executorId)
- if (hostPortsAlive.contains(hostPort)) {
- // DEBUG Code
- Utils.checkHostPort(hostPort)
+ val host = executorIdToHost(executorId)
- hostPortsAlive -= hostPort
- hostToAliveHostPorts.getOrElseUpdate(Utils.parseHostPort(hostPort)._1, new HashSet[String]).remove(hostPort)
- }
-
- val execs = executorsByHostPort.getOrElse(hostPort, new HashSet)
+ val execs = executorsByHost.getOrElse(host, new HashSet)
execs -= executorId
if (execs.isEmpty) {
- executorsByHostPort -= hostPort
+ executorsByHost -= host
}
- executorIdToHostPort -= executorId
- rootPool.executorLost(executorId, hostPort)
+ executorIdToHost -= executorId
+ rootPool.executorLost(executorId, host)
}
- def executorGained(execId: String, hostPort: String) {
- listener.executorGained(execId, hostPort)
+ def executorGained(execId: String, host: String) {
+ listener.executorGained(execId, host)
}
- def getExecutorsAliveOnHost(host: String): Option[Set[String]] = {
- Utils.checkHost(host)
-
- val retval = hostToAliveHostPorts.get(host)
- if (retval.isDefined) {
- return Some(retval.get.toSet)
- }
+ def getExecutorsAliveOnHost(host: String): Option[Set[String]] = synchronized {
+ executorsByHost.get(host).map(_.toSet)
+ }
- None
+ def hasExecutorsAliveOnHost(host: String): Boolean = synchronized {
+ executorsByHost.contains(host)
}
- def isExecutorAliveOnHostPort(hostPort: String): Boolean = {
- // Even if hostPort is a host, it does not matter - it is just a specific check.
- // But we do have to ensure that only hostPort get into hostPortsAlive !
- // So no check against Utils.checkHostPort
- hostPortsAlive.contains(hostPort)
+ def isExecutorAlive(execId: String): Boolean = synchronized {
+ activeExecutorIds.contains(execId)
}
// By default, rack is unknown
def getRackForHost(value: String): Option[String] = None
-
- // By default, (cached) hosts for rack is unknown
- def getCachedHostsForRack(rack: String): Option[Set[String]] = None
}
@@ -610,6 +476,7 @@ object ClusterScheduler {
// order keyList based on population of value in map
val keyList = _keyList.sortWith(
+ // TODO(matei): not sure why we're using getOrElse if keyList = map.keys... see if it matters
(left, right) => map.get(left).getOrElse(Set()).size > map.get(right).getOrElse(Set()).size
)
@@ -617,7 +484,7 @@ object ClusterScheduler {
var index = 0
var found = true
- while (found){
+ while (found) {
found = false
for (key <- keyList) {
val containerList: ArrayBuffer[T] = map.get(key).getOrElse(null)
diff --git a/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala
index 7f855cd345..1947c516db 100644
--- a/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/ClusterTaskSetManager.scala
@@ -29,49 +29,13 @@ import scala.math.min
import spark.{FetchFailed, Logging, Resubmitted, SparkEnv, Success, TaskEndReason, TaskState, Utils}
import spark.{ExceptionFailure, SparkException, TaskResultTooBigFailure}
import spark.TaskState.TaskState
-import spark.scheduler.{ShuffleMapTask, Task, TaskResult, TaskSet}
+import spark.scheduler._
+import scala.Some
+import spark.FetchFailed
+import spark.ExceptionFailure
+import spark.TaskResultTooBigFailure
-private[spark] object TaskLocality
- extends Enumeration("PROCESS_LOCAL", "NODE_LOCAL", "RACK_LOCAL", "ANY") with Logging {
-
- // process local is expected to be used ONLY within tasksetmanager for now.
- val PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY = Value
-
- type TaskLocality = Value
-
- def isAllowed(constraint: TaskLocality, condition: TaskLocality): Boolean = {
-
- // Must not be the constraint.
- assert (constraint != TaskLocality.PROCESS_LOCAL)
-
- constraint match {
- case TaskLocality.NODE_LOCAL =>
- condition == TaskLocality.NODE_LOCAL
- case TaskLocality.RACK_LOCAL =>
- condition == TaskLocality.NODE_LOCAL || condition == TaskLocality.RACK_LOCAL
- // For anything else, allow
- case _ => true
- }
- }
-
- def parse(str: String): TaskLocality = {
- // better way to do this ?
- try {
- val retval = TaskLocality.withName(str)
- // Must not specify PROCESS_LOCAL !
- assert (retval != TaskLocality.PROCESS_LOCAL)
- retval
- } catch {
- case nEx: NoSuchElementException => {
- logWarning("Invalid task locality specified '" + str + "', defaulting to NODE_LOCAL")
- // default to preserve earlier behavior
- NODE_LOCAL
- }
- }
- }
-}
-
/**
* Schedules the tasks within a single TaskSet in the ClusterScheduler.
*/
@@ -113,28 +77,26 @@ private[spark] class ClusterTaskSetManager(sched: ClusterScheduler, val taskSet:
// Last time when we launched a preferred task (for delay scheduling)
var lastPreferredLaunchTime = System.currentTimeMillis
- // List of pending tasks for each node (process local to container).
- // These collections are actually
+ // Set of pending tasks for each executor. These collections are actually
// treated as stacks, in which new tasks are added to the end of the
// ArrayBuffer and removed from the end. This makes it faster to detect
// tasks that repeatedly fail because whenever a task failed, it is put
// back at the head of the stack. They are also only cleaned up lazily;
// when a task is launched, it remains in all the pending lists except
// the one that it was launched from, but gets removed from them later.
- private val pendingTasksForHostPort = new HashMap[String, ArrayBuffer[Int]]
+ private val pendingTasksForExecutor = new HashMap[String, ArrayBuffer[Int]]
- // List of pending tasks for each node.
- // Essentially, similar to pendingTasksForHostPort, except at host level
+ // Set of pending tasks for each host. Similar to pendingTasksForExecutor,
+ // but at host level.
private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]]
- // List of pending tasks for each node based on rack locality.
- // Essentially, similar to pendingTasksForHost, except at rack level
- private val pendingRackLocalTasksForHost = new HashMap[String, ArrayBuffer[Int]]
+ // Set of pending tasks for each rack -- similar to the above.
+ private val pendingTasksForRack = new HashMap[String, ArrayBuffer[Int]]
- // List containing pending tasks with no locality preferences
+ // Set containing pending tasks with no locality preferences.
val pendingTasksWithNoPrefs = new ArrayBuffer[Int]
- // List containing all pending tasks (also used as a stack, as above)
+ // Set containing all pending tasks (also used as a stack, as above).
val allPendingTasks = new ArrayBuffer[Int]
// Tasks that can be speculated. Since these will be a small fraction of total
@@ -144,13 +106,14 @@ private[spark] class ClusterTaskSetManager(sched: ClusterScheduler, val taskSet:
// Task index, start and finish time for each task attempt (indexed by task ID)
val taskInfos = new HashMap[Long, TaskInfo]
- // Did the job fail?
+ // Did the TaskSet fail?
var failed = false
var causeOfFailure = ""
// How frequently to reprint duplicate exceptions in full, in milliseconds
val EXCEPTION_PRINT_INTERVAL =
System.getProperty("spark.logging.exceptionPrintInterval", "10000").toLong
+
// Map of recent exceptions (identified by string representation and
// top stack frame) to duplicate count (how many times the same
// exception has appeared) and time the full exception was
@@ -158,11 +121,11 @@ private[spark] class ClusterTaskSetManager(sched: ClusterScheduler, val taskSet:
// exceptions automatically.
val recentExceptions = HashMap[String, (Int, Long)]()
- // Figure out the current map output tracker generation and set it on all tasks
- val generation = sched.mapOutputTracker.getGeneration
- logDebug("Generation for " + taskSet.id + ": " + generation)
+ // Figure out the current map output tracker epoch and set it on all tasks
+ val epoch = sched.mapOutputTracker.getEpoch
+ logDebug("Epoch for " + taskSet.id + ": " + epoch)
for (t <- tasks) {
- t.generation = generation
+ t.epoch = epoch
}
// Add all our tasks to the pending lists. We do this in reverse order
@@ -171,166 +134,74 @@ private[spark] class ClusterTaskSetManager(sched: ClusterScheduler, val taskSet:
addPendingTask(i)
}
- // Note that it follows the hierarchy.
- // if we search for NODE_LOCAL, the output will include PROCESS_LOCAL and
- // if we search for RACK_LOCAL, it will include PROCESS_LOCAL & NODE_LOCAL
- private def findPreferredLocations(
- _taskPreferredLocations: Seq[String],
- scheduler: ClusterScheduler,
- taskLocality: TaskLocality.TaskLocality): HashSet[String] =
- {
- if (TaskLocality.PROCESS_LOCAL == taskLocality) {
- // straight forward comparison ! Special case it.
- val retval = new HashSet[String]()
- scheduler.synchronized {
- for (location <- _taskPreferredLocations) {
- if (scheduler.isExecutorAliveOnHostPort(location)) {
- retval += location
- }
- }
+ /**
+ * Add a task to all the pending-task lists that it should be on. If readding is set, we are
+ * re-adding the task so only include it in each list if it's not already there.
+ */
+ private def addPendingTask(index: Int, readding: Boolean = false) {
+ // Utility method that adds `index` to a list only if readding=false or it's not already there
+ def addTo(list: ArrayBuffer[Int]) {
+ if (!readding || !list.contains(index)) {
+ list += index
}
-
- return retval
}
- val taskPreferredLocations = {
- if (TaskLocality.NODE_LOCAL == taskLocality) {
- _taskPreferredLocations
- } else {
- assert (TaskLocality.RACK_LOCAL == taskLocality)
- // Expand set to include all 'seen' rack local hosts.
- // This works since container allocation/management happens within master -
- // so any rack locality information is updated in msater.
- // Best case effort, and maybe sort of kludge for now ... rework it later ?
- val hosts = new HashSet[String]
- _taskPreferredLocations.foreach(h => {
- val rackOpt = scheduler.getRackForHost(h)
- if (rackOpt.isDefined) {
- val hostsOpt = scheduler.getCachedHostsForRack(rackOpt.get)
- if (hostsOpt.isDefined) {
- hosts ++= hostsOpt.get
- }
- }
-
- // Ensure that irrespective of what scheduler says, host is always added !
- hosts += h
- })
-
- hosts
+ var hadAliveLocations = false
+ for (loc <- tasks(index).preferredLocations) {
+ for (execId <- loc.executorId) {
+ if (sched.isExecutorAlive(execId)) {
+ addTo(pendingTasksForExecutor.getOrElseUpdate(execId, new ArrayBuffer))
+ hadAliveLocations = true
+ }
}
- }
-
- val retval = new HashSet[String]
- scheduler.synchronized {
- for (prefLocation <- taskPreferredLocations) {
- val aliveLocationsOpt = scheduler.getExecutorsAliveOnHost(Utils.parseHostPort(prefLocation)._1)
- if (aliveLocationsOpt.isDefined) {
- retval ++= aliveLocationsOpt.get
+ if (sched.hasExecutorsAliveOnHost(loc.host)) {
+ addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer))
+ for (rack <- sched.getRackForHost(loc.host)) {
+ addTo(pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer))
}
+ hadAliveLocations = true
}
}
- retval
- }
-
- // Add a task to all the pending-task lists that it should be on.
- private def addPendingTask(index: Int) {
- // We can infer hostLocalLocations from rackLocalLocations by joining it against
- // tasks(index).preferredLocations (with appropriate hostPort <-> host conversion).
- // But not doing it for simplicity sake. If this becomes a performance issue, modify it.
- val locs = tasks(index).preferredLocations
- val processLocalLocations = findPreferredLocations(locs, sched, TaskLocality.PROCESS_LOCAL)
- val hostLocalLocations = findPreferredLocations(locs, sched, TaskLocality.NODE_LOCAL)
- val rackLocalLocations = findPreferredLocations(locs, sched, TaskLocality.RACK_LOCAL)
-
- if (rackLocalLocations.size == 0) {
- // Current impl ensures this.
- assert (processLocalLocations.size == 0)
- assert (hostLocalLocations.size == 0)
- pendingTasksWithNoPrefs += index
- } else {
-
- // process local locality
- for (hostPort <- processLocalLocations) {
- // DEBUG Code
- Utils.checkHostPort(hostPort)
-
- val hostPortList = pendingTasksForHostPort.getOrElseUpdate(hostPort, ArrayBuffer())
- hostPortList += index
- }
-
- // host locality (includes process local)
- for (hostPort <- hostLocalLocations) {
- // DEBUG Code
- Utils.checkHostPort(hostPort)
-
- val host = Utils.parseHostPort(hostPort)._1
- val hostList = pendingTasksForHost.getOrElseUpdate(host, ArrayBuffer())
- hostList += index
- }
-
- // rack locality (includes process local and host local)
- for (rackLocalHostPort <- rackLocalLocations) {
- // DEBUG Code
- Utils.checkHostPort(rackLocalHostPort)
-
- val rackLocalHost = Utils.parseHostPort(rackLocalHostPort)._1
- val list = pendingRackLocalTasksForHost.getOrElseUpdate(rackLocalHost, ArrayBuffer())
- list += index
- }
+ if (!hadAliveLocations) {
+ // Even though the task might've had preferred locations, all of those hosts or executors
+ // are dead; put it in the no-prefs list so we can schedule it elsewhere right away.
+ addTo(pendingTasksWithNoPrefs)
}
- allPendingTasks += index
+ addTo(allPendingTasks)
}
- // Return the pending tasks list for a given host port (process local), or an empty list if
- // there is no map entry for that host
- private def getPendingTasksForHostPort(hostPort: String): ArrayBuffer[Int] = {
- // DEBUG Code
- Utils.checkHostPort(hostPort)
- pendingTasksForHostPort.getOrElse(hostPort, ArrayBuffer())
+ /**
+ * Return the pending tasks list for a given executor ID, or an empty list if
+ * there is no map entry for that host
+ */
+ private def getPendingTasksForExecutor(executorId: String): ArrayBuffer[Int] = {
+ pendingTasksForExecutor.getOrElse(executorId, ArrayBuffer())
}
- // Return the pending tasks list for a given host, or an empty list if
- // there is no map entry for that host
- private def getPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = {
- val host = Utils.parseHostPort(hostPort)._1
+ /**
+ * Return the pending tasks list for a given host, or an empty list if
+ * there is no map entry for that host
+ */
+ private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = {
pendingTasksForHost.getOrElse(host, ArrayBuffer())
}
- // Return the pending tasks (rack level) list for a given host, or an empty list if
- // there is no map entry for that host
- private def getRackLocalPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = {
- val host = Utils.parseHostPort(hostPort)._1
- pendingRackLocalTasksForHost.getOrElse(host, ArrayBuffer())
- }
-
- // Number of pending tasks for a given host Port (which would be process local)
- override def numPendingTasksForHostPort(hostPort: String): Int = {
- getPendingTasksForHostPort(hostPort).count { index =>
- copiesRunning(index) == 0 && !finished(index)
- }
- }
-
- // Number of pending tasks for a given host (which would be data local)
- override def numPendingTasksForHost(hostPort: String): Int = {
- getPendingTasksForHost(hostPort).count { index =>
- copiesRunning(index) == 0 && !finished(index)
- }
- }
-
- // Number of pending rack local tasks for a given host
- override def numRackLocalPendingTasksForHost(hostPort: String): Int = {
- getRackLocalPendingTasksForHost(hostPort).count { index =>
- copiesRunning(index) == 0 && !finished(index)
- }
+ /**
+ * Return the pending rack-local task list for a given rack, or an empty list if
+ * there is no map entry for that rack
+ */
+ private def getPendingTasksForRack(rack: String): ArrayBuffer[Int] = {
+ pendingTasksForRack.getOrElse(rack, ArrayBuffer())
}
-
- // Dequeue a pending task from the given list and return its index.
- // Return None if the list is empty.
- // This method also cleans up any tasks in the list that have already
- // been launched, since we want that to happen lazily.
+ /**
+ * Dequeue a pending task from the given list and return its index.
+ * Return None if the list is empty.
+ * This method also cleans up any tasks in the list that have already
+ * been launched, since we want that to happen lazily.
+ */
private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
while (!list.isEmpty) {
val index = list.last
@@ -342,176 +213,145 @@ private[spark] class ClusterTaskSetManager(sched: ClusterScheduler, val taskSet:
return None
}
- // Return a speculative task for a given host if any are available. The task should not have an
- // attempt running on this host, in case the host is slow. In addition, if locality is set, the
- // task must have a preference for this host/rack/no preferred locations at all.
- private def findSpeculativeTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = {
+ /** Check whether a task is currently running an attempt on a given host */
+ private def hasAttemptOnHost(taskIndex: Int, host: String): Boolean = {
+ !taskAttempts(taskIndex).exists(_.host == host)
+ }
- assert (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL))
+ /**
+ * Return a speculative task for a given executor if any are available. The task should not have
+ * an attempt running on this host, in case the host is slow. In addition, the task should meet
+ * the given locality constraint.
+ */
+ private def findSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value)
+ : Option[(Int, TaskLocality.Value)] =
+ {
speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set
- if (speculatableTasks.size > 0) {
- val localTask = speculatableTasks.find { index =>
- val locations = findPreferredLocations(tasks(index).preferredLocations, sched,
- TaskLocality.NODE_LOCAL)
- val attemptLocs = taskAttempts(index).map(_.hostPort)
- (locations.size == 0 || locations.contains(hostPort)) && !attemptLocs.contains(hostPort)
+ if (!speculatableTasks.isEmpty) {
+ // Check for process-local or preference-less tasks; note that tasks can be process-local
+ // on multiple nodes when we replicate cached blocks, as in Spark Streaming
+ for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
+ val prefs = tasks(index).preferredLocations
+ val executors = prefs.flatMap(_.executorId)
+ if (prefs.size == 0 || executors.contains(execId)) {
+ speculatableTasks -= index
+ return Some((index, TaskLocality.PROCESS_LOCAL))
+ }
}
- if (localTask != None) {
- speculatableTasks -= localTask.get
- return localTask
+ // Check for node-local tasks
+ if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) {
+ for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
+ val locations = tasks(index).preferredLocations.map(_.host)
+ if (locations.contains(host)) {
+ speculatableTasks -= index
+ return Some((index, TaskLocality.NODE_LOCAL))
+ }
+ }
}
- // check for rack locality
+ // Check for rack-local tasks
if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
- val rackTask = speculatableTasks.find { index =>
- val locations = findPreferredLocations(tasks(index).preferredLocations, sched,
- TaskLocality.RACK_LOCAL)
- val attemptLocs = taskAttempts(index).map(_.hostPort)
- locations.contains(hostPort) && !attemptLocs.contains(hostPort)
- }
-
- if (rackTask != None) {
- speculatableTasks -= rackTask.get
- return rackTask
+ for (rack <- sched.getRackForHost(host)) {
+ for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
+ val racks = tasks(index).preferredLocations.map(_.host).map(sched.getRackForHost)
+ if (racks.contains(rack)) {
+ speculatableTasks -= index
+ return Some((index, TaskLocality.RACK_LOCAL))
+ }
+ }
}
}
- // Any task ...
+ // Check for non-local tasks
if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
- // Check for attemptLocs also ?
- val nonLocalTask = speculatableTasks.find { i =>
- !taskAttempts(i).map(_.hostPort).contains(hostPort)
- }
- if (nonLocalTask != None) {
- speculatableTasks -= nonLocalTask.get
- return nonLocalTask
+ for (index <- speculatableTasks if !hasAttemptOnHost(index, host)) {
+ speculatableTasks -= index
+ return Some((index, TaskLocality.ANY))
}
}
}
+
return None
}
- // Dequeue a pending task for a given node and return its index.
- // If localOnly is set to false, allow non-local tasks as well.
- private def findTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = {
- val processLocalTask = findTaskFromList(getPendingTasksForHostPort(hostPort))
- if (processLocalTask != None) {
- return processLocalTask
+ /**
+ * Dequeue a pending task for a given node and return its index and locality level.
+ * Only search for tasks matching the given locality constraint.
+ */
+ private def findTask(execId: String, host: String, locality: TaskLocality.Value)
+ : Option[(Int, TaskLocality.Value)] =
+ {
+ for (index <- findTaskFromList(getPendingTasksForExecutor(execId))) {
+ return Some((index, TaskLocality.PROCESS_LOCAL))
}
- val localTask = findTaskFromList(getPendingTasksForHost(hostPort))
- if (localTask != None) {
- return localTask
+ if (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL)) {
+ for (index <- findTaskFromList(getPendingTasksForHost(host))) {
+ return Some((index, TaskLocality.NODE_LOCAL))
+ }
}
if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
- val rackLocalTask = findTaskFromList(getRackLocalPendingTasksForHost(hostPort))
- if (rackLocalTask != None) {
- return rackLocalTask
+ for {
+ rack <- sched.getRackForHost(host)
+ index <- findTaskFromList(getPendingTasksForRack(rack))
+ } {
+ return Some((index, TaskLocality.RACK_LOCAL))
}
}
- // Look for no pref tasks AFTER rack local tasks - this has side effect that we will get to
- // failed tasks later rather than sooner.
- // TODO: That code path needs to be revisited (adding to no prefs list when host:port goes down).
- val noPrefTask = findTaskFromList(pendingTasksWithNoPrefs)
- if (noPrefTask != None) {
- return noPrefTask
+ // Look for no-pref tasks after rack-local tasks since they can run anywhere.
+ for (index <- findTaskFromList(pendingTasksWithNoPrefs)) {
+ return Some((index, TaskLocality.PROCESS_LOCAL))
}
if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
- val nonLocalTask = findTaskFromList(allPendingTasks)
- if (nonLocalTask != None) {
- return nonLocalTask
+ for (index <- findTaskFromList(allPendingTasks)) {
+ return Some((index, TaskLocality.ANY))
}
}
// Finally, if all else has failed, find a speculative task
- return findSpeculativeTask(hostPort, locality)
- }
-
- private def isProcessLocalLocation(task: Task[_], hostPort: String): Boolean = {
- Utils.checkHostPort(hostPort)
-
- val locs = task.preferredLocations
-
- locs.contains(hostPort)
- }
-
- private def isHostLocalLocation(task: Task[_], hostPort: String): Boolean = {
- val locs = task.preferredLocations
-
- // If no preference, consider it as host local
- if (locs.isEmpty) return true
-
- val host = Utils.parseHostPort(hostPort)._1
- locs.find(h => Utils.parseHostPort(h)._1 == host).isDefined
- }
-
- // Does a host count as a rack local preferred location for a task?
- // (assumes host is NOT preferred location).
- // This is true if either the task has preferred locations and this host is one, or it has
- // no preferred locations (in which we still count the launch as preferred).
- private def isRackLocalLocation(task: Task[_], hostPort: String): Boolean = {
-
- val locs = task.preferredLocations
-
- val preferredRacks = new HashSet[String]()
- for (preferredHost <- locs) {
- val rack = sched.getRackForHost(preferredHost)
- if (None != rack) preferredRacks += rack.get
- }
-
- if (preferredRacks.isEmpty) return false
-
- val hostRack = sched.getRackForHost(hostPort)
-
- return None != hostRack && preferredRacks.contains(hostRack.get)
+ return findSpeculativeTask(execId, host, locality)
}
- // Respond to an offer of a single slave from the scheduler by finding a task
- override def slaveOffer(
- execId: String,
- hostPort: String,
- availableCpus: Double,
- overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] =
+ /**
+ * Respond to an offer of a single slave from the scheduler by finding a task
+ */
+ override def resourceOffer(execId: String, host: String, availableCpus: Double)
+ : Option[TaskDescription] =
{
if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) {
+ val curTime = System.currentTimeMillis
+
// If explicitly specified, use that
- val locality = if (overrideLocality != null) overrideLocality else {
+ val locality = {
// expand only if we have waited for more than LOCALITY_WAIT for a host local task ...
- val time = System.currentTimeMillis
- if (time - lastPreferredLaunchTime < LOCALITY_WAIT) {
+ // TODO(matei): Multi-level delay scheduling
+ if (curTime - lastPreferredLaunchTime < LOCALITY_WAIT) {
TaskLocality.NODE_LOCAL
} else {
TaskLocality.ANY
}
}
- findTask(hostPort, locality) match {
- case Some(index) => {
- // Found a task; do some bookkeeping and return a Mesos task for it
+ findTask(execId, host, locality) match {
+ case Some((index, taskLocality)) => {
+ // Found a task; do some bookkeeping and return a task description
val task = tasks(index)
val taskId = sched.newTaskId()
// Figure out whether this should count as a preferred launch
- val taskLocality =
- if (isProcessLocalLocation(task, hostPort)) TaskLocality.PROCESS_LOCAL
- else if (isHostLocalLocation(task, hostPort)) TaskLocality.NODE_LOCAL
- else if (isRackLocalLocation(task, hostPort)) TaskLocality.RACK_LOCAL
- else TaskLocality.ANY
- val prefStr = taskLocality.toString
logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format(
- taskSet.id, index, taskId, execId, hostPort, prefStr))
+ taskSet.id, index, taskId, execId, host, taskLocality))
// Do various bookkeeping
copiesRunning(index) += 1
- val time = System.currentTimeMillis
- val info = new TaskInfo(taskId, index, time, execId, hostPort, taskLocality)
+ val info = new TaskInfo(taskId, index, curTime, execId, host, taskLocality)
taskInfos(taskId) = info
taskAttempts(index) = info :: taskAttempts(index)
if (taskLocality == TaskLocality.PROCESS_LOCAL || taskLocality == TaskLocality.NODE_LOCAL) {
- lastPreferredLaunchTime = time
+ lastPreferredLaunchTime = curTime
}
// Serialize and return the task
val startTime = System.currentTimeMillis
@@ -534,6 +374,7 @@ private[spark] class ClusterTaskSetManager(sched: ClusterScheduler, val taskSet:
return None
}
+ /** Called by cluster scheduler when one of our tasks changes state */
override def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
SparkEnv.set(env)
state match {
@@ -566,7 +407,7 @@ private[spark] class ClusterTaskSetManager(sched: ClusterScheduler, val taskSet:
if (!finished(index)) {
tasksFinished += 1
logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format(
- tid, info.duration, info.hostPort, tasksFinished, numTasks))
+ tid, info.duration, info.host, tasksFinished, numTasks))
// Deserialize task result and pass it to the scheduler
try {
val result = ser.deserialize[TaskResult[_]](serializedData)
@@ -698,44 +539,33 @@ private[spark] class ClusterTaskSetManager(sched: ClusterScheduler, val taskSet:
}
}
- // TODO(xiajunluan): for now we just find Pool not TaskSetManager
- // we can extend this function in future if needed
override def getSchedulableByName(name: String): Schedulable = {
return null
}
- override def addSchedulable(schedulable:Schedulable) {
- // nothing
- }
+ override def addSchedulable(schedulable: Schedulable) {}
- override def removeSchedulable(schedulable:Schedulable) {
- // nothing
- }
+ override def removeSchedulable(schedulable: Schedulable) {}
override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
- var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager]
+ var sortedTaskSetQueue = ArrayBuffer[TaskSetManager](this)
sortedTaskSetQueue += this
return sortedTaskSetQueue
}
- override def executorLost(execId: String, hostPort: String) {
+ /** Called by cluster scheduler when an executor is lost so we can re-enqueue our tasks */
+ override def executorLost(execId: String, host: String) {
logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id)
- // If some task has preferred locations only on hostname, and there are no more executors there,
- // put it in the no-prefs list to avoid the wait from delay scheduling
-
- // host local tasks - should we push this to rack local or no pref list ? For now, preserving
- // behavior and moving to no prefs list. Note, this was done due to impliations related to
- // 'waiting' for data local tasks, etc.
- // Note: NOT checking process local list - since host local list is super set of that. We need
- // to ad to no prefs only if there is no host local node for the task (not if there is no
- // process local node for the task)
- for (index <- getPendingTasksForHost(Utils.parseHostPort(hostPort)._1)) {
- val newLocs = findPreferredLocations(
- tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL)
- if (newLocs.isEmpty) {
- pendingTasksWithNoPrefs += index
- }
+ // Re-enqueue pending tasks for this host based on the status of the cluster -- for example, a
+ // task that used to have locations on only this host might now go to the no-prefs list. Note
+ // that it's okay if we add a task to the same queue twice (if it had multiple preferred
+ // locations), because findTaskFromList will skip already-running tasks.
+ for (index <- getPendingTasksForExecutor(execId)) {
+ addPendingTask(index, readding=true)
+ }
+ for (index <- getPendingTasksForHost(host)) {
+ addPendingTask(index, readding=true)
}
// Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage
@@ -789,7 +619,7 @@ private[spark] class ClusterTaskSetManager(sched: ClusterScheduler, val taskSet:
!speculatableTasks.contains(index)) {
logInfo(
"Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format(
- taskSet.id, index, info.hostPort, threshold))
+ taskSet.id, index, info.host, threshold))
speculatableTasks += index
foundTasks = true
}
diff --git a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
index 075a7cbf7e..3b49af1258 100644
--- a/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/StandaloneSchedulerBackend.scala
@@ -37,7 +37,9 @@ import spark.scheduler.cluster.StandaloneClusterMessages._
*/
private[spark]
class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: ActorSystem)
- extends SchedulerBackend with Logging {
+ extends SchedulerBackend with Logging
+{
+ // TODO(matei): periodically revive offers as in MesosScheduler
// Use an atomic variable to track total number of cores in the cluster for simplicity and speed
var totalCoreCount = new AtomicInteger(0)
@@ -45,7 +47,7 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor {
private val executorActor = new HashMap[String, ActorRef]
private val executorAddress = new HashMap[String, Address]
- private val executorHostPort = new HashMap[String, String]
+ private val executorHost = new HashMap[String, String]
private val freeCores = new HashMap[String, Int]
private val actorToExecutorId = new HashMap[ActorRef, String]
private val addressToExecutorId = new HashMap[Address, String]
@@ -65,7 +67,7 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
sender ! RegisteredExecutor(sparkProperties)
context.watch(sender)
executorActor(executorId) = sender
- executorHostPort(executorId) = hostPort
+ executorHost(executorId) = Utils.parseHostPort(hostPort)._1
freeCores(executorId) = cores
executorAddress(executorId) = sender.path.address
actorToExecutorId(sender) = executorId
@@ -105,13 +107,13 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
// Make fake resource offers on all executors
def makeOffers() {
launchTasks(scheduler.resourceOffers(
- executorHostPort.toArray.map {case (id, hostPort) => new WorkerOffer(id, hostPort, freeCores(id))}))
+ executorHost.toArray.map {case (id, host) => new WorkerOffer(id, host, freeCores(id))}))
}
// Make fake resource offers on just one executor
def makeOffers(executorId: String) {
launchTasks(scheduler.resourceOffers(
- Seq(new WorkerOffer(executorId, executorHostPort(executorId), freeCores(executorId)))))
+ Seq(new WorkerOffer(executorId, executorHost(executorId), freeCores(executorId)))))
}
// Launch tasks returned by a set of resource offers
@@ -130,9 +132,8 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
actorToExecutorId -= executorActor(executorId)
addressToExecutorId -= executorAddress(executorId)
executorActor -= executorId
- executorHostPort -= executorId
+ executorHost -= executorId
freeCores -= executorId
- executorHostPort -= executorId
totalCoreCount.addAndGet(-numCores)
scheduler.executorLost(executorId, SlaveLost(reason))
}
diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala
index c693b722ac..c2c5522686 100644
--- a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala
@@ -28,11 +28,9 @@ class TaskInfo(
val index: Int,
val launchTime: Long,
val executorId: String,
- val hostPort: String,
+ val host: String,
val taskLocality: TaskLocality.TaskLocality) {
- Utils.checkHostPort(hostPort, "Expected hostport")
-
var finishTime: Long = 0
var failed = false
diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskLocality.scala b/core/src/main/scala/spark/scheduler/cluster/TaskLocality.scala
new file mode 100644
index 0000000000..1c33e41f87
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/cluster/TaskLocality.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package spark.scheduler.cluster
+
+
+private[spark] object TaskLocality
+ extends Enumeration("PROCESS_LOCAL", "NODE_LOCAL", "RACK_LOCAL", "ANY")
+{
+ // process local is expected to be used ONLY within tasksetmanager for now.
+ val PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY = Value
+
+ type TaskLocality = Value
+
+ def isAllowed(constraint: TaskLocality, condition: TaskLocality): Boolean = {
+ condition <= constraint
+ }
+}
diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
index 1a92a5ed6f..277654edc0 100644
--- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
@@ -29,17 +29,8 @@ private[spark] trait TaskSetManager extends Schedulable {
def taskSet: TaskSet
- def slaveOffer(
- execId: String,
- hostPort: String,
- availableCpus: Double,
- overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription]
-
- def numPendingTasksForHostPort(hostPort: String): Int
-
- def numRackLocalPendingTasksForHost(hostPort: String): Int
-
- def numPendingTasksForHost(hostPort: String): Int
+ def resourceOffer(execId: String, hostPort: String, availableCpus: Double)
+ : Option[TaskDescription]
def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer)
diff --git a/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala b/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala
index 06d1203f70..1d09bd9b03 100644
--- a/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/WorkerOffer.scala
@@ -21,5 +21,4 @@ package spark.scheduler.cluster
* Represents free resources available on an executor.
*/
private[spark]
-class WorkerOffer(val executorId: String, val hostPort: String, val cores: Int) {
-}
+class WorkerOffer(val executorId: String, val host: String, val cores: Int)
diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
index 6c43928bc8..a4f5f46777 100644
--- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
@@ -141,7 +141,8 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
for (manager <- sortedTaskSetQueue) {
do {
launchTask = false
- manager.slaveOffer(null, null, freeCpuCores) match {
+ // TODO(matei): don't pass null here?
+ manager.resourceOffer(null, null, freeCpuCores) match {
case Some(task) =>
tasks += task
taskIdToTaskSetId(task.taskId) = manager.taskSet.id
diff --git a/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala
index c38eeb9e11..698c777bec 100644
--- a/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala
+++ b/core/src/main/scala/spark/scheduler/local/LocalTaskSetManager.scala
@@ -98,14 +98,11 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
return None
}
- override def slaveOffer(
- execId: String,
- hostPort: String,
- availableCpus: Double,
- overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] =
+ override def resourceOffer(execId: String, host: String, availableCpus: Double)
+ : Option[TaskDescription] =
{
SparkEnv.set(sched.env)
- logDebug("availableCpus:%d,numFinished:%d,numTasks:%d".format(
+ logDebug("availableCpus:%d, numFinished:%d, numTasks:%d".format(
availableCpus.toInt, numFinished, numTasks))
if (availableCpus > 0 && numFinished < numTasks) {
findTask() match {
@@ -131,18 +128,6 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
return None
}
- override def numPendingTasksForHostPort(hostPort: String): Int = {
- return 0
- }
-
- override def numRackLocalPendingTasksForHost(hostPort :String): Int = {
- return 0
- }
-
- override def numPendingTasksForHost(hostPort: String): Int = {
- return 0
- }
-
override def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
SparkEnv.set(env)
state match {
diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala
index 3a72474419..2a6ec2a55d 100644
--- a/core/src/main/scala/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/spark/storage/BlockManager.scala
@@ -1004,43 +1004,43 @@ private[spark] object BlockManager extends Logging {
}
}
- def blockIdsToExecutorLocations(blockIds: Array[String], env: SparkEnv, blockManagerMaster: BlockManagerMaster = null): HashMap[String, List[String]] = {
+ def blockIdsToBlockManagers(
+ blockIds: Array[String],
+ env: SparkEnv,
+ blockManagerMaster: BlockManagerMaster = null)
+ : Map[String, Seq[BlockManagerId]] =
+ {
// env == null and blockManagerMaster != null is used in tests
assert (env != null || blockManagerMaster != null)
- val locationBlockIds: Seq[Seq[BlockManagerId]] =
- if (env != null) {
- env.blockManager.getLocationBlockIds(blockIds)
- } else {
- blockManagerMaster.getLocations(blockIds)
- }
+ val blockLocations: Seq[Seq[BlockManagerId]] = if (env != null) {
+ env.blockManager.getLocationBlockIds(blockIds)
+ } else {
+ blockManagerMaster.getLocations(blockIds)
+ }
- // Convert from block master locations to executor locations (we need that for task scheduling)
- val executorLocations = new HashMap[String, List[String]]()
+ val blockManagers = new HashMap[String, Seq[BlockManagerId]]
for (i <- 0 until blockIds.length) {
- val blockId = blockIds(i)
- val blockLocations = locationBlockIds(i)
-
- val executors = new HashSet[String]()
-
- if (env != null) {
- for (bkLocation <- blockLocations) {
- val executorHostPort = env.resolveExecutorIdToHostPort(bkLocation.executorId, bkLocation.host)
- executors += executorHostPort
- // logInfo("bkLocation = " + bkLocation + ", executorHostPort = " + executorHostPort)
- }
- } else {
- // Typically while testing, etc - revert to simply using host.
- for (bkLocation <- blockLocations) {
- executors += bkLocation.host
- // logInfo("bkLocation = " + bkLocation + ", executorHostPort = " + executorHostPort)
- }
- }
-
- executorLocations.put(blockId, executors.toSeq.toList)
+ blockManagers(blockIds(i)) = blockLocations(i)
}
+ blockManagers.toMap
+ }
- executorLocations
+ def blockIdsToExecutorIds(
+ blockIds: Array[String],
+ env: SparkEnv,
+ blockManagerMaster: BlockManagerMaster = null)
+ : Map[String, Seq[String]] =
+ {
+ blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.executorId))
}
+ def blockIdsToHosts(
+ blockIds: Array[String],
+ env: SparkEnv,
+ blockManagerMaster: BlockManagerMaster = null)
+ : Map[String, Seq[String]] =
+ {
+ blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.host))
+ }
}
diff --git a/core/src/main/scala/spark/ui/jobs/StagePage.scala b/core/src/main/scala/spark/ui/jobs/StagePage.scala
index 797513f266..6948ea4dd9 100644
--- a/core/src/main/scala/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/spark/ui/jobs/StagePage.scala
@@ -156,7 +156,7 @@ private[spark] class StagePage(parent: JobProgressUI) {
<td>{info.taskId}</td>
<td>{info.status}</td>
<td>{info.taskLocality}</td>
- <td>{info.hostPort}</td>
+ <td>{info.host}</td>
<td>{dateFmt.format(new Date(info.launchTime))}</td>
<td sorttable_customkey={duration.toString}>
{formatDuration}
diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala
index ce6cec0451..c21f3331d0 100644
--- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala
@@ -112,22 +112,22 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
"akka://spark@localhost:" + boundPort + "/user/MapOutputTracker")
masterTracker.registerShuffle(10, 1)
- masterTracker.incrementGeneration()
- slaveTracker.updateGeneration(masterTracker.getGeneration)
+ masterTracker.incrementEpoch()
+ slaveTracker.updateEpoch(masterTracker.getEpoch)
intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
masterTracker.registerMapOutput(10, 0, new MapStatus(
BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000)))
- masterTracker.incrementGeneration()
- slaveTracker.updateGeneration(masterTracker.getGeneration)
+ masterTracker.incrementEpoch()
+ slaveTracker.updateEpoch(masterTracker.getEpoch)
assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
Seq((BlockManagerId("a", "hostA", 1000, 0), size1000)))
masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0))
- masterTracker.incrementGeneration()
- slaveTracker.updateGeneration(masterTracker.getGeneration)
+ masterTracker.incrementEpoch()
+ slaveTracker.updateEpoch(masterTracker.getEpoch)
intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
// failure should be cached
diff --git a/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala
index 05afcd6567..6327155157 100644
--- a/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala
+++ b/core/src/test/scala/spark/scheduler/ClusterSchedulerSuite.scala
@@ -72,7 +72,9 @@ class DummyTaskSetManager(
override def executorLost(executorId: String, host: String): Unit = {
}
- override def slaveOffer(execId: String, host: String, avaiableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = {
+ override def resourceOffer(execId: String, host: String, availableCpus: Double)
+ : Option[TaskDescription] =
+ {
if (tasksFinished + runningTasks < numTasks) {
increaseRunningTasks(1)
return Some(new TaskDescription(0, execId, "task 0:0", null))
@@ -118,7 +120,7 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging
logInfo("parentName:%s, parent running tasks:%d, name:%s,runningTasks:%d".format(manager.parent.name, manager.parent.runningTasks, manager.name, manager.runningTasks))
}
for (taskSet <- taskSetQueue) {
- taskSet.slaveOffer("execId_1", "hostname_1", 1) match {
+ taskSet.resourceOffer("execId_1", "hostname_1", 1) match {
case Some(task) =>
return taskSet.stageId
case None => {}
diff --git a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
index caaf3209fd..3b4a0d52fc 100644
--- a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
@@ -59,7 +59,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
override def stop() = {}
override def submitTasks(taskSet: TaskSet) = {
// normally done by TaskSetManager
- taskSet.tasks.foreach(_.generation = mapOutputTracker.getGeneration)
+ taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch)
taskSets += taskSet
}
override def setListener(listener: TaskSchedulerListener) = {}
@@ -299,10 +299,10 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
val reduceRdd = makeRdd(2, List(shuffleDep))
submit(reduceRdd, Array(0, 1))
// pretend we were told hostA went away
- val oldGeneration = mapOutputTracker.getGeneration
+ val oldEpoch = mapOutputTracker.getEpoch
runEvent(ExecutorLost("exec-hostA"))
- val newGeneration = mapOutputTracker.getGeneration
- assert(newGeneration > oldGeneration)
+ val newEpoch = mapOutputTracker.getEpoch
+ assert(newEpoch > oldEpoch)
val noAccum = Map[Long, Any]()
val taskSet = taskSets(0)
// should be ignored for being too old
@@ -311,8 +311,8 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), noAccum, null, null))
// should be ignored for being too old
runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum, null, null))
- // should work because it's a new generation
- taskSet.tasks(1).generation = newGeneration
+ // should work because it's a new epoch
+ taskSet.tasks(1).epoch = newEpoch
runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum, null, null))
assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA")))
@@ -401,12 +401,14 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
assert(results === Map(0 -> 42))
}
- /** Assert that the supplied TaskSet has exactly the given preferredLocations. Note, converts taskSet's locations to host only. */
- private def assertLocations(taskSet: TaskSet, locations: Seq[Seq[String]]) {
- assert(locations.size === taskSet.tasks.size)
- for ((expectLocs, taskLocs) <-
- taskSet.tasks.map(_.preferredLocations).zip(locations)) {
- assert(expectLocs.map(loc => spark.Utils.parseHostPort(loc)._1) === taskLocs)
+ /**
+ * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations.
+ * Note that this checks only the host and not the executor ID.
+ */
+ private def assertLocations(taskSet: TaskSet, hosts: Seq[Seq[String]]) {
+ assert(hosts.size === taskSet.tasks.size)
+ for ((taskLocs, expectedLocs) <- taskSet.tasks.map(_.preferredLocations).zip(hosts)) {
+ assert(taskLocs.map(_.host) === expectedLocs)
}
}