aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorShixiong Zhu <shixiong@databricks.com>2016-02-01 11:02:17 -0800
committerAndrew Or <andrew@databricks.com>2016-02-01 11:02:17 -0800
commit6075573a93176ee8c071888e4525043d9e73b061 (patch)
tree45cdc80c2f00b52ac5b5f4aaabb04e3e822557fe
parentc1da4d421ab78772ffa52ad46e5bdfb4e5268f47 (diff)
downloadspark-6075573a93176ee8c071888e4525043d9e73b061.tar.gz
spark-6075573a93176ee8c071888e4525043d9e73b061.tar.bz2
spark-6075573a93176ee8c071888e4525043d9e73b061.zip
[SPARK-6847][CORE][STREAMING] Fix stack overflow issue when updateStateByKey is followed by a checkpointed dstream
Add a local property to indicate if checkpointing all RDDs that are marked with the checkpoint flag, and enable it in Streaming Author: Shixiong Zhu <shixiong@databricks.com> Closes #10934 from zsxwing/recursive-checkpoint.
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala19
-rw-r--r--core/src/test/scala/org/apache/spark/CheckpointSuite.scala21
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala5
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala7
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala69
5 files changed, 119 insertions, 2 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index be47172581..e8157cf4eb 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -1542,6 +1542,15 @@ abstract class RDD[T: ClassTag](
private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None
+ // Whether to checkpoint all ancestor RDDs that are marked for checkpointing. By default,
+ // we stop as soon as we find the first such RDD, an optimization that allows us to write
+ // less data but is not safe for all workloads. E.g. in streaming we may checkpoint both
+ // an RDD and its parent in every batch, in which case the parent may never be checkpointed
+ // and its lineage never truncated, leading to OOMs in the long run (SPARK-6847).
+ private val checkpointAllMarkedAncestors =
+ Option(sc.getLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS))
+ .map(_.toBoolean).getOrElse(false)
+
/** Returns the first parent RDD */
protected[spark] def firstParent[U: ClassTag]: RDD[U] = {
dependencies.head.rdd.asInstanceOf[RDD[U]]
@@ -1585,6 +1594,13 @@ abstract class RDD[T: ClassTag](
if (!doCheckpointCalled) {
doCheckpointCalled = true
if (checkpointData.isDefined) {
+ if (checkpointAllMarkedAncestors) {
+ // TODO We can collect all the RDDs that needs to be checkpointed, and then checkpoint
+ // them in parallel.
+ // Checkpoint parents first because our lineage will be truncated after we
+ // checkpoint ourselves
+ dependencies.foreach(_.rdd.doCheckpoint())
+ }
checkpointData.get.checkpoint()
} else {
dependencies.foreach(_.rdd.doCheckpoint())
@@ -1704,6 +1720,9 @@ abstract class RDD[T: ClassTag](
*/
object RDD {
+ private[spark] val CHECKPOINT_ALL_MARKED_ANCESTORS =
+ "spark.checkpoint.checkpointAllMarkedAncestors"
+
// The following implicit functions were in SparkContext before 1.3 and users had to
// `import SparkContext._` to enable them. Now we move them here to make the compiler find
// them automatically. However, we still keep the old functions in SparkContext for backward
diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
index 390764ba24..ce35856dce 100644
--- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
@@ -512,6 +512,27 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS
assert(rdd.isCheckpointedAndMaterialized === true)
assert(rdd.partitions.size === 0)
}
+
+ runTest("checkpointAllMarkedAncestors") { reliableCheckpoint: Boolean =>
+ testCheckpointAllMarkedAncestors(reliableCheckpoint, checkpointAllMarkedAncestors = true)
+ testCheckpointAllMarkedAncestors(reliableCheckpoint, checkpointAllMarkedAncestors = false)
+ }
+
+ private def testCheckpointAllMarkedAncestors(
+ reliableCheckpoint: Boolean, checkpointAllMarkedAncestors: Boolean): Unit = {
+ sc.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, checkpointAllMarkedAncestors.toString)
+ try {
+ val rdd1 = sc.parallelize(1 to 10)
+ checkpoint(rdd1, reliableCheckpoint)
+ val rdd2 = rdd1.map(_ + 1)
+ checkpoint(rdd2, reliableCheckpoint)
+ rdd2.count()
+ assert(rdd1.isCheckpointed === checkpointAllMarkedAncestors)
+ assert(rdd2.isCheckpointed === true)
+ } finally {
+ sc.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, null)
+ }
+ }
}
/** RDD partition that has large serialized size. */
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
index a5a01e7763..a3ad5eaa40 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
@@ -20,6 +20,7 @@ package org.apache.spark.streaming.scheduler
import scala.util.{Failure, Success, Try}
import org.apache.spark.{Logging, SparkEnv}
+import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.{Checkpoint, CheckpointWriter, Time}
import org.apache.spark.streaming.util.RecurringTimer
import org.apache.spark.util.{Clock, EventLoop, ManualClock, Utils}
@@ -243,6 +244,10 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
// Example: BlockRDDs are created in this thread, and it needs to access BlockManager
// Update: This is probably redundant after threadlocal stuff in SparkEnv has been removed.
SparkEnv.set(ssc.env)
+
+ // Checkpoint all RDDs marked for checkpointing to ensure their lineages are
+ // truncated periodically. Otherwise, we may run into stack overflows (SPARK-6847).
+ ssc.sparkContext.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, "true")
Try {
jobScheduler.receiverTracker.allocateBlocksToBatch(time) // allocate received blocks to batch
graph.generateJobs(time) // generate jobs using allocated block
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
index 9535c8e5b7..3fed3d8835 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
@@ -23,10 +23,10 @@ import scala.collection.JavaConverters._
import scala.util.Failure
import org.apache.spark.Logging
-import org.apache.spark.rdd.PairRDDFunctions
+import org.apache.spark.rdd.{PairRDDFunctions, RDD}
import org.apache.spark.streaming._
import org.apache.spark.streaming.ui.UIUtils
-import org.apache.spark.util.{EventLoop, ThreadUtils, Utils}
+import org.apache.spark.util.{EventLoop, ThreadUtils}
private[scheduler] sealed trait JobSchedulerEvent
@@ -210,6 +210,9 @@ class JobScheduler(val ssc: StreamingContext) extends Logging {
s"""Streaming job from <a href="$batchUrl">$batchLinkText</a>""")
ssc.sc.setLocalProperty(BATCH_TIME_PROPERTY_KEY, job.time.milliseconds.toString)
ssc.sc.setLocalProperty(OUTPUT_OP_ID_PROPERTY_KEY, job.outputOpId.toString)
+ // Checkpoint all RDDs marked for checkpointing to ensure their lineages are
+ // truncated periodically. Otherwise, we may run into stack overflows (SPARK-6847).
+ ssc.sparkContext.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, "true")
// We need to assign `eventLoop` to a temp variable. Otherwise, because
// `JobScheduler.stop(false)` may set `eventLoop` to null when this method is running, then
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
index 4a6b91fbc7..786703eb9a 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
@@ -821,6 +821,75 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester
checkpointWriter.stop()
}
+ test("SPARK-6847: stack overflow when updateStateByKey is followed by a checkpointed dstream") {
+ // In this test, there are two updateStateByKey operators. The RDD DAG is as follows:
+ //
+ // batch 1 batch 2 batch 3 ...
+ //
+ // 1) input rdd input rdd input rdd
+ // | | |
+ // v v v
+ // 2) cogroup rdd ---> cogroup rdd ---> cogroup rdd ...
+ // | / | / |
+ // v / v / v
+ // 3) map rdd --- map rdd --- map rdd ...
+ // | | |
+ // v v v
+ // 4) cogroup rdd ---> cogroup rdd ---> cogroup rdd ...
+ // | / | / |
+ // v / v / v
+ // 5) map rdd --- map rdd --- map rdd ...
+ //
+ // Every batch depends on its previous batch, so "updateStateByKey" needs to do checkpoint to
+ // break the RDD chain. However, before SPARK-6847, when the state RDD (layer 5) of the second
+ // "updateStateByKey" does checkpoint, it won't checkpoint the state RDD (layer 3) of the first
+ // "updateStateByKey" (Note: "updateStateByKey" has already marked that its state RDD (layer 3)
+ // should be checkpointed). Hence, the connections between layer 2 and layer 3 won't be broken
+ // and the RDD chain will grow infinitely and cause StackOverflow.
+ //
+ // Therefore SPARK-6847 introduces "spark.checkpoint.checkpointAllMarked" to force checkpointing
+ // all marked RDDs in the DAG to resolve this issue. (For the previous example, it will break
+ // connections between layer 2 and layer 3)
+ ssc = new StreamingContext(master, framework, batchDuration)
+ val batchCounter = new BatchCounter(ssc)
+ ssc.checkpoint(checkpointDir)
+ val inputDStream = new CheckpointInputDStream(ssc)
+ val updateFunc = (values: Seq[Int], state: Option[Int]) => {
+ Some(values.sum + state.getOrElse(0))
+ }
+ @volatile var shouldCheckpointAllMarkedRDDs = false
+ @volatile var rddsCheckpointed = false
+ inputDStream.map(i => (i, i))
+ .updateStateByKey(updateFunc).checkpoint(batchDuration)
+ .updateStateByKey(updateFunc).checkpoint(batchDuration)
+ .foreachRDD { rdd =>
+ /**
+ * Find all RDDs that are marked for checkpointing in the specified RDD and its ancestors.
+ */
+ def findAllMarkedRDDs(rdd: RDD[_]): List[RDD[_]] = {
+ val markedRDDs = rdd.dependencies.flatMap(dep => findAllMarkedRDDs(dep.rdd)).toList
+ if (rdd.checkpointData.isDefined) {
+ rdd :: markedRDDs
+ } else {
+ markedRDDs
+ }
+ }
+
+ shouldCheckpointAllMarkedRDDs =
+ Option(rdd.sparkContext.getLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS)).
+ map(_.toBoolean).getOrElse(false)
+
+ val stateRDDs = findAllMarkedRDDs(rdd)
+ rdd.count()
+ // Check the two state RDDs are both checkpointed
+ rddsCheckpointed = stateRDDs.size == 2 && stateRDDs.forall(_.isCheckpointed)
+ }
+ ssc.start()
+ batchCounter.waitUntilBatchesCompleted(1, 10000)
+ assert(shouldCheckpointAllMarkedRDDs === true)
+ assert(rddsCheckpointed === true)
+ }
+
/**
* Advances the manual clock on the streaming scheduler by given number of batches.
* It also waits for the expected amount of time for each batch.