aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorjerryshao <saisai.shao@intel.com>2015-07-27 15:46:35 -0700
committerSandy Ryza <sandy@cloudera.com>2015-07-27 15:46:35 -0700
commitab625956616664c2b4861781a578311da75a9ae4 (patch)
tree314bc1f84b5a5794676c362b822414e47d190da2 /core
parent2104931d7d726eda2c098e0f403c7f1533df8746 (diff)
downloadspark-ab625956616664c2b4861781a578311da75a9ae4.tar.gz
spark-ab625956616664c2b4861781a578311da75a9ae4.tar.bz2
spark-ab625956616664c2b4861781a578311da75a9ae4.zip
[SPARK-4352] [YARN] [WIP] Incorporate locality preferences in dynamic allocation requests
Currently there's no locality preference for container request in YARN mode, this will affect the performance if fetching data remotely, so here proposed to add locality in Yarn dynamic allocation mode. Ping sryza, please help to review, thanks a lot. Author: jerryshao <saisai.shao@intel.com> Closes #6394 from jerryshao/SPARK-4352 and squashes the following commits: d45fecb [jerryshao] Add documents 6c3fe5c [jerryshao] Fix bug 8db6c0e [jerryshao] Further address the comments 2e2b2cb [jerryshao] Fix rebase compiling problem ce5f096 [jerryshao] Fix style issue 7f7df95 [jerryshao] Fix rebase issue 9ca9e07 [jerryshao] Code refactor according to comments d3e4236 [jerryshao] Further address the comments 5e7a593 [jerryshao] Fix bug introduced code rebase 9ca7783 [jerryshao] Style changes 08317f9 [jerryshao] code and comment refines 65b2423 [jerryshao] Further address the comments a27c587 [jerryshao] address the comment 27faabc [jerryshao] redundant code remove 9ce06a1 [jerryshao] refactor the code f5ba27b [jerryshao] Style fix 2c6cc8a [jerryshao] Fix bug and add unit tests 0757335 [jerryshao] Consider the distribution of existed containers to recalculate the new container requests 0ad66ff [jerryshao] Fix compile bugs 1c20381 [jerryshao] Minor fix 5ef2dc8 [jerryshao] Add docs and improve the code 3359814 [jerryshao] Fix rebase and test bugs 0398539 [jerryshao] reinitialize the new implementation 67596d6 [jerryshao] Still fix the code 654e1d2 [jerryshao] Fix some bugs 45b1c89 [jerryshao] Further polish the algorithm dea0152 [jerryshao] Enable node locality information in YarnAllocator 74bbcc6 [jerryshao] Support node locality for dynamic allocation initial commit
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala18
-rw-r--r--core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala62
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala25
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala26
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Stage.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala13
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala32
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala3
-rw-r--r--core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala55
-rw-r--r--core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala7
11 files changed, 224 insertions, 30 deletions
diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala
index 443830f8d0..842bfdbadc 100644
--- a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala
+++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala
@@ -24,11 +24,23 @@ package org.apache.spark
private[spark] trait ExecutorAllocationClient {
/**
- * Express a preference to the cluster manager for a given total number of executors.
- * This can result in canceling pending requests or filing additional requests.
+ * Update the cluster manager on our scheduling needs. Three bits of information are included
+ * to help it make decisions.
+ * @param numExecutors The total number of executors we'd like to have. The cluster manager
+ * shouldn't kill any running executor to reach this number, but,
+ * if all existing executors were to die, this is the number of executors
+ * we'd want to be allocated.
+ * @param localityAwareTasks The number of tasks in all active stages that have a locality
+ * preferences. This includes running, pending, and completed tasks.
+ * @param hostToLocalTaskCount A map of hosts to the number of tasks from all active stages
+ * that would like to like to run on that host.
+ * This includes running, pending, and completed tasks.
* @return whether the request is acknowledged by the cluster manager.
*/
- private[spark] def requestTotalExecutors(numExecutors: Int): Boolean
+ private[spark] def requestTotalExecutors(
+ numExecutors: Int,
+ localityAwareTasks: Int,
+ hostToLocalTaskCount: Map[String, Int]): Boolean
/**
* Request an additional number of executors from the cluster manager.
diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
index 648bcfe28c..1877aaf2ca 100644
--- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
+++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
@@ -161,6 +161,12 @@ private[spark] class ExecutorAllocationManager(
// (2) an executor idle timeout has elapsed.
@volatile private var initializing: Boolean = true
+ // Number of locality aware tasks, used for executor placement.
+ private var localityAwareTasks = 0
+
+ // Host to possible task running on it, used for executor placement.
+ private var hostToLocalTaskCount: Map[String, Int] = Map.empty
+
/**
* Verify that the settings specified through the config are valid.
* If not, throw an appropriate exception.
@@ -295,7 +301,7 @@ private[spark] class ExecutorAllocationManager(
// If the new target has not changed, avoid sending a message to the cluster manager
if (numExecutorsTarget < oldNumExecutorsTarget) {
- client.requestTotalExecutors(numExecutorsTarget)
+ client.requestTotalExecutors(numExecutorsTarget, localityAwareTasks, hostToLocalTaskCount)
logDebug(s"Lowering target number of executors to $numExecutorsTarget (previously " +
s"$oldNumExecutorsTarget) because not all requested executors are actually needed")
}
@@ -349,7 +355,8 @@ private[spark] class ExecutorAllocationManager(
return 0
}
- val addRequestAcknowledged = testing || client.requestTotalExecutors(numExecutorsTarget)
+ val addRequestAcknowledged = testing ||
+ client.requestTotalExecutors(numExecutorsTarget, localityAwareTasks, hostToLocalTaskCount)
if (addRequestAcknowledged) {
val executorsString = "executor" + { if (delta > 1) "s" else "" }
logInfo(s"Requesting $delta new $executorsString because tasks are backlogged" +
@@ -519,6 +526,12 @@ private[spark] class ExecutorAllocationManager(
// Number of tasks currently running on the cluster. Should be 0 when no stages are active.
private var numRunningTasks: Int = _
+ // stageId to tuple (the number of task with locality preferences, a map where each pair is a
+ // node and the number of tasks that would like to be scheduled on that node) map,
+ // maintain the executor placement hints for each stage Id used by resource framework to better
+ // place the executors.
+ private val stageIdToExecutorPlacementHints = new mutable.HashMap[Int, (Int, Map[String, Int])]
+
override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = {
initializing = false
val stageId = stageSubmitted.stageInfo.stageId
@@ -526,6 +539,24 @@ private[spark] class ExecutorAllocationManager(
allocationManager.synchronized {
stageIdToNumTasks(stageId) = numTasks
allocationManager.onSchedulerBacklogged()
+
+ // Compute the number of tasks requested by the stage on each host
+ var numTasksPending = 0
+ val hostToLocalTaskCountPerStage = new mutable.HashMap[String, Int]()
+ stageSubmitted.stageInfo.taskLocalityPreferences.foreach { locality =>
+ if (!locality.isEmpty) {
+ numTasksPending += 1
+ locality.foreach { location =>
+ val count = hostToLocalTaskCountPerStage.getOrElse(location.host, 0) + 1
+ hostToLocalTaskCountPerStage(location.host) = count
+ }
+ }
+ }
+ stageIdToExecutorPlacementHints.put(stageId,
+ (numTasksPending, hostToLocalTaskCountPerStage.toMap))
+
+ // Update the executor placement hints
+ updateExecutorPlacementHints()
}
}
@@ -534,6 +565,10 @@ private[spark] class ExecutorAllocationManager(
allocationManager.synchronized {
stageIdToNumTasks -= stageId
stageIdToTaskIndices -= stageId
+ stageIdToExecutorPlacementHints -= stageId
+
+ // Update the executor placement hints
+ updateExecutorPlacementHints()
// If this is the last stage with pending tasks, mark the scheduler queue as empty
// This is needed in case the stage is aborted for any reason
@@ -637,6 +672,29 @@ private[spark] class ExecutorAllocationManager(
def isExecutorIdle(executorId: String): Boolean = {
!executorIdToTaskIds.contains(executorId)
}
+
+ /**
+ * Update the Executor placement hints (the number of tasks with locality preferences,
+ * a map where each pair is a node and the number of tasks that would like to be scheduled
+ * on that node).
+ *
+ * These hints are updated when stages arrive and complete, so are not up-to-date at task
+ * granularity within stages.
+ */
+ def updateExecutorPlacementHints(): Unit = {
+ var localityAwareTasks = 0
+ val localityToCount = new mutable.HashMap[String, Int]()
+ stageIdToExecutorPlacementHints.values.foreach { case (numTasksPending, localities) =>
+ localityAwareTasks += numTasksPending
+ localities.foreach { case (hostname, count) =>
+ val updatedCount = localityToCount.getOrElse(hostname, 0) + count
+ localityToCount(hostname) = updatedCount
+ }
+ }
+
+ allocationManager.localityAwareTasks = localityAwareTasks
+ allocationManager.hostToLocalTaskCount = localityToCount.toMap
+ }
}
/**
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 6a6b94a271..ac6ac6c216 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -1382,16 +1382,29 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
}
/**
- * Express a preference to the cluster manager for a given total number of executors.
- * This can result in canceling pending requests or filing additional requests.
- * This is currently only supported in YARN mode. Return whether the request is received.
- */
- private[spark] override def requestTotalExecutors(numExecutors: Int): Boolean = {
+ * Update the cluster manager on our scheduling needs. Three bits of information are included
+ * to help it make decisions.
+ * @param numExecutors The total number of executors we'd like to have. The cluster manager
+ * shouldn't kill any running executor to reach this number, but,
+ * if all existing executors were to die, this is the number of executors
+ * we'd want to be allocated.
+ * @param localityAwareTasks The number of tasks in all active stages that have a locality
+ * preferences. This includes running, pending, and completed tasks.
+ * @param hostToLocalTaskCount A map of hosts to the number of tasks from all active stages
+ * that would like to like to run on that host.
+ * This includes running, pending, and completed tasks.
+ * @return whether the request is acknowledged by the cluster manager.
+ */
+ private[spark] override def requestTotalExecutors(
+ numExecutors: Int,
+ localityAwareTasks: Int,
+ hostToLocalTaskCount: scala.collection.immutable.Map[String, Int]
+ ): Boolean = {
assert(supportDynamicAllocation,
"Requesting executors is currently only supported in YARN and Mesos modes")
schedulerBackend match {
case b: CoarseGrainedSchedulerBackend =>
- b.requestTotalExecutors(numExecutors)
+ b.requestTotalExecutors(numExecutors, localityAwareTasks, hostToLocalTaskCount)
case _ =>
logWarning("Requesting executors is only supported in coarse-grained mode")
false
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index b6a833bbb0..cdf6078421 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -790,8 +790,28 @@ class DAGScheduler(
// serializable. If tasks are not serializable, a SparkListenerStageCompleted event
// will be posted, which should always come after a corresponding SparkListenerStageSubmitted
// event.
- stage.makeNewStageAttempt(partitionsToCompute.size)
outputCommitCoordinator.stageStart(stage.id)
+ val taskIdToLocations = try {
+ stage match {
+ case s: ShuffleMapStage =>
+ partitionsToCompute.map { id => (id, getPreferredLocs(stage.rdd, id))}.toMap
+ case s: ResultStage =>
+ val job = s.resultOfJob.get
+ partitionsToCompute.map { id =>
+ val p = job.partitions(id)
+ (id, getPreferredLocs(stage.rdd, p))
+ }.toMap
+ }
+ } catch {
+ case NonFatal(e) =>
+ stage.makeNewStageAttempt(partitionsToCompute.size)
+ listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties))
+ abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}")
+ runningStages -= stage
+ return
+ }
+
+ stage.makeNewStageAttempt(partitionsToCompute.size, taskIdToLocations.values.toSeq)
listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties))
// TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times.
@@ -830,7 +850,7 @@ class DAGScheduler(
stage match {
case stage: ShuffleMapStage =>
partitionsToCompute.map { id =>
- val locs = getPreferredLocs(stage.rdd, id)
+ val locs = taskIdToLocations(id)
val part = stage.rdd.partitions(id)
new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs)
}
@@ -840,7 +860,7 @@ class DAGScheduler(
partitionsToCompute.map { id =>
val p: Int = job.partitions(id)
val part = stage.rdd.partitions(p)
- val locs = getPreferredLocs(stage.rdd, p)
+ val locs = taskIdToLocations(id)
new ResultTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs, id)
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index b86724de2c..40a333a3e0 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -77,8 +77,11 @@ private[spark] abstract class Stage(
private var _latestInfo: StageInfo = StageInfo.fromStage(this, nextAttemptId)
/** Creates a new attempt for this stage by creating a new StageInfo with a new attempt ID. */
- def makeNewStageAttempt(numPartitionsToCompute: Int): Unit = {
- _latestInfo = StageInfo.fromStage(this, nextAttemptId, Some(numPartitionsToCompute))
+ def makeNewStageAttempt(
+ numPartitionsToCompute: Int,
+ taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty): Unit = {
+ _latestInfo = StageInfo.fromStage(
+ this, nextAttemptId, Some(numPartitionsToCompute), taskLocalityPreferences)
nextAttemptId += 1
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
index 5d2abbc67e..24796c1430 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
@@ -34,7 +34,8 @@ class StageInfo(
val numTasks: Int,
val rddInfos: Seq[RDDInfo],
val parentIds: Seq[Int],
- val details: String) {
+ val details: String,
+ private[spark] val taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty) {
/** When this stage was submitted from the DAGScheduler to a TaskScheduler. */
var submissionTime: Option[Long] = None
/** Time when all tasks in the stage completed or when the stage was cancelled. */
@@ -70,7 +71,12 @@ private[spark] object StageInfo {
* shuffle dependencies. Therefore, all ancestor RDDs related to this Stage's RDD through a
* sequence of narrow dependencies should also be associated with this Stage.
*/
- def fromStage(stage: Stage, attemptId: Int, numTasks: Option[Int] = None): StageInfo = {
+ def fromStage(
+ stage: Stage,
+ attemptId: Int,
+ numTasks: Option[Int] = None,
+ taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty
+ ): StageInfo = {
val ancestorRddInfos = stage.rdd.getNarrowAncestors.map(RDDInfo.fromRdd)
val rddInfos = Seq(RDDInfo.fromRdd(stage.rdd)) ++ ancestorRddInfos
new StageInfo(
@@ -80,6 +86,7 @@ private[spark] object StageInfo {
numTasks.getOrElse(stage.numTasks),
rddInfos,
stage.parents.map(_.id),
- stage.details)
+ stage.details,
+ taskLocalityPreferences)
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
index 4be1eda2e9..06f5438433 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
@@ -86,7 +86,11 @@ private[spark] object CoarseGrainedClusterMessages {
// Request executors by specifying the new total number of executors desired
// This includes executors already pending or running
- case class RequestExecutors(requestedTotal: Int) extends CoarseGrainedClusterMessage
+ case class RequestExecutors(
+ requestedTotal: Int,
+ localityAwareTasks: Int,
+ hostToLocalTaskCount: Map[String, Int])
+ extends CoarseGrainedClusterMessage
case class KillExecutors(executorIds: Seq[String]) extends CoarseGrainedClusterMessage
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index c65b3e5177..660702f6e6 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -66,6 +66,12 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
// Executors we have requested the cluster manager to kill that have not died yet
private val executorsPendingToRemove = new HashSet[String]
+ // A map to store hostname with its possible task number running on it
+ protected var hostToLocalTaskCount: Map[String, Int] = Map.empty
+
+ // The number of pending tasks which is locality required
+ protected var localityAwareTasks = 0
+
class DriverEndpoint(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)])
extends ThreadSafeRpcEndpoint with Logging {
@@ -339,6 +345,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
}
logInfo(s"Requesting $numAdditionalExecutors additional executor(s) from the cluster manager")
logDebug(s"Number of pending executors is now $numPendingExecutors")
+
numPendingExecutors += numAdditionalExecutors
// Account for executors pending to be added or removed
val newTotal = numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size
@@ -346,16 +353,33 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
}
/**
- * Express a preference to the cluster manager for a given total number of executors. This can
- * result in canceling pending requests or filing additional requests.
- * @return whether the request is acknowledged.
+ * Update the cluster manager on our scheduling needs. Three bits of information are included
+ * to help it make decisions.
+ * @param numExecutors The total number of executors we'd like to have. The cluster manager
+ * shouldn't kill any running executor to reach this number, but,
+ * if all existing executors were to die, this is the number of executors
+ * we'd want to be allocated.
+ * @param localityAwareTasks The number of tasks in all active stages that have a locality
+ * preferences. This includes running, pending, and completed tasks.
+ * @param hostToLocalTaskCount A map of hosts to the number of tasks from all active stages
+ * that would like to like to run on that host.
+ * This includes running, pending, and completed tasks.
+ * @return whether the request is acknowledged by the cluster manager.
*/
- final override def requestTotalExecutors(numExecutors: Int): Boolean = synchronized {
+ final override def requestTotalExecutors(
+ numExecutors: Int,
+ localityAwareTasks: Int,
+ hostToLocalTaskCount: Map[String, Int]
+ ): Boolean = synchronized {
if (numExecutors < 0) {
throw new IllegalArgumentException(
"Attempted to request a negative number of executor(s) " +
s"$numExecutors from the cluster manager. Please specify a positive number!")
}
+
+ this.localityAwareTasks = localityAwareTasks
+ this.hostToLocalTaskCount = hostToLocalTaskCount
+
numPendingExecutors =
math.max(numExecutors - numExistingExecutors + executorsPendingToRemove.size, 0)
doRequestTotalExecutors(numExecutors)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
index bc67abb5df..074282d1be 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
@@ -53,7 +53,8 @@ private[spark] abstract class YarnSchedulerBackend(
* This includes executors already pending or running.
*/
override def doRequestTotalExecutors(requestedTotal: Int): Boolean = {
- yarnSchedulerEndpoint.askWithRetry[Boolean](RequestExecutors(requestedTotal))
+ yarnSchedulerEndpoint.askWithRetry[Boolean](
+ RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount))
}
/**
diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
index 803e1831bb..34caca8928 100644
--- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
@@ -751,6 +751,42 @@ class ExecutorAllocationManagerSuite
assert(numExecutorsTarget(manager) === 2)
}
+ test("get pending task number and related locality preference") {
+ sc = createSparkContext(2, 5, 3)
+ val manager = sc.executorAllocationManager.get
+
+ val localityPreferences1 = Seq(
+ Seq(TaskLocation("host1"), TaskLocation("host2"), TaskLocation("host3")),
+ Seq(TaskLocation("host1"), TaskLocation("host2"), TaskLocation("host4")),
+ Seq(TaskLocation("host2"), TaskLocation("host3"), TaskLocation("host4")),
+ Seq.empty,
+ Seq.empty
+ )
+ val stageInfo1 = createStageInfo(1, 5, localityPreferences1)
+ sc.listenerBus.postToAll(SparkListenerStageSubmitted(stageInfo1))
+
+ assert(localityAwareTasks(manager) === 3)
+ assert(hostToLocalTaskCount(manager) ===
+ Map("host1" -> 2, "host2" -> 3, "host3" -> 2, "host4" -> 2))
+
+ val localityPreferences2 = Seq(
+ Seq(TaskLocation("host2"), TaskLocation("host3"), TaskLocation("host5")),
+ Seq(TaskLocation("host3"), TaskLocation("host4"), TaskLocation("host5")),
+ Seq.empty
+ )
+ val stageInfo2 = createStageInfo(2, 3, localityPreferences2)
+ sc.listenerBus.postToAll(SparkListenerStageSubmitted(stageInfo2))
+
+ assert(localityAwareTasks(manager) === 5)
+ assert(hostToLocalTaskCount(manager) ===
+ Map("host1" -> 2, "host2" -> 4, "host3" -> 4, "host4" -> 3, "host5" -> 2))
+
+ sc.listenerBus.postToAll(SparkListenerStageCompleted(stageInfo1))
+ assert(localityAwareTasks(manager) === 2)
+ assert(hostToLocalTaskCount(manager) ===
+ Map("host2" -> 1, "host3" -> 2, "host4" -> 1, "host5" -> 2))
+ }
+
private def createSparkContext(
minExecutors: Int = 1,
maxExecutors: Int = 5,
@@ -784,8 +820,13 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester {
private val sustainedSchedulerBacklogTimeout = 2L
private val executorIdleTimeout = 3L
- private def createStageInfo(stageId: Int, numTasks: Int): StageInfo = {
- new StageInfo(stageId, 0, "name", numTasks, Seq.empty, Seq.empty, "no details")
+ private def createStageInfo(
+ stageId: Int,
+ numTasks: Int,
+ taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty
+ ): StageInfo = {
+ new StageInfo(
+ stageId, 0, "name", numTasks, Seq.empty, Seq.empty, "no details", taskLocalityPreferences)
}
private def createTaskInfo(taskId: Int, taskIndex: Int, executorId: String): TaskInfo = {
@@ -815,6 +856,8 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester {
private val _onSchedulerQueueEmpty = PrivateMethod[Unit]('onSchedulerQueueEmpty)
private val _onExecutorIdle = PrivateMethod[Unit]('onExecutorIdle)
private val _onExecutorBusy = PrivateMethod[Unit]('onExecutorBusy)
+ private val _localityAwareTasks = PrivateMethod[Int]('localityAwareTasks)
+ private val _hostToLocalTaskCount = PrivateMethod[Map[String, Int]]('hostToLocalTaskCount)
private def numExecutorsToAdd(manager: ExecutorAllocationManager): Int = {
manager invokePrivate _numExecutorsToAdd()
@@ -885,4 +928,12 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester {
private def onExecutorBusy(manager: ExecutorAllocationManager, id: String): Unit = {
manager invokePrivate _onExecutorBusy(id)
}
+
+ private def localityAwareTasks(manager: ExecutorAllocationManager): Int = {
+ manager invokePrivate _localityAwareTasks()
+ }
+
+ private def hostToLocalTaskCount(manager: ExecutorAllocationManager): Map[String, Int] = {
+ manager invokePrivate _hostToLocalTaskCount()
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
index 5a2670e4d1..139b8dc25f 100644
--- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
+++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala
@@ -182,7 +182,7 @@ class HeartbeatReceiverSuite
// Adjust the target number of executors on the cluster manager side
assert(fakeClusterManager.getTargetNumExecutors === 0)
- sc.requestTotalExecutors(2)
+ sc.requestTotalExecutors(2, 0, Map.empty)
assert(fakeClusterManager.getTargetNumExecutors === 2)
assert(fakeClusterManager.getExecutorIdsToKill.isEmpty)
@@ -241,7 +241,8 @@ private class FakeSchedulerBackend(
extends CoarseGrainedSchedulerBackend(scheduler, rpcEnv) {
protected override def doRequestTotalExecutors(requestedTotal: Int): Boolean = {
- clusterManagerEndpoint.askWithRetry[Boolean](RequestExecutors(requestedTotal))
+ clusterManagerEndpoint.askWithRetry[Boolean](
+ RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount))
}
protected override def doKillExecutors(executorIds: Seq[String]): Boolean = {
@@ -260,7 +261,7 @@ private class FakeClusterManager(override val rpcEnv: RpcEnv) extends RpcEndpoin
def getExecutorIdsToKill: Set[String] = executorIdsToKill.toSet
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
- case RequestExecutors(requestedTotal) =>
+ case RequestExecutors(requestedTotal, _, _) =>
targetNumExecutors = requestedTotal
context.reply(true)
case KillExecutors(executorIds) =>