aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala11
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala40
2 files changed, 46 insertions, 5 deletions
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 20a4bd12f9..d9d53faf84 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -690,8 +690,7 @@ private[spark] class TaskSetManager(
handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure)
}
// recalculate valid locality levels and waits when executor is lost
- myLocalityLevels = computeValidLocalityLevels()
- localityWaits = myLocalityLevels.map(getLocalityWait)
+ recomputeLocality()
}
/**
@@ -775,9 +774,15 @@ private[spark] class TaskSetManager(
levels.toArray
}
- def executorAdded() {
+ def recomputeLocality() {
+ val previousLocalityLevel = myLocalityLevels(currentLocalityIndex)
myLocalityLevels = computeValidLocalityLevels()
localityWaits = myLocalityLevels.map(getLocalityWait)
+ currentLocalityIndex = getLocalityIndex(previousLocalityLevel)
+ }
+
+ def executorAdded() {
+ recomputeLocality()
}
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index ffd23380a8..93e8ddacf8 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -154,6 +154,11 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
val LOCALITY_WAIT = conf.getLong("spark.locality.wait", 3000)
val MAX_TASK_FAILURES = 4
+ override def beforeEach() {
+ super.beforeEach()
+ FakeRackUtil.cleanUp()
+ }
+
test("TaskSet with no preferences") {
sc = new SparkContext("local", "test")
val sched = new FakeTaskScheduler(sc, ("exec1", "host1"))
@@ -471,7 +476,6 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
test("new executors get added and lost") {
// Assign host2 to rack2
- FakeRackUtil.cleanUp()
FakeRackUtil.assignHostToRack("host2", "rack2")
sc = new SparkContext("local", "test")
val sched = new FakeTaskScheduler(sc)
@@ -504,7 +508,6 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
}
test("test RACK_LOCAL tasks") {
- FakeRackUtil.cleanUp()
// Assign host1 to rack1
FakeRackUtil.assignHostToRack("host1", "rack1")
// Assign host2 to rack1
@@ -607,6 +610,39 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
assert(manager.resourceOffer("execA", "host3", NO_PREF).get.index === 2)
}
+ test("Ensure TaskSetManager is usable after addition of levels") {
+ // Regression test for SPARK-2931
+ sc = new SparkContext("local", "test")
+ val sched = new FakeTaskScheduler(sc)
+ val taskSet = FakeTask.createTaskSet(2,
+ Seq(TaskLocation("host1", "execA")),
+ Seq(TaskLocation("host2", "execB.1")))
+ val clock = new FakeClock
+ val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock)
+ // Only ANY is valid
+ assert(manager.myLocalityLevels.sameElements(Array(ANY)))
+ // Add a new executor
+ sched.addExecutor("execA", "host1")
+ sched.addExecutor("execB.2", "host2")
+ manager.executorAdded()
+ assert(manager.pendingTasksWithNoPrefs.size === 0)
+ // Valid locality should contain PROCESS_LOCAL, NODE_LOCAL and ANY
+ assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY)))
+ assert(manager.resourceOffer("execA", "host1", ANY) !== None)
+ clock.advance(LOCALITY_WAIT * 4)
+ assert(manager.resourceOffer("execB.2", "host2", ANY) !== None)
+ sched.removeExecutor("execA")
+ sched.removeExecutor("execB.2")
+ manager.executorLost("execA", "host1")
+ manager.executorLost("execB.2", "host2")
+ clock.advance(LOCALITY_WAIT * 4)
+ sched.addExecutor("execC", "host3")
+ manager.executorAdded()
+ // Prior to the fix, this line resulted in an ArrayIndexOutOfBoundsException:
+ assert(manager.resourceOffer("execC", "host3", ANY) !== None)
+ }
+
+
def createTaskResult(id: Int): DirectTaskResult[Int] = {
val valueSer = SparkEnv.get.serializer.newInstance()
new DirectTaskResult[Int](valueSer.serialize(id), mutable.Map.empty, new TaskMetrics)