aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@gmail.com>2014-08-01 12:12:30 -0700
committerMatei Zaharia <matei@databricks.com>2014-08-01 12:12:30 -0700
commitbaf9ce1a4ecb7acf5accf7a7029f29604b4360c2 (patch)
treeefd7a6c96a7eb2dcea37a330e519069d3d15dad9 /core
parenteb5bdcaf6c7834558cb76b7132f68b8d94230356 (diff)
downloadspark-baf9ce1a4ecb7acf5accf7a7029f29604b4360c2.tar.gz
spark-baf9ce1a4ecb7acf5accf7a7029f29604b4360c2.tar.bz2
spark-baf9ce1a4ecb7acf5accf7a7029f29604b4360c2.zip
[SPARK-2490] Change recursive visiting on RDD dependencies to iterative approach
When performing some transformations on RDDs after many iterations, the dependencies of RDDs could be very long. It can easily cause StackOverflowError when recursively visiting these dependencies in Spark core. For example: var rdd = sc.makeRDD(Array(1)) for (i <- 1 to 1000) { rdd = rdd.coalesce(1).cache() rdd.collect() } This PR changes recursive visiting on rdd's dependencies to iterative approach to avoid StackOverflowError. In addition to the recursive visiting, since the Java serializer has a known [bug](http://bugs.java.com/bugdatabase/view_bug.do?bug_id=4152790) that causes StackOverflowError too when serializing/deserializing a large graph of objects. So applying this PR only solves part of the problem. Using KryoSerializer to replace Java serializer might be helpful. However, since KryoSerializer is not supported for `spark.closure.serializer` now, I can not test if KryoSerializer can solve Java serializer's problem completely. Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #1418 from viirya/remove_recursive_visit and squashes the following commits: 6b2c615 [Liang-Chi Hsieh] change function name; comply with code style. 5f072a7 [Liang-Chi Hsieh] add comments to explain Stack usage. 8742dbb [Liang-Chi Hsieh] comply with code style. 900538b [Liang-Chi Hsieh] change recursive visiting on rdd's dependencies to iterative approach to avoid stackoverflowerror.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala83
1 files changed, 75 insertions, 8 deletions
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 5110785de3..d87c304898 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -21,7 +21,7 @@ import java.io.NotSerializableException
import java.util.Properties
import java.util.concurrent.atomic.AtomicInteger
-import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack}
import scala.concurrent.Await
import scala.concurrent.duration._
import scala.language.postfixOps
@@ -211,11 +211,15 @@ class DAGScheduler(
shuffleToMapStage.get(shuffleDep.shuffleId) match {
case Some(stage) => stage
case None =>
+ // We are going to register ancestor shuffle dependencies
+ registerShuffleDependencies(shuffleDep, jobId)
+ // Then register current shuffleDep
val stage =
newOrUsedStage(
shuffleDep.rdd, shuffleDep.rdd.partitions.size, shuffleDep, jobId,
shuffleDep.rdd.creationSite)
shuffleToMapStage(shuffleDep.shuffleId) = stage
+
stage
}
}
@@ -280,6 +284,9 @@ class DAGScheduler(
private def getParentStages(rdd: RDD[_], jobId: Int): List[Stage] = {
val parents = new HashSet[Stage]
val visited = new HashSet[RDD[_]]
+ // We are manually maintaining a stack here to prevent StackOverflowError
+ // caused by recursively visiting
+ val waitingForVisit = new Stack[RDD[_]]
def visit(r: RDD[_]) {
if (!visited(r)) {
visited += r
@@ -290,18 +297,69 @@ class DAGScheduler(
case shufDep: ShuffleDependency[_, _, _] =>
parents += getShuffleMapStage(shufDep, jobId)
case _ =>
- visit(dep.rdd)
+ waitingForVisit.push(dep.rdd)
}
}
}
}
- visit(rdd)
+ waitingForVisit.push(rdd)
+ while (!waitingForVisit.isEmpty) {
+ visit(waitingForVisit.pop())
+ }
parents.toList
}
+ // Find ancestor missing shuffle dependencies and register into shuffleToMapStage
+ private def registerShuffleDependencies(shuffleDep: ShuffleDependency[_, _, _], jobId: Int) = {
+ val parentsWithNoMapStage = getAncestorShuffleDependencies(shuffleDep.rdd)
+ while (!parentsWithNoMapStage.isEmpty) {
+ val currentShufDep = parentsWithNoMapStage.pop()
+ val stage =
+ newOrUsedStage(
+ currentShufDep.rdd, currentShufDep.rdd.partitions.size, currentShufDep, jobId,
+ currentShufDep.rdd.creationSite)
+ shuffleToMapStage(currentShufDep.shuffleId) = stage
+ }
+ }
+
+ // Find ancestor shuffle dependencies that are not registered in shuffleToMapStage yet
+ private def getAncestorShuffleDependencies(rdd: RDD[_]): Stack[ShuffleDependency[_, _, _]] = {
+ val parents = new Stack[ShuffleDependency[_, _, _]]
+ val visited = new HashSet[RDD[_]]
+ // We are manually maintaining a stack here to prevent StackOverflowError
+ // caused by recursively visiting
+ val waitingForVisit = new Stack[RDD[_]]
+ def visit(r: RDD[_]) {
+ if (!visited(r)) {
+ visited += r
+ for (dep <- r.dependencies) {
+ dep match {
+ case shufDep: ShuffleDependency[_, _, _] =>
+ if (!shuffleToMapStage.contains(shufDep.shuffleId)) {
+ parents.push(shufDep)
+ }
+
+ waitingForVisit.push(shufDep.rdd)
+ case _ =>
+ waitingForVisit.push(dep.rdd)
+ }
+ }
+ }
+ }
+
+ waitingForVisit.push(rdd)
+ while (!waitingForVisit.isEmpty) {
+ visit(waitingForVisit.pop())
+ }
+ parents
+ }
+
private def getMissingParentStages(stage: Stage): List[Stage] = {
val missing = new HashSet[Stage]
val visited = new HashSet[RDD[_]]
+ // We are manually maintaining a stack here to prevent StackOverflowError
+ // caused by recursively visiting
+ val waitingForVisit = new Stack[RDD[_]]
def visit(rdd: RDD[_]) {
if (!visited(rdd)) {
visited += rdd
@@ -314,13 +372,16 @@ class DAGScheduler(
missing += mapStage
}
case narrowDep: NarrowDependency[_] =>
- visit(narrowDep.rdd)
+ waitingForVisit.push(narrowDep.rdd)
}
}
}
}
}
- visit(stage.rdd)
+ waitingForVisit.push(stage.rdd)
+ while (!waitingForVisit.isEmpty) {
+ visit(waitingForVisit.pop())
+ }
missing.toList
}
@@ -1119,6 +1180,9 @@ class DAGScheduler(
}
val visitedRdds = new HashSet[RDD[_]]
val visitedStages = new HashSet[Stage]
+ // We are manually maintaining a stack here to prevent StackOverflowError
+ // caused by recursively visiting
+ val waitingForVisit = new Stack[RDD[_]]
def visit(rdd: RDD[_]) {
if (!visitedRdds(rdd)) {
visitedRdds += rdd
@@ -1128,15 +1192,18 @@ class DAGScheduler(
val mapStage = getShuffleMapStage(shufDep, stage.jobId)
if (!mapStage.isAvailable) {
visitedStages += mapStage
- visit(mapStage.rdd)
+ waitingForVisit.push(mapStage.rdd)
} // Otherwise there's no need to follow the dependency back
case narrowDep: NarrowDependency[_] =>
- visit(narrowDep.rdd)
+ waitingForVisit.push(narrowDep.rdd)
}
}
}
}
- visit(stage.rdd)
+ waitingForVisit.push(stage.rdd)
+ while (!waitingForVisit.isEmpty) {
+ visit(waitingForVisit.pop())
+ }
visitedRdds.contains(target.rdd)
}