aboutsummaryrefslogtreecommitdiff
path: root/streaming
diff options
context:
space:
mode:
authorjhu-chang <gt.hu.chang@gmail.com>2015-12-17 17:53:15 -0800
committerShixiong Zhu <shixiong@databricks.com>2015-12-17 17:53:15 -0800
commitf4346f612b6798517153a786f9172cf41618d34d (patch)
tree1d79709b1d11b14b4df9955f99f7e1dc5eda826f /streaming
parent658f66e6208a52367e3b43a6fee9c90f33fb6226 (diff)
downloadspark-f4346f612b6798517153a786f9172cf41618d34d.tar.gz
spark-f4346f612b6798517153a786f9172cf41618d34d.tar.bz2
spark-f4346f612b6798517153a786f9172cf41618d34d.zip
[SPARK-11749][STREAMING] Duplicate creating the RDD in file stream when recovering from checkpoint data
Add a transient flag `DStream.restoredFromCheckpointData` to control the restore processing in DStream to avoid duplicate works: check this flag first in `DStream.restoreCheckpointData`, only when `false`, the restore process will be executed. Author: jhu-chang <gt.hu.chang@gmail.com> Closes #9765 from jhu-chang/SPARK-11749.
Diffstat (limited to 'streaming')
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala15
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala56
2 files changed, 62 insertions, 9 deletions
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
index 1a6edf9473..91a43e14a8 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala
@@ -97,6 +97,8 @@ abstract class DStream[T: ClassTag] (
private[streaming] val mustCheckpoint = false
private[streaming] var checkpointDuration: Duration = null
private[streaming] val checkpointData = new DStreamCheckpointData(this)
+ @transient
+ private var restoredFromCheckpointData = false
// Reference to whole DStream graph
private[streaming] var graph: DStreamGraph = null
@@ -507,11 +509,14 @@ abstract class DStream[T: ClassTag] (
* override the updateCheckpointData() method would also need to override this method.
*/
private[streaming] def restoreCheckpointData() {
- // Create RDDs from the checkpoint data
- logInfo("Restoring checkpoint data")
- checkpointData.restore()
- dependencies.foreach(_.restoreCheckpointData())
- logInfo("Restored checkpoint data")
+ if (!restoredFromCheckpointData) {
+ // Create RDDs from the checkpoint data
+ logInfo("Restoring checkpoint data")
+ checkpointData.restore()
+ dependencies.foreach(_.restoreCheckpointData())
+ restoredFromCheckpointData = true
+ logInfo("Restored checkpoint data")
+ }
}
@throws(classOf[IOException])
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
index cd28d3cf40..f5f446f14a 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.streaming
-import java.io.{ObjectOutputStream, ByteArrayOutputStream, ByteArrayInputStream, File}
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, ObjectOutputStream}
import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
import scala.reflect.ClassTag
@@ -34,9 +34,30 @@ import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite, TestUtils}
-import org.apache.spark.streaming.dstream.{DStream, FileInputDStream}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.streaming.dstream._
import org.apache.spark.streaming.scheduler._
-import org.apache.spark.util.{MutableURLClassLoader, Clock, ManualClock, Utils}
+import org.apache.spark.util.{Clock, ManualClock, MutableURLClassLoader, Utils}
+
+/**
+ * A input stream that records the times of restore() invoked
+ */
+private[streaming]
+class CheckpointInputDStream(ssc_ : StreamingContext) extends InputDStream[Int](ssc_) {
+ protected[streaming] override val checkpointData = new FileInputDStreamCheckpointData
+ override def start(): Unit = { }
+ override def stop(): Unit = { }
+ override def compute(time: Time): Option[RDD[Int]] = Some(ssc.sc.makeRDD(Seq(1)))
+ private[streaming]
+ class FileInputDStreamCheckpointData extends DStreamCheckpointData(this) {
+ @transient
+ var restoredTimes = 0
+ override def restore() {
+ restoredTimes += 1
+ super.restore()
+ }
+ }
+}
/**
* A trait of that can be mixed in to get methods for testing DStream operations under
@@ -110,7 +131,7 @@ trait DStreamCheckpointTester { self: SparkFunSuite =>
new StreamingContext(SparkContext.getOrCreate(conf), batchDuration)
}
- private def generateOutput[V: ClassTag](
+ protected def generateOutput[V: ClassTag](
ssc: StreamingContext,
targetBatchTime: Time,
checkpointDir: String,
@@ -715,6 +736,33 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester {
}
}
+ test("DStreamCheckpointData.restore invoking times") {
+ withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc =>
+ ssc.checkpoint(checkpointDir)
+ val inputDStream = new CheckpointInputDStream(ssc)
+ val checkpointData = inputDStream.checkpointData
+ val mappedDStream = inputDStream.map(_ + 100)
+ val outputStream = new TestOutputStreamWithPartitions(mappedDStream)
+ outputStream.register()
+ // do two more times output
+ mappedDStream.foreachRDD(rdd => rdd.count())
+ mappedDStream.foreachRDD(rdd => rdd.count())
+ assert(checkpointData.restoredTimes === 0)
+ val batchDurationMillis = ssc.progressListener.batchDuration
+ generateOutput(ssc, Time(batchDurationMillis * 3), checkpointDir, stopSparkContext = true)
+ assert(checkpointData.restoredTimes === 0)
+ }
+ logInfo("*********** RESTARTING ************")
+ withStreamingContext(new StreamingContext(checkpointDir)) { ssc =>
+ val checkpointData =
+ ssc.graph.getInputStreams().head.asInstanceOf[CheckpointInputDStream].checkpointData
+ assert(checkpointData.restoredTimes === 1)
+ ssc.start()
+ ssc.stop()
+ assert(checkpointData.restoredTimes === 1)
+ }
+ }
+
// This tests whether spark can deserialize array object
// refer to SPARK-5569
test("recovery from checkpoint contains array object") {