diff options
author | Kay Ousterhout <kayousterhout@gmail.com> | 2016-06-14 17:26:33 -0700 |
---|---|---|
committer | Kay Ousterhout <kayousterhout@gmail.com> | 2016-06-14 17:27:01 -0700 |
commit | 5d50d4f0f9db3e6cc7c51e35cdb2d12daa4fd108 (patch) | |
tree | 346f79d8713d1922a979e5d9decf857b94ffcf9f | |
parent | dae4d5db21368faaa46fa8d1a256c27428694c2c (diff) | |
download | spark-5d50d4f0f9db3e6cc7c51e35cdb2d12daa4fd108.tar.gz spark-5d50d4f0f9db3e6cc7c51e35cdb2d12daa4fd108.tar.bz2 spark-5d50d4f0f9db3e6cc7c51e35cdb2d12daa4fd108.zip |
[SPARK-15927] Eliminate redundant DAGScheduler code.
To try to eliminate redundant code to traverse the RDD dependency graph,
this PR creates a new function getShuffleDependencies that returns
shuffle dependencies that are immediate parents of a given RDD. This
new function is used by getParentStages and
getAncestorShuffleDependencies.
Author: Kay Ousterhout <kayousterhout@gmail.com>
Closes #13646 from kayousterhout/SPARK-15927.
-rw-r--r-- | core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala | 82 | ||||
-rw-r--r-- | core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala | 31 |
2 files changed, 74 insertions, 39 deletions
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index a2eadbcbd6..4e1250a14d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -378,59 +378,63 @@ class DAGScheduler( * the provided firstJobId. */ private def getParentStages(rdd: RDD[_], firstJobId: Int): List[Stage] = { - val parents = new HashSet[Stage] - val visited = new HashSet[RDD[_]] - // We are manually maintaining a stack here to prevent StackOverflowError - // caused by recursively visiting - val waitingForVisit = new Stack[RDD[_]] - def visit(r: RDD[_]) { - if (!visited(r)) { - visited += r - // Kind of ugly: need to register RDDs with the cache here since - // we can't do it in its constructor because # of partitions is unknown - for (dep <- r.dependencies) { - dep match { - case shufDep: ShuffleDependency[_, _, _] => - parents += getShuffleMapStage(shufDep, firstJobId) - case _ => - waitingForVisit.push(dep.rdd) - } - } - } - } - waitingForVisit.push(rdd) - while (waitingForVisit.nonEmpty) { - visit(waitingForVisit.pop()) - } - parents.toList + getShuffleDependencies(rdd).map { shuffleDep => + getShuffleMapStage(shuffleDep, firstJobId) + }.toList } /** Find ancestor shuffle dependencies that are not registered in shuffleToMapStage yet */ private def getAncestorShuffleDependencies(rdd: RDD[_]): Stack[ShuffleDependency[_, _, _]] = { - val parents = new Stack[ShuffleDependency[_, _, _]] + val ancestors = new Stack[ShuffleDependency[_, _, _]] val visited = new HashSet[RDD[_]] // We are manually maintaining a stack here to prevent StackOverflowError // caused by recursively visiting val waitingForVisit = new Stack[RDD[_]] - def visit(r: RDD[_]) { - if (!visited(r)) { - visited += r - for (dep <- r.dependencies) { - dep match { - case shufDep: ShuffleDependency[_, _, _] => - if (!shuffleToMapStage.contains(shufDep.shuffleId)) { - parents.push(shufDep) - } - case _ => - } - waitingForVisit.push(dep.rdd) + waitingForVisit.push(rdd) + while (waitingForVisit.nonEmpty) { + val toVisit = waitingForVisit.pop() + if (!visited(toVisit)) { + visited += toVisit + getShuffleDependencies(toVisit).foreach { shuffleDep => + if (!shuffleToMapStage.contains(shuffleDep.shuffleId)) { + ancestors.push(shuffleDep) + waitingForVisit.push(shuffleDep.rdd) + } // Otherwise, the dependency and its ancestors have already been registered. } } } + ancestors + } + /** + * Returns shuffle dependencies that are immediate parents of the given RDD. + * + * This function will not return more distant ancestors. For example, if C has a shuffle + * dependency on B which has a shuffle dependency on A: + * + * A <-- B <-- C + * + * calling this function with rdd C will only return the B <-- C dependency. + * + * This function is scheduler-visible for the purpose of unit testing. + */ + private[scheduler] def getShuffleDependencies( + rdd: RDD[_]): HashSet[ShuffleDependency[_, _, _]] = { + val parents = new HashSet[ShuffleDependency[_, _, _]] + val visited = new HashSet[RDD[_]] + val waitingForVisit = new Stack[RDD[_]] waitingForVisit.push(rdd) while (waitingForVisit.nonEmpty) { - visit(waitingForVisit.pop()) + val toVisit = waitingForVisit.pop() + if (!visited(toVisit)) { + visited += toVisit + toVisit.dependencies.foreach { + case shuffleDep: ShuffleDependency[_, _, _] => + parents += shuffleDep + case dependency => + waitingForVisit.push(dependency.rdd) + } + } } parents } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 3c30ec8ee8..ab8e95314f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -2024,6 +2024,37 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou } /** + * Checks the DAGScheduler's internal logic for traversing a RDD DAG by making sure that + * getShuffleDependencies correctly returns the direct shuffle dependencies of a particular + * RDD. The test creates the following RDD graph (where n denotes a narrow dependency and s + * denotes a shuffle dependency): + * + * A <------------s---------, + * \ + * B <--s-- C <--s-- D <--n---`-- E + * + * Here, the direct shuffle dependency of C is just the shuffle dependency on B. The direct + * shuffle dependencies of E are the shuffle dependency on A and the shuffle dependency on C. + */ + test("getShuffleDependencies correctly returns only direct shuffle parents") { + val rddA = new MyRDD(sc, 2, Nil) + val shuffleDepA = new ShuffleDependency(rddA, new HashPartitioner(1)) + val rddB = new MyRDD(sc, 2, Nil) + val shuffleDepB = new ShuffleDependency(rddB, new HashPartitioner(1)) + val rddC = new MyRDD(sc, 1, List(shuffleDepB)) + val shuffleDepC = new ShuffleDependency(rddC, new HashPartitioner(1)) + val rddD = new MyRDD(sc, 1, List(shuffleDepC)) + val narrowDepD = new OneToOneDependency(rddD) + val rddE = new MyRDD(sc, 1, List(shuffleDepA, narrowDepD), tracker = mapOutputTracker) + + assert(scheduler.getShuffleDependencies(rddA) === Set()) + assert(scheduler.getShuffleDependencies(rddB) === Set()) + assert(scheduler.getShuffleDependencies(rddC) === Set(shuffleDepB)) + assert(scheduler.getShuffleDependencies(rddD) === Set(shuffleDepC)) + assert(scheduler.getShuffleDependencies(rddE) === Set(shuffleDepA, shuffleDepC)) + } + + /** * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations. * Note that this checks only the host and not the executor ID. */ |