aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala15
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala35
2 files changed, 42 insertions, 8 deletions
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index a083be2448..a2299e907c 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -193,9 +193,15 @@ class DAGScheduler(
def getCacheLocs(rdd: RDD[_]): Seq[Seq[TaskLocation]] = cacheLocs.synchronized {
// Note: this doesn't use `getOrElse()` because this method is called O(num tasks) times
if (!cacheLocs.contains(rdd.id)) {
- val blockIds = rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)).toArray[BlockId]
- val locs: Seq[Seq[TaskLocation]] = blockManagerMaster.getLocations(blockIds).map { bms =>
- bms.map(bm => TaskLocation(bm.host, bm.executorId))
+ // Note: if the storage level is NONE, we don't need to get locations from block manager.
+ val locs: Seq[Seq[TaskLocation]] = if (rdd.getStorageLevel == StorageLevel.NONE) {
+ Seq.fill(rdd.partitions.size)(Nil)
+ } else {
+ val blockIds =
+ rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)).toArray[BlockId]
+ blockManagerMaster.getLocations(blockIds).map { bms =>
+ bms.map(bm => TaskLocation(bm.host, bm.executorId))
+ }
}
cacheLocs(rdd.id) = locs
}
@@ -382,7 +388,8 @@ class DAGScheduler(
def visit(rdd: RDD[_]) {
if (!visited(rdd)) {
visited += rdd
- if (getCacheLocs(rdd).contains(Nil)) {
+ val rddHasUncachedPartitions = getCacheLocs(rdd).contains(Nil)
+ if (rddHasUncachedPartitions) {
for (dep <- rdd.dependencies) {
dep match {
case shufDep: ShuffleDependency[_, _, _] =>
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 6a8ae29aae..46642236e4 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -318,7 +318,7 @@ class DAGSchedulerSuite
}
test("cache location preferences w/ dependency") {
- val baseRdd = new MyRDD(sc, 1, Nil)
+ val baseRdd = new MyRDD(sc, 1, Nil).cache()
val finalRdd = new MyRDD(sc, 1, List(new OneToOneDependency(baseRdd)))
cacheLocations(baseRdd.id -> 0) =
Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))
@@ -331,7 +331,7 @@ class DAGSchedulerSuite
}
test("regression test for getCacheLocs") {
- val rdd = new MyRDD(sc, 3, Nil)
+ val rdd = new MyRDD(sc, 3, Nil).cache()
cacheLocations(rdd.id -> 0) =
Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))
cacheLocations(rdd.id -> 1) =
@@ -342,6 +342,33 @@ class DAGSchedulerSuite
assert(locs === Seq(Seq("hostA", "hostB"), Seq("hostB", "hostC"), Seq("hostC", "hostD")))
}
+ /**
+ * This test ensures that if a particular RDD is cached, RDDs earlier in the dependency chain
+ * are not computed. It constructs the following chain of dependencies:
+ * +---+ shuffle +---+ +---+ +---+
+ * | A |<--------| B |<---| C |<---| D |
+ * +---+ +---+ +---+ +---+
+ * Here, B is derived from A by performing a shuffle, C has a one-to-one dependency on B,
+ * and D similarly has a one-to-one dependency on C. If none of the RDDs were cached, this
+ * set of RDDs would result in a two stage job: one ShuffleMapStage, and a ResultStage that
+ * reads the shuffled data from RDD A. This test ensures that if C is cached, the scheduler
+ * doesn't perform a shuffle, and instead computes the result using a single ResultStage
+ * that reads C's cached data.
+ */
+ test("getMissingParentStages should consider all ancestor RDDs' cache statuses") {
+ val rddA = new MyRDD(sc, 1, Nil)
+ val rddB = new MyRDD(sc, 1, List(new ShuffleDependency(rddA, null)))
+ val rddC = new MyRDD(sc, 1, List(new OneToOneDependency(rddB))).cache()
+ val rddD = new MyRDD(sc, 1, List(new OneToOneDependency(rddC)))
+ cacheLocations(rddC.id -> 0) =
+ Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))
+ submit(rddD, Array(0))
+ assert(scheduler.runningStages.size === 1)
+ // Make sure that the scheduler is running the final result stage.
+ // Because C is cached, the shuffle map stage to compute A does not need to be run.
+ assert(scheduler.runningStages.head.isInstanceOf[ResultStage])
+ }
+
test("avoid exponential blowup when getting preferred locs list") {
// Build up a complex dependency graph with repeated zip operations, without preferred locations
var rdd: RDD[_] = new MyRDD(sc, 1, Nil)
@@ -678,9 +705,9 @@ class DAGSchedulerSuite
}
test("cached post-shuffle") {
- val shuffleOneRdd = new MyRDD(sc, 2, Nil)
+ val shuffleOneRdd = new MyRDD(sc, 2, Nil).cache()
val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
- val shuffleTwoRdd = new MyRDD(sc, 2, List(shuffleDepOne))
+ val shuffleTwoRdd = new MyRDD(sc, 2, List(shuffleDepOne)).cache()
val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
val finalRdd = new MyRDD(sc, 1, List(shuffleDepTwo))
submit(finalRdd, Array(0))