aboutsummaryrefslogtreecommitdiff
path: root/core/src
diff options
context:
space:
mode:
authorKay Ousterhout <kayousterhout@gmail.com>2016-06-14 17:26:33 -0700
committerKay Ousterhout <kayousterhout@gmail.com>2016-06-14 17:27:01 -0700
commit5d50d4f0f9db3e6cc7c51e35cdb2d12daa4fd108 (patch)
tree346f79d8713d1922a979e5d9decf857b94ffcf9f /core/src
parentdae4d5db21368faaa46fa8d1a256c27428694c2c (diff)
downloadspark-5d50d4f0f9db3e6cc7c51e35cdb2d12daa4fd108.tar.gz
spark-5d50d4f0f9db3e6cc7c51e35cdb2d12daa4fd108.tar.bz2
spark-5d50d4f0f9db3e6cc7c51e35cdb2d12daa4fd108.zip
[SPARK-15927] Eliminate redundant DAGScheduler code.
To try to eliminate redundant code to traverse the RDD dependency graph, this PR creates a new function getShuffleDependencies that returns shuffle dependencies that are immediate parents of a given RDD. This new function is used by getParentStages and getAncestorShuffleDependencies. Author: Kay Ousterhout <kayousterhout@gmail.com> Closes #13646 from kayousterhout/SPARK-15927.
Diffstat (limited to 'core/src')
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala82
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala31
2 files changed, 74 insertions, 39 deletions
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index a2eadbcbd6..4e1250a14d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -378,59 +378,63 @@ class DAGScheduler(
* the provided firstJobId.
*/
private def getParentStages(rdd: RDD[_], firstJobId: Int): List[Stage] = {
- val parents = new HashSet[Stage]
- val visited = new HashSet[RDD[_]]
- // We are manually maintaining a stack here to prevent StackOverflowError
- // caused by recursively visiting
- val waitingForVisit = new Stack[RDD[_]]
- def visit(r: RDD[_]) {
- if (!visited(r)) {
- visited += r
- // Kind of ugly: need to register RDDs with the cache here since
- // we can't do it in its constructor because # of partitions is unknown
- for (dep <- r.dependencies) {
- dep match {
- case shufDep: ShuffleDependency[_, _, _] =>
- parents += getShuffleMapStage(shufDep, firstJobId)
- case _ =>
- waitingForVisit.push(dep.rdd)
- }
- }
- }
- }
- waitingForVisit.push(rdd)
- while (waitingForVisit.nonEmpty) {
- visit(waitingForVisit.pop())
- }
- parents.toList
+ getShuffleDependencies(rdd).map { shuffleDep =>
+ getShuffleMapStage(shuffleDep, firstJobId)
+ }.toList
}
/** Find ancestor shuffle dependencies that are not registered in shuffleToMapStage yet */
private def getAncestorShuffleDependencies(rdd: RDD[_]): Stack[ShuffleDependency[_, _, _]] = {
- val parents = new Stack[ShuffleDependency[_, _, _]]
+ val ancestors = new Stack[ShuffleDependency[_, _, _]]
val visited = new HashSet[RDD[_]]
// We are manually maintaining a stack here to prevent StackOverflowError
// caused by recursively visiting
val waitingForVisit = new Stack[RDD[_]]
- def visit(r: RDD[_]) {
- if (!visited(r)) {
- visited += r
- for (dep <- r.dependencies) {
- dep match {
- case shufDep: ShuffleDependency[_, _, _] =>
- if (!shuffleToMapStage.contains(shufDep.shuffleId)) {
- parents.push(shufDep)
- }
- case _ =>
- }
- waitingForVisit.push(dep.rdd)
+ waitingForVisit.push(rdd)
+ while (waitingForVisit.nonEmpty) {
+ val toVisit = waitingForVisit.pop()
+ if (!visited(toVisit)) {
+ visited += toVisit
+ getShuffleDependencies(toVisit).foreach { shuffleDep =>
+ if (!shuffleToMapStage.contains(shuffleDep.shuffleId)) {
+ ancestors.push(shuffleDep)
+ waitingForVisit.push(shuffleDep.rdd)
+ } // Otherwise, the dependency and its ancestors have already been registered.
}
}
}
+ ancestors
+ }
+ /**
+ * Returns shuffle dependencies that are immediate parents of the given RDD.
+ *
+ * This function will not return more distant ancestors. For example, if C has a shuffle
+ * dependency on B which has a shuffle dependency on A:
+ *
+ * A <-- B <-- C
+ *
+ * calling this function with rdd C will only return the B <-- C dependency.
+ *
+ * This function is scheduler-visible for the purpose of unit testing.
+ */
+ private[scheduler] def getShuffleDependencies(
+ rdd: RDD[_]): HashSet[ShuffleDependency[_, _, _]] = {
+ val parents = new HashSet[ShuffleDependency[_, _, _]]
+ val visited = new HashSet[RDD[_]]
+ val waitingForVisit = new Stack[RDD[_]]
waitingForVisit.push(rdd)
while (waitingForVisit.nonEmpty) {
- visit(waitingForVisit.pop())
+ val toVisit = waitingForVisit.pop()
+ if (!visited(toVisit)) {
+ visited += toVisit
+ toVisit.dependencies.foreach {
+ case shuffleDep: ShuffleDependency[_, _, _] =>
+ parents += shuffleDep
+ case dependency =>
+ waitingForVisit.push(dependency.rdd)
+ }
+ }
}
parents
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 3c30ec8ee8..ab8e95314f 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -2024,6 +2024,37 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
}
/**
+ * Checks the DAGScheduler's internal logic for traversing a RDD DAG by making sure that
+ * getShuffleDependencies correctly returns the direct shuffle dependencies of a particular
+ * RDD. The test creates the following RDD graph (where n denotes a narrow dependency and s
+ * denotes a shuffle dependency):
+ *
+ * A <------------s---------,
+ * \
+ * B <--s-- C <--s-- D <--n---`-- E
+ *
+ * Here, the direct shuffle dependency of C is just the shuffle dependency on B. The direct
+ * shuffle dependencies of E are the shuffle dependency on A and the shuffle dependency on C.
+ */
+ test("getShuffleDependencies correctly returns only direct shuffle parents") {
+ val rddA = new MyRDD(sc, 2, Nil)
+ val shuffleDepA = new ShuffleDependency(rddA, new HashPartitioner(1))
+ val rddB = new MyRDD(sc, 2, Nil)
+ val shuffleDepB = new ShuffleDependency(rddB, new HashPartitioner(1))
+ val rddC = new MyRDD(sc, 1, List(shuffleDepB))
+ val shuffleDepC = new ShuffleDependency(rddC, new HashPartitioner(1))
+ val rddD = new MyRDD(sc, 1, List(shuffleDepC))
+ val narrowDepD = new OneToOneDependency(rddD)
+ val rddE = new MyRDD(sc, 1, List(shuffleDepA, narrowDepD), tracker = mapOutputTracker)
+
+ assert(scheduler.getShuffleDependencies(rddA) === Set())
+ assert(scheduler.getShuffleDependencies(rddB) === Set())
+ assert(scheduler.getShuffleDependencies(rddC) === Set(shuffleDepB))
+ assert(scheduler.getShuffleDependencies(rddD) === Set(shuffleDepC))
+ assert(scheduler.getShuffleDependencies(rddE) === Set(shuffleDepA, shuffleDepC))
+ }
+
+ /**
* Assert that the supplied TaskSet has exactly the given hosts as its preferred locations.
* Note that this checks only the host and not the executor ID.
*/