From f4346f612b6798517153a786f9172cf41618d34d Mon Sep 17 00:00:00 2001 From: jhu-chang Date: Thu, 17 Dec 2015 17:53:15 -0800 Subject: [SPARK-11749][STREAMING] Duplicate creating the RDD in file stream when recovering from checkpoint data Add a transient flag `DStream.restoredFromCheckpointData` to control the restore processing in DStream to avoid duplicate works: check this flag first in `DStream.restoreCheckpointData`, only when `false`, the restore process will be executed. Author: jhu-chang Closes #9765 from jhu-chang/SPARK-11749. --- .../apache/spark/streaming/dstream/DStream.scala | 15 ++++-- .../apache/spark/streaming/CheckpointSuite.scala | 56 ++++++++++++++++++++-- 2 files changed, 62 insertions(+), 9 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index 1a6edf9473..91a43e14a8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -97,6 +97,8 @@ abstract class DStream[T: ClassTag] ( private[streaming] val mustCheckpoint = false private[streaming] var checkpointDuration: Duration = null private[streaming] val checkpointData = new DStreamCheckpointData(this) + @transient + private var restoredFromCheckpointData = false // Reference to whole DStream graph private[streaming] var graph: DStreamGraph = null @@ -507,11 +509,14 @@ abstract class DStream[T: ClassTag] ( * override the updateCheckpointData() method would also need to override this method. */ private[streaming] def restoreCheckpointData() { - // Create RDDs from the checkpoint data - logInfo("Restoring checkpoint data") - checkpointData.restore() - dependencies.foreach(_.restoreCheckpointData()) - logInfo("Restored checkpoint data") + if (!restoredFromCheckpointData) { + // Create RDDs from the checkpoint data + logInfo("Restoring checkpoint data") + checkpointData.restore() + dependencies.foreach(_.restoreCheckpointData()) + restoredFromCheckpointData = true + logInfo("Restored checkpoint data") + } } @throws(classOf[IOException]) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index cd28d3cf40..f5f446f14a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming -import java.io.{ObjectOutputStream, ByteArrayOutputStream, ByteArrayInputStream, File} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, ObjectOutputStream} import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.reflect.ClassTag @@ -34,9 +34,30 @@ import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite, TestUtils} -import org.apache.spark.streaming.dstream.{DStream, FileInputDStream} +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.scheduler._ -import org.apache.spark.util.{MutableURLClassLoader, Clock, ManualClock, Utils} +import org.apache.spark.util.{Clock, ManualClock, MutableURLClassLoader, Utils} + +/** + * A input stream that records the times of restore() invoked + */ +private[streaming] +class CheckpointInputDStream(ssc_ : StreamingContext) extends InputDStream[Int](ssc_) { + protected[streaming] override val checkpointData = new FileInputDStreamCheckpointData + override def start(): Unit = { } + override def stop(): Unit = { } + override def compute(time: Time): Option[RDD[Int]] = Some(ssc.sc.makeRDD(Seq(1))) + private[streaming] + class FileInputDStreamCheckpointData extends DStreamCheckpointData(this) { + @transient + var restoredTimes = 0 + override def restore() { + restoredTimes += 1 + super.restore() + } + } +} /** * A trait of that can be mixed in to get methods for testing DStream operations under @@ -110,7 +131,7 @@ trait DStreamCheckpointTester { self: SparkFunSuite => new StreamingContext(SparkContext.getOrCreate(conf), batchDuration) } - private def generateOutput[V: ClassTag]( + protected def generateOutput[V: ClassTag]( ssc: StreamingContext, targetBatchTime: Time, checkpointDir: String, @@ -715,6 +736,33 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester { } } + test("DStreamCheckpointData.restore invoking times") { + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + ssc.checkpoint(checkpointDir) + val inputDStream = new CheckpointInputDStream(ssc) + val checkpointData = inputDStream.checkpointData + val mappedDStream = inputDStream.map(_ + 100) + val outputStream = new TestOutputStreamWithPartitions(mappedDStream) + outputStream.register() + // do two more times output + mappedDStream.foreachRDD(rdd => rdd.count()) + mappedDStream.foreachRDD(rdd => rdd.count()) + assert(checkpointData.restoredTimes === 0) + val batchDurationMillis = ssc.progressListener.batchDuration + generateOutput(ssc, Time(batchDurationMillis * 3), checkpointDir, stopSparkContext = true) + assert(checkpointData.restoredTimes === 0) + } + logInfo("*********** RESTARTING ************") + withStreamingContext(new StreamingContext(checkpointDir)) { ssc => + val checkpointData = + ssc.graph.getInputStreams().head.asInstanceOf[CheckpointInputDStream].checkpointData + assert(checkpointData.restoredTimes === 1) + ssc.start() + ssc.stop() + assert(checkpointData.restoredTimes === 1) + } + } + // This tests whether spark can deserialize array object // refer to SPARK-5569 test("recovery from checkpoint contains array object") { -- cgit v1.2.3