diff options
author | Shixiong Zhu <shixiong@databricks.com> | 2016-02-01 11:02:17 -0800 |
---|---|---|
committer | Andrew Or <andrew@databricks.com> | 2016-02-01 11:02:17 -0800 |
commit | 6075573a93176ee8c071888e4525043d9e73b061 (patch) | |
tree | 45cdc80c2f00b52ac5b5f4aaabb04e3e822557fe /streaming/src/test/scala/org | |
parent | c1da4d421ab78772ffa52ad46e5bdfb4e5268f47 (diff) | |
download | spark-6075573a93176ee8c071888e4525043d9e73b061.tar.gz spark-6075573a93176ee8c071888e4525043d9e73b061.tar.bz2 spark-6075573a93176ee8c071888e4525043d9e73b061.zip |
[SPARK-6847][CORE][STREAMING] Fix stack overflow issue when updateStateByKey is followed by a checkpointed dstream
Add a local property to indicate if checkpointing all RDDs that are marked with the checkpoint flag, and enable it in Streaming
Author: Shixiong Zhu <shixiong@databricks.com>
Closes #10934 from zsxwing/recursive-checkpoint.
Diffstat (limited to 'streaming/src/test/scala/org')
-rw-r--r-- | streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala | 69 |
1 files changed, 69 insertions, 0 deletions
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 4a6b91fbc7..786703eb9a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -821,6 +821,75 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester checkpointWriter.stop() } + test("SPARK-6847: stack overflow when updateStateByKey is followed by a checkpointed dstream") { + // In this test, there are two updateStateByKey operators. The RDD DAG is as follows: + // + // batch 1 batch 2 batch 3 ... + // + // 1) input rdd input rdd input rdd + // | | | + // v v v + // 2) cogroup rdd ---> cogroup rdd ---> cogroup rdd ... + // | / | / | + // v / v / v + // 3) map rdd --- map rdd --- map rdd ... + // | | | + // v v v + // 4) cogroup rdd ---> cogroup rdd ---> cogroup rdd ... + // | / | / | + // v / v / v + // 5) map rdd --- map rdd --- map rdd ... + // + // Every batch depends on its previous batch, so "updateStateByKey" needs to do checkpoint to + // break the RDD chain. However, before SPARK-6847, when the state RDD (layer 5) of the second + // "updateStateByKey" does checkpoint, it won't checkpoint the state RDD (layer 3) of the first + // "updateStateByKey" (Note: "updateStateByKey" has already marked that its state RDD (layer 3) + // should be checkpointed). Hence, the connections between layer 2 and layer 3 won't be broken + // and the RDD chain will grow infinitely and cause StackOverflow. + // + // Therefore SPARK-6847 introduces "spark.checkpoint.checkpointAllMarked" to force checkpointing + // all marked RDDs in the DAG to resolve this issue. (For the previous example, it will break + // connections between layer 2 and layer 3) + ssc = new StreamingContext(master, framework, batchDuration) + val batchCounter = new BatchCounter(ssc) + ssc.checkpoint(checkpointDir) + val inputDStream = new CheckpointInputDStream(ssc) + val updateFunc = (values: Seq[Int], state: Option[Int]) => { + Some(values.sum + state.getOrElse(0)) + } + @volatile var shouldCheckpointAllMarkedRDDs = false + @volatile var rddsCheckpointed = false + inputDStream.map(i => (i, i)) + .updateStateByKey(updateFunc).checkpoint(batchDuration) + .updateStateByKey(updateFunc).checkpoint(batchDuration) + .foreachRDD { rdd => + /** + * Find all RDDs that are marked for checkpointing in the specified RDD and its ancestors. + */ + def findAllMarkedRDDs(rdd: RDD[_]): List[RDD[_]] = { + val markedRDDs = rdd.dependencies.flatMap(dep => findAllMarkedRDDs(dep.rdd)).toList + if (rdd.checkpointData.isDefined) { + rdd :: markedRDDs + } else { + markedRDDs + } + } + + shouldCheckpointAllMarkedRDDs = + Option(rdd.sparkContext.getLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS)). + map(_.toBoolean).getOrElse(false) + + val stateRDDs = findAllMarkedRDDs(rdd) + rdd.count() + // Check the two state RDDs are both checkpointed + rddsCheckpointed = stateRDDs.size == 2 && stateRDDs.forall(_.isCheckpointed) + } + ssc.start() + batchCounter.waitUntilBatchesCompleted(1, 10000) + assert(shouldCheckpointAllMarkedRDDs === true) + assert(rddsCheckpointed === true) + } + /** * Advances the manual clock on the streaming scheduler by given number of batches. * It also waits for the expected amount of time for each batch. |