diff options
18 files changed, 579 insertions, 254 deletions
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 7b59a6f09e..63048d5df0 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -119,22 +119,23 @@ abstract class RDD[T: ClassManifest]( private var storageLevel: StorageLevel = StorageLevel.NONE /** Returns the first parent RDD */ - private[spark] def firstParent[U: ClassManifest] = { + protected[spark] def firstParent[U: ClassManifest] = { dependencies.head.rdd.asInstanceOf[RDD[U]] } /** Returns the `i` th parent RDD */ - private[spark] def parent[U: ClassManifest](i: Int) = dependencies(i).rdd.asInstanceOf[RDD[U]] + protected[spark] def parent[U: ClassManifest](i: Int) = dependencies(i).rdd.asInstanceOf[RDD[U]] // Variables relating to checkpointing - val isCheckpointable = true // override to set this to false to avoid checkpointing an RDD - var shouldCheckpoint = false // set to true when an RDD is marked for checkpointing - var isCheckpointInProgress = false // set to true when checkpointing is in progress - var isCheckpointed = false // set to true after checkpointing is completed + protected val isCheckpointable = true // override to set this to false to avoid checkpointing an RDD - var checkpointFile: String = null // set to the checkpoint file after checkpointing is completed - var checkpointRDD: RDD[T] = null // set to the HadoopRDD of the checkpoint file - var checkpointRDDSplits: Seq[Split] = null // set to the splits of the Hadoop RDD + protected var shouldCheckpoint = false // set to true when an RDD is marked for checkpointing + protected var isCheckpointInProgress = false // set to true when checkpointing is in progress + protected[spark] var isCheckpointed = false // set to true after checkpointing is completed + + protected[spark] var checkpointFile: String = null // set to the checkpoint file after checkpointing is completed + protected var checkpointRDD: RDD[T] = null // set to the HadoopRDD of the checkpoint file + protected var checkpointRDDSplits: Seq[Split] = null // set to the splits of the Hadoop RDD // Methods available on all RDDs: @@ -176,6 +177,9 @@ abstract class RDD[T: ClassManifest]( if (isCheckpointed || shouldCheckpoint || isCheckpointInProgress) { // do nothing } else if (isCheckpointable) { + if (sc.checkpointDir == null) { + throw new Exception("Checkpoint directory has not been set in the SparkContext.") + } shouldCheckpoint = true } else { throw new Exception(this + " cannot be checkpointed") @@ -183,6 +187,16 @@ abstract class RDD[T: ClassManifest]( } } + def getCheckpointData(): Any = { + synchronized { + if (isCheckpointed) { + checkpointFile + } else { + null + } + } + } + /** * Performs the checkpointing of this RDD by saving this . It is called by the DAGScheduler after a job * using this RDD has completed (therefore the RDD has been materialized and diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 79ceab5f4f..d7b46bee38 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -474,8 +474,10 @@ class SparkContext( /** Shut down the SparkContext. */ def stop() { - dagScheduler.stop() - dagScheduler = null + if (dagScheduler != null) { + dagScheduler.stop() + dagScheduler = null + } taskScheduler = null // TODO: Cache.stop()? env.stop() @@ -584,14 +586,15 @@ class SparkContext( * overwriting existing files may be overwritten). The directory will be deleted on exit * if indicated. */ - def setCheckpointDir(dir: String, deleteOnExit: Boolean = false) { + def setCheckpointDir(dir: String, useExisting: Boolean = false) { val path = new Path(dir) val fs = path.getFileSystem(new Configuration()) - if (fs.exists(path)) { - throw new Exception("Checkpoint directory '" + path + "' already exists.") - } else { - fs.mkdirs(path) - if (deleteOnExit) fs.deleteOnExit(path) + if (!useExisting) { + if (fs.exists(path)) { + throw new Exception("Checkpoint directory '" + path + "' already exists.") + } else { + fs.mkdirs(path) + } } checkpointDir = dir } diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala index 83a43d15cb..1643f45ffb 100644 --- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala @@ -1,6 +1,6 @@ package spark.streaming -import spark.Utils +import spark.{Logging, Utils} import org.apache.hadoop.fs.{FileUtil, Path} import org.apache.hadoop.conf.Configuration @@ -8,13 +8,14 @@ import org.apache.hadoop.conf.Configuration import java.io.{InputStream, ObjectStreamClass, ObjectInputStream, ObjectOutputStream} -class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) extends Serializable { +class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) + extends Logging with Serializable { val master = ssc.sc.master val framework = ssc.sc.jobName val sparkHome = ssc.sc.sparkHome val jars = ssc.sc.jars val graph = ssc.graph - val checkpointFile = ssc.checkpointFile + val checkpointDir = ssc.checkpointDir val checkpointInterval = ssc.checkpointInterval validate() @@ -24,22 +25,25 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) ext assert(framework != null, "Checkpoint.framework is null") assert(graph != null, "Checkpoint.graph is null") assert(checkpointTime != null, "Checkpoint.checkpointTime is null") + logInfo("Checkpoint for time " + checkpointTime + " validated") } - def saveToFile(file: String = checkpointFile) { - val path = new Path(file) + def save(path: String) { + val file = new Path(path, "graph") val conf = new Configuration() - val fs = path.getFileSystem(conf) - if (fs.exists(path)) { - val bkPath = new Path(path.getParent, path.getName + ".bk") - FileUtil.copy(fs, path, fs, bkPath, true, true, conf) - //logInfo("Moved existing checkpoint file to " + bkPath) + val fs = file.getFileSystem(conf) + logDebug("Saved checkpoint for time " + checkpointTime + " to file '" + file + "'") + if (fs.exists(file)) { + val bkFile = new Path(file.getParent, file.getName + ".bk") + FileUtil.copy(fs, file, fs, bkFile, true, true, conf) + logDebug("Moved existing checkpoint file to " + bkFile) } - val fos = fs.create(path) + val fos = fs.create(file) val oos = new ObjectOutputStream(fos) oos.writeObject(this) oos.close() fs.close() + logInfo("Saved checkpoint for time " + checkpointTime + " to file '" + file + "'") } def toBytes(): Array[Byte] = { @@ -48,33 +52,41 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) ext } } -object Checkpoint { +object Checkpoint extends Logging { - def loadFromFile(file: String): Checkpoint = { - try { - val path = new Path(file) - val conf = new Configuration() - val fs = path.getFileSystem(conf) - if (!fs.exists(path)) { - throw new Exception("Checkpoint file '" + file + "' does not exist") + def load(path: String): Checkpoint = { + + val fs = new Path(path).getFileSystem(new Configuration()) + val attempts = Seq(new Path(path, "graph"), new Path(path, "graph.bk"), new Path(path), new Path(path + ".bk")) + var detailedLog: String = "" + + attempts.foreach(file => { + if (fs.exists(file)) { + logInfo("Attempting to load checkpoint from file '" + file + "'") + try { + val fis = fs.open(file) + // ObjectInputStream uses the last defined user-defined class loader in the stack + // to find classes, which maybe the wrong class loader. Hence, a inherited version + // of ObjectInputStream is used to explicitly use the current thread's default class + // loader to find and load classes. This is a well know Java issue and has popped up + // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627) + val ois = new ObjectInputStreamWithLoader(fis, Thread.currentThread().getContextClassLoader) + val cp = ois.readObject.asInstanceOf[Checkpoint] + ois.close() + fs.close() + cp.validate() + logInfo("Checkpoint successfully loaded from file '" + file + "'") + return cp + } catch { + case e: Exception => + logError("Error loading checkpoint from file '" + file + "'", e) + } + } else { + logWarning("Could not load checkpoint from file '" + file + "' as it does not exist") } - val fis = fs.open(path) - // ObjectInputStream uses the last defined user-defined class loader in the stack - // to find classes, which maybe the wrong class loader. Hence, a inherited version - // of ObjectInputStream is used to explicitly use the current thread's default class - // loader to find and load classes. This is a well know Java issue and has popped up - // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627) - val ois = new ObjectInputStreamWithLoader(fis, Thread.currentThread().getContextClassLoader) - val cp = ois.readObject.asInstanceOf[Checkpoint] - ois.close() - fs.close() - cp.validate() - cp - } catch { - case e: Exception => - e.printStackTrace() - throw new Exception("Could not load checkpoint file '" + file + "'", e) - } + + }) + throw new Exception("Could not load checkpoint from path '" + path + "'") } def fromBytes(bytes: Array[Byte]): Checkpoint = { diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index a4921bb1a2..922ff5088d 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -1,6 +1,7 @@ package spark.streaming -import spark.streaming.StreamingContext._ +import StreamingContext._ +import Time._ import spark._ import spark.SparkContext._ @@ -12,7 +13,9 @@ import scala.collection.mutable.HashMap import java.util.concurrent.ArrayBlockingQueue import java.io.{ObjectInputStream, IOException, ObjectOutputStream} -import scala.Some + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.conf.Configuration abstract class DStream[T: ClassManifest] (@transient var ssc: StreamingContext) extends Serializable with Logging { @@ -41,53 +44,56 @@ extends Serializable with Logging { */ // RDDs generated, marked as protected[streaming] so that testsuites can access it - protected[streaming] val generatedRDDs = new HashMap[Time, RDD[T]] () + @transient + protected[streaming] var generatedRDDs = new HashMap[Time, RDD[T]] () // Time zero for the DStream - protected var zeroTime: Time = null + protected[streaming] var zeroTime: Time = null // Duration for which the DStream will remember each RDD created - protected var rememberDuration: Time = null + protected[streaming] var rememberDuration: Time = null // Storage level of the RDDs in the stream - protected var storageLevel: StorageLevel = StorageLevel.NONE + protected[streaming] var storageLevel: StorageLevel = StorageLevel.NONE - // Checkpoint level and checkpoint interval - protected var checkpointLevel: StorageLevel = StorageLevel.NONE // NONE means don't checkpoint - protected var checkpointInterval: Time = null + // Checkpoint details + protected[streaming] val mustCheckpoint = false + protected[streaming] var checkpointInterval: Time = null + protected[streaming] val checkpointData = new HashMap[Time, Any]() // Reference to whole DStream graph - protected var graph: DStreamGraph = null + protected[streaming] var graph: DStreamGraph = null def isInitialized = (zeroTime != null) // Duration for which the DStream requires its parent DStream to remember each RDD created def parentRememberDuration = rememberDuration - // Change this RDD's storage level - def persist( - storageLevel: StorageLevel, - checkpointLevel: StorageLevel, - checkpointInterval: Time): DStream[T] = { - if (this.storageLevel != StorageLevel.NONE && this.storageLevel != storageLevel) { - // TODO: not sure this is necessary for DStreams + // Set caching level for the RDDs created by this DStream + def persist(level: StorageLevel): DStream[T] = { + if (this.isInitialized) { throw new UnsupportedOperationException( - "Cannot change storage level of an DStream after it was already assigned a level") + "Cannot change storage level of an DStream after streaming context has started") } - this.storageLevel = storageLevel - this.checkpointLevel = checkpointLevel - this.checkpointInterval = checkpointInterval + this.storageLevel = level this } - // Set caching level for the RDDs created by this DStream - def persist(newLevel: StorageLevel): DStream[T] = persist(newLevel, StorageLevel.NONE, null) - def persist(): DStream[T] = persist(StorageLevel.MEMORY_ONLY) // Turn on the default caching level for this RDD def cache(): DStream[T] = persist() + def checkpoint(interval: Time): DStream[T] = { + if (isInitialized) { + throw new UnsupportedOperationException( + "Cannot change checkpoint interval of an DStream after streaming context has started") + } + persist() + checkpointInterval = interval + this + } + /** * This method initializes the DStream by setting the "zero" time, based on which * the validity of future times is calculated. This method also recursively initializes @@ -99,7 +105,67 @@ extends Serializable with Logging { + ", cannot initialize it again to " + time) } zeroTime = time + + // Set the checkpoint interval to be slideTime or 10 seconds, which ever is larger + if (mustCheckpoint && checkpointInterval == null) { + checkpointInterval = slideTime.max(Seconds(10)) + logInfo("Checkpoint interval automatically set to " + checkpointInterval) + } + + // Set the minimum value of the rememberDuration if not already set + var minRememberDuration = slideTime + if (checkpointInterval != null && minRememberDuration <= checkpointInterval) { + minRememberDuration = checkpointInterval * 2 // times 2 just to be sure that the latest checkpoint is not forgetten + } + if (rememberDuration == null || rememberDuration < minRememberDuration) { + rememberDuration = minRememberDuration + } + + // Initialize the dependencies dependencies.foreach(_.initialize(zeroTime)) + } + + protected[streaming] def validate() { + assert( + !mustCheckpoint || checkpointInterval != null, + "The checkpoint interval for " + this.getClass.getSimpleName + " has not been set. " + + " Please use DStream.checkpoint() to set the interval." + ) + + assert( + checkpointInterval == null || checkpointInterval >= slideTime, + "The checkpoint interval for " + this.getClass.getSimpleName + " has been set to " + + checkpointInterval + " which is lower than its slide time (" + slideTime + "). " + + "Please set it to at least " + slideTime + "." + ) + + assert( + checkpointInterval == null || checkpointInterval.isMultipleOf(slideTime), + "The checkpoint interval for " + this.getClass.getSimpleName + " has been set to " + + checkpointInterval + " which not a multiple of its slide time (" + slideTime + "). " + + "Please set it to a multiple " + slideTime + "." + ) + + assert( + checkpointInterval == null || storageLevel != StorageLevel.NONE, + "" + this.getClass.getSimpleName + " has been marked for checkpointing but the storage " + + "level has not been set to enable persisting. Please use DStream.persist() to set the " + + "storage level to use memory for better checkpointing performance." + ) + + assert( + checkpointInterval == null || rememberDuration > checkpointInterval, + "The remember duration for " + this.getClass.getSimpleName + " has been set to " + + rememberDuration + " which is not more than the checkpoint interval (" + + checkpointInterval + "). Please set it to higher than " + checkpointInterval + "." + ) + + dependencies.foreach(_.validate()) + + logInfo("Slide time = " + slideTime) + logInfo("Storage level = " + storageLevel) + logInfo("Checkpoint interval = " + checkpointInterval) + logInfo("Remember duration = " + rememberDuration) logInfo("Initialized " + this) } @@ -120,17 +186,12 @@ extends Serializable with Logging { dependencies.foreach(_.setGraph(graph)) } - protected[streaming] def setRememberDuration(duration: Time = slideTime) { - if (duration == null) { - throw new Exception("Duration for remembering RDDs cannot be set to null for " + this) - } else if (rememberDuration != null && duration < rememberDuration) { - logWarning("Duration for remembering RDDs cannot be reduced from " + rememberDuration - + " to " + duration + " for " + this) - } else { + protected[streaming] def setRememberDuration(duration: Time) { + if (duration != null && duration > rememberDuration) { rememberDuration = duration - dependencies.foreach(_.setRememberDuration(parentRememberDuration)) logInfo("Duration for remembering RDDs set to " + rememberDuration + " for " + this) } + dependencies.foreach(_.setRememberDuration(parentRememberDuration)) } /** This method checks whether the 'time' is valid wrt slideTime for generating RDD */ @@ -145,10 +206,10 @@ extends Serializable with Logging { } /** - * This method either retrieves a precomputed RDD of this DStream, - * or computes the RDD (if the time is valid) + * Retrieves a precomputed RDD of this DStream, or computes the RDD. This is an internal + * method that should not be called directly. */ - def getOrCompute(time: Time): Option[RDD[T]] = { + protected[streaming] def getOrCompute(time: Time): Option[RDD[T]] = { // If this DStream was not initialized (i.e., zeroTime not set), then do it // If RDD was already generated, then retrieve it from HashMap generatedRDDs.get(time) match { @@ -163,12 +224,13 @@ extends Serializable with Logging { if (isTimeValid(time)) { compute(time) match { case Some(newRDD) => - if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) { - newRDD.persist(checkpointLevel) - logInfo("Persisting " + newRDD + " to " + checkpointLevel + " at time " + time) - } else if (storageLevel != StorageLevel.NONE) { + if (storageLevel != StorageLevel.NONE) { newRDD.persist(storageLevel) - logInfo("Persisting " + newRDD + " to " + storageLevel + " at time " + time) + logInfo("Persisting RDD for time " + time + " to " + storageLevel + " at time " + time) + } + if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) { + newRDD.checkpoint() + logInfo("Marking RDD for time " + time + " for checkpointing at time " + time) } generatedRDDs.put(time, newRDD) Some(newRDD) @@ -183,10 +245,12 @@ extends Serializable with Logging { } /** - * This method generates a SparkStreaming job for the given time - * and may required to be overriden by subclasses + * Generates a SparkStreaming job for the given time. This is an internal method that + * should not be called directly. This default implementation creates a job + * that materializes the corresponding RDD. Subclasses of DStream may override this + * (eg. PerRDDForEachDStream). */ - def generateJob(time: Time): Option[Job] = { + protected[streaming] def generateJob(time: Time): Option[Job] = { getOrCompute(time) match { case Some(rdd) => { val jobFunc = () => { @@ -199,20 +263,75 @@ extends Serializable with Logging { } } - def forgetOldRDDs(time: Time) { + /** + * Dereferences RDDs that are older than rememberDuration. + */ + protected[streaming] def forgetOldRDDs(time: Time) { val keys = generatedRDDs.keys var numForgotten = 0 keys.foreach(t => { if (t <= (time - rememberDuration)) { generatedRDDs.remove(t) numForgotten += 1 - //logInfo("Forgot RDD of time " + t + " from " + this) + logInfo("Forgot RDD of time " + t + " from " + this) } }) logInfo("Forgot " + numForgotten + " RDDs from " + this) dependencies.foreach(_.forgetOldRDDs(time)) } + /** + * Refreshes the list of checkpointed RDDs that will be saved along with checkpoint of + * this stream. This is an internal method that should not be called directly. This is + * a default implementation that saves only the file names of the checkpointed RDDs to + * checkpointData. Subclasses of DStream (especially those of InputDStream) may override + * this method to save custom checkpoint data. + */ + protected[streaming] def updateCheckpointData(currentTime: Time) { + val newCheckpointData = generatedRDDs.filter(_._2.getCheckpointData() != null) + .map(x => (x._1, x._2.getCheckpointData())) + val oldCheckpointData = checkpointData.clone() + if (newCheckpointData.size > 0) { + checkpointData.clear() + checkpointData ++= newCheckpointData + } + + dependencies.foreach(_.updateCheckpointData(currentTime)) + + newCheckpointData.foreach { + case (time, data) => { logInfo("Added checkpointed RDD for time " + time + " to stream checkpoint") } + } + + if (newCheckpointData.size > 0) { + (oldCheckpointData -- newCheckpointData.keySet).foreach { + case (time, data) => { + val path = new Path(data.toString) + val fs = path.getFileSystem(new Configuration()) + fs.delete(path, true) + logInfo("Deleted checkpoint file '" + path + "' for time " + time) + } + } + } + logInfo("Updated checkpoint data") + } + + /** + * Restores the RDDs in generatedRDDs from the checkpointData. This is an internal method + * that should not be called directly. This is a default implementation that recreates RDDs + * from the checkpoint file names stored in checkpointData. Subclasses of DStream that + * override the updateCheckpointData() method would also need to override this method. + */ + protected[streaming] def restoreCheckpointData() { + logInfo("Restoring checkpoint data from " + checkpointData.size + " checkpointed RDDs") + checkpointData.foreach { + case(time, data) => { + logInfo("Restoring checkpointed RDD for time " + time + " from file") + generatedRDDs += ((time, ssc.sc.objectFile[T](data.toString))) + } + } + dependencies.foreach(_.restoreCheckpointData()) + } + @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { logDebug(this.getClass().getSimpleName + ".writeObject used") @@ -239,6 +358,7 @@ extends Serializable with Logging { private def readObject(ois: ObjectInputStream) { logDebug(this.getClass().getSimpleName + ".readObject used") ois.defaultReadObject() + generatedRDDs = new HashMap[Time, RDD[T]] () } /** diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala index ac44d7a2a6..246522838a 100644 --- a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala @@ -4,7 +4,7 @@ import java.io.{ObjectInputStream, IOException, ObjectOutputStream} import collection.mutable.ArrayBuffer import spark.Logging -final class DStreamGraph extends Serializable with Logging { +final private[streaming] class DStreamGraph extends Serializable with Logging { initLogging() private val inputStreams = new ArrayBuffer[InputDStream[_]]() @@ -15,23 +15,20 @@ final class DStreamGraph extends Serializable with Logging { private[streaming] var rememberDuration: Time = null private[streaming] var checkpointInProgress = false - def start(time: Time) { + private[streaming] def start(time: Time) { this.synchronized { if (zeroTime != null) { throw new Exception("DStream graph computation already started") } zeroTime = time outputStreams.foreach(_.initialize(zeroTime)) - outputStreams.foreach(_.setRememberDuration()) // first set the rememberDuration to default values - if (rememberDuration != null) { - // if custom rememberDuration has been provided, set the rememberDuration - outputStreams.foreach(_.setRememberDuration(rememberDuration)) - } + outputStreams.foreach(_.setRememberDuration(rememberDuration)) + outputStreams.foreach(_.validate) inputStreams.par.foreach(_.start()) } } - def stop() { + private[streaming] def stop() { this.synchronized { inputStreams.par.foreach(_.stop()) } @@ -43,7 +40,7 @@ final class DStreamGraph extends Serializable with Logging { } } - def setBatchDuration(duration: Time) { + private[streaming] def setBatchDuration(duration: Time) { this.synchronized { if (batchDuration != null) { throw new Exception("Batch duration already set as " + batchDuration + @@ -53,7 +50,7 @@ final class DStreamGraph extends Serializable with Logging { batchDuration = duration } - def setRememberDuration(duration: Time) { + private[streaming] def setRememberDuration(duration: Time) { this.synchronized { if (rememberDuration != null) { throw new Exception("Batch duration already set as " + batchDuration + @@ -63,37 +60,49 @@ final class DStreamGraph extends Serializable with Logging { rememberDuration = duration } - def addInputStream(inputStream: InputDStream[_]) { + private[streaming] def addInputStream(inputStream: InputDStream[_]) { this.synchronized { inputStream.setGraph(this) inputStreams += inputStream } } - def addOutputStream(outputStream: DStream[_]) { + private[streaming] def addOutputStream(outputStream: DStream[_]) { this.synchronized { outputStream.setGraph(this) outputStreams += outputStream } } - def getInputStreams() = inputStreams.toArray + private[streaming] def getInputStreams() = this.synchronized { inputStreams.toArray } - def getOutputStreams() = outputStreams.toArray + private[streaming] def getOutputStreams() = this.synchronized { outputStreams.toArray } - def generateRDDs(time: Time): Seq[Job] = { + private[streaming] def generateRDDs(time: Time): Seq[Job] = { this.synchronized { outputStreams.flatMap(outputStream => outputStream.generateJob(time)) } } - def forgetOldRDDs(time: Time) { + private[streaming] def forgetOldRDDs(time: Time) { this.synchronized { outputStreams.foreach(_.forgetOldRDDs(time)) } } - def validate() { + private[streaming] def updateCheckpointData(time: Time) { + this.synchronized { + outputStreams.foreach(_.updateCheckpointData(time)) + } + } + + private[streaming] def restoreCheckpointData() { + this.synchronized { + outputStreams.foreach(_.restoreCheckpointData()) + } + } + + private[streaming] def validate() { this.synchronized { assert(batchDuration != null, "Batch duration has not been set") assert(batchDuration > Milliseconds(100), "Batch duration of " + batchDuration + " is very low") diff --git a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala index 1c57d5f855..6df82c0df3 100644 --- a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala @@ -21,15 +21,19 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( partitioner: Partitioner ) extends DStream[(K,V)](parent.ssc) { - if (!_windowTime.isMultipleOf(parent.slideTime)) - throw new Exception("The window duration of ReducedWindowedDStream (" + _slideTime + ") " + - "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")") + assert(_windowTime.isMultipleOf(parent.slideTime), + "The window duration of ReducedWindowedDStream (" + _slideTime + ") " + + "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")" + ) - if (!_slideTime.isMultipleOf(parent.slideTime)) - throw new Exception("The slide duration of ReducedWindowedDStream (" + _slideTime + ") " + - "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")") + assert(_slideTime.isMultipleOf(parent.slideTime), + "The slide duration of ReducedWindowedDStream (" + _slideTime + ") " + + "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")" + ) - @transient val reducedStream = parent.reduceByKey(reduceFunc, partitioner) + super.persist(StorageLevel.MEMORY_ONLY) + + val reducedStream = parent.reduceByKey(reduceFunc, partitioner) def windowTime: Time = _windowTime @@ -37,15 +41,19 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( override def slideTime: Time = _slideTime - //TODO: This is wrong. This should depend on the checkpointInterval + override val mustCheckpoint = true + override def parentRememberDuration: Time = rememberDuration + windowTime - override def persist( - storageLevel: StorageLevel, - checkpointLevel: StorageLevel, - checkpointInterval: Time): DStream[(K,V)] = { - super.persist(storageLevel, checkpointLevel, checkpointInterval) - reducedStream.persist(storageLevel, checkpointLevel, checkpointInterval) + override def persist(storageLevel: StorageLevel): DStream[(K,V)] = { + super.persist(storageLevel) + reducedStream.persist(storageLevel) + this + } + + override def checkpoint(interval: Time): DStream[(K, V)] = { + super.checkpoint(interval) + reducedStream.checkpoint(interval) this } diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index 7d52e2eddf..2b3f5a4829 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -58,7 +58,6 @@ extends Logging { graph.forgetOldRDDs(time) if (ssc.checkpointInterval != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointInterval)) { ssc.doCheckpoint(time) - logInfo("Checkpointed at time " + time) } } diff --git a/streaming/src/main/scala/spark/streaming/StateDStream.scala b/streaming/src/main/scala/spark/streaming/StateDStream.scala index 086752ac55..0211df1343 100644 --- a/streaming/src/main/scala/spark/streaming/StateDStream.scala +++ b/streaming/src/main/scala/spark/streaming/StateDStream.scala @@ -23,51 +23,14 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife rememberPartitioner: Boolean ) extends DStream[(K, S)](parent.ssc) { + super.persist(StorageLevel.MEMORY_ONLY) + override def dependencies = List(parent) override def slideTime = parent.slideTime - override def getOrCompute(time: Time): Option[RDD[(K, S)]] = { - generatedRDDs.get(time) match { - case Some(oldRDD) => { - if (checkpointInterval != null && time > zeroTime && (time - zeroTime).isMultipleOf(checkpointInterval) && oldRDD.dependencies.size > 0) { - val r = oldRDD - val oldRDDBlockIds = oldRDD.splits.map(s => "rdd:" + r.id + ":" + s.index) - val checkpointedRDD = new BlockRDD[(K, S)](ssc.sc, oldRDDBlockIds) { - override val partitioner = oldRDD.partitioner - } - generatedRDDs.update(time, checkpointedRDD) - logInfo("Checkpointed RDD " + oldRDD.id + " of time " + time + " with its new RDD " + checkpointedRDD.id) - Some(checkpointedRDD) - } else { - Some(oldRDD) - } - } - case None => { - if (isTimeValid(time)) { - compute(time) match { - case Some(newRDD) => { - if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) { - newRDD.persist(checkpointLevel) - logInfo("Persisting " + newRDD + " to " + checkpointLevel + " at time " + time) - } else if (storageLevel != StorageLevel.NONE) { - newRDD.persist(storageLevel) - logInfo("Persisting " + newRDD + " to " + storageLevel + " at time " + time) - } - generatedRDDs.put(time, newRDD) - Some(newRDD) - } - case None => { - None - } - } - } else { - None - } - } - } - } - + override val mustCheckpoint = true + override def compute(validTime: Time): Option[RDD[(K, S)]] = { // Try to get the previous state RDD diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 4a78090597..05c83d6c08 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -15,6 +15,8 @@ import org.apache.hadoop.io.LongWritable import org.apache.hadoop.io.Text import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.hadoop.mapreduce.lib.input.TextInputFormat +import org.apache.hadoop.fs.Path +import java.util.UUID class StreamingContext ( sc_ : SparkContext, @@ -26,7 +28,7 @@ class StreamingContext ( def this(master: String, frameworkName: String, sparkHome: String = null, jars: Seq[String] = Nil) = this(new SparkContext(master, frameworkName, sparkHome, jars), null) - def this(file: String) = this(null, Checkpoint.loadFromFile(file)) + def this(path: String) = this(null, Checkpoint.load(path)) def this(cp_ : Checkpoint) = this(null, cp_) @@ -51,8 +53,8 @@ class StreamingContext ( val graph: DStreamGraph = { if (isCheckpointPresent) { - cp_.graph.setContext(this) + cp_.graph.restoreCheckpointData() cp_.graph } else { new DStreamGraph() @@ -62,7 +64,15 @@ class StreamingContext ( val nextNetworkInputStreamId = new AtomicInteger(0) var networkInputTracker: NetworkInputTracker = null - private[streaming] var checkpointFile: String = if (isCheckpointPresent) cp_.checkpointFile else null + private[streaming] var checkpointDir: String = { + if (isCheckpointPresent) { + sc.setCheckpointDir(cp_.checkpointDir, true) + cp_.checkpointDir + } else { + null + } + } + private[streaming] var checkpointInterval: Time = if (isCheckpointPresent) cp_.checkpointInterval else null private[streaming] var receiverJobThread: Thread = null private[streaming] var scheduler: Scheduler = null @@ -75,9 +85,15 @@ class StreamingContext ( graph.setRememberDuration(duration) } - def setCheckpointDetails(file: String, interval: Time) { - checkpointFile = file - checkpointInterval = interval + def checkpoint(dir: String, interval: Time) { + if (dir != null) { + sc.setCheckpointDir(new Path(dir, "rdds-" + UUID.randomUUID.toString).toString) + checkpointDir = dir + checkpointInterval = interval + } else { + checkpointDir = null + checkpointInterval = null + } } private[streaming] def getInitialCheckpoint(): Checkpoint = { @@ -181,16 +197,12 @@ class StreamingContext ( graph.addOutputStream(outputStream) } - def validate() { - assert(graph != null, "Graph is null") - graph.validate() - } - /** * This function starts the execution of the streams. */ def start() { - validate() + assert(graph != null, "Graph is null") + graph.validate() val networkInputStreams = graph.getInputStreams().filter(s => s match { case n: NetworkInputDStream[_] => true @@ -218,16 +230,16 @@ class StreamingContext ( if (scheduler != null) scheduler.stop() if (networkInputTracker != null) networkInputTracker.stop() if (receiverJobThread != null) receiverJobThread.interrupt() - sc.stop() + sc.stop() + logInfo("StreamingContext stopped successfully") } catch { case e: Exception => logWarning("Error while stopping", e) } - - logInfo("StreamingContext stopped") } def doCheckpoint(currentTime: Time) { - new Checkpoint(this, currentTime).saveToFile(checkpointFile) + graph.updateCheckpointData(currentTime) + new Checkpoint(this, currentTime).save(checkpointDir) } } diff --git a/streaming/src/main/scala/spark/streaming/Time.scala b/streaming/src/main/scala/spark/streaming/Time.scala index 9ddb65249a..480d292d7c 100644 --- a/streaming/src/main/scala/spark/streaming/Time.scala +++ b/streaming/src/main/scala/spark/streaming/Time.scala @@ -25,6 +25,10 @@ case class Time(millis: Long) { def isMultipleOf(that: Time): Boolean = (this.millis % that.millis == 0) + def min(that: Time): Time = if (this < that) this else that + + def max(that: Time): Time = if (this > that) this else that + def isZero: Boolean = (this.millis == 0) override def toString: String = (millis.toString + " ms") @@ -39,7 +43,7 @@ object Time { implicit def toTime(long: Long) = Time(long) - implicit def toLong(time: Time) = time.milliseconds + implicit def toLong(time: Time) = time.milliseconds } object Milliseconds { diff --git a/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala b/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala index df96a811da..21a83c0fde 100644 --- a/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala +++ b/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala @@ -10,20 +10,20 @@ object FileStreamWithCheckpoint { def main(args: Array[String]) { if (args.size != 3) { - println("FileStreamWithCheckpoint <master> <directory> <checkpoint file>") - println("FileStreamWithCheckpoint restart <directory> <checkpoint file>") + println("FileStreamWithCheckpoint <master> <directory> <checkpoint dir>") + println("FileStreamWithCheckpoint restart <directory> <checkpoint dir>") System.exit(-1) } val directory = new Path(args(1)) - val checkpointFile = args(2) + val checkpointDir = args(2) val ssc: StreamingContext = { if (args(0) == "restart") { // Recreated streaming context from specified checkpoint file - new StreamingContext(checkpointFile) + new StreamingContext(checkpointDir) } else { @@ -34,7 +34,7 @@ object FileStreamWithCheckpoint { // Create new streaming context val ssc_ = new StreamingContext(args(0), "FileStreamWithCheckpoint") ssc_.setBatchDuration(Seconds(1)) - ssc_.setCheckpointDetails(checkpointFile, Seconds(1)) + ssc_.checkpoint(checkpointDir, Seconds(1)) // Setup the streaming computation val inputStream = ssc_.textFileStream(directory.toString) diff --git a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala index 57fd10f0a5..750cb7445f 100644 --- a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala @@ -41,9 +41,8 @@ object TopKWordCountRaw { val windowedCounts = union.mapPartitions(splitAndCountPartitions) .reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(batchMs), reduces) - windowedCounts.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, - Milliseconds(chkptMs)) - //windowedCounts.print() // TODO: something else? + windowedCounts.persist().checkpoint(Milliseconds(chkptMs)) + //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMs)) def topK(data: Iterator[(String, Long)], k: Int): Iterator[(String, Long)] = { val taken = new Array[(String, Long)](k) diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala index 0d2e62b955..865026033e 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala @@ -100,10 +100,9 @@ object WordCount2 { val windowedCounts = sentences .mapPartitions(splitAndCountPartitions) .reduceByKeyAndWindow(add _, subtract _, Seconds(30), batchDuration, reduceTasks.toInt) - windowedCounts.persist(StorageLevel.MEMORY_ONLY, - StorageLevel.MEMORY_ONLY_2, - //new StorageLevel(false, true, true, 3), - Milliseconds(chkptMillis.toLong)) + + windowedCounts.persist().checkpoint(Milliseconds(chkptMillis.toLong)) + //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMillis.toLong)) windowedCounts.foreachRDD(r => println("Element count: " + r.count())) ssc.start() diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala index abfd12890f..d1ea9a9cd5 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala @@ -41,9 +41,9 @@ object WordCountRaw { val windowedCounts = union.mapPartitions(splitAndCountPartitions) .reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(batchMs), reduces) - windowedCounts.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, - Milliseconds(chkptMs)) - //windowedCounts.print() // TODO: something else? + windowedCounts.persist().checkpoint(chkptMs) + //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMs)) + windowedCounts.foreachRDD(r => println("Element count: " + r.count())) ssc.start() diff --git a/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala b/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala index 9d44da2b11..6a9c8a9a69 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala @@ -57,11 +57,13 @@ object WordMax2 { val windowedCounts = sentences .mapPartitions(splitAndCountPartitions) .reduceByKey(add _, reduceTasks.toInt) - .persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, - Milliseconds(chkptMillis.toLong)) + .persist() + .checkpoint(Milliseconds(chkptMillis.toLong)) + //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMillis.toLong)) .reduceByKeyAndWindow(max _, Seconds(10), batchDuration, reduceTasks.toInt) - //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, - // Milliseconds(chkptMillis.toLong)) + .persist() + .checkpoint(Milliseconds(chkptMillis.toLong)) + //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMillis.toLong)) windowedCounts.foreachRDD(r => println("Element count: " + r.count())) ssc.start() diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala index 6dcedcf463..9fdfd50be2 100644 --- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala @@ -2,52 +2,195 @@ package spark.streaming import spark.streaming.StreamingContext._ import java.io.File +import runtime.RichInt +import org.scalatest.BeforeAndAfter +import org.apache.commons.io.FileUtils +import collection.mutable.{SynchronizedBuffer, ArrayBuffer} +import util.{Clock, ManualClock} -class CheckpointSuite extends TestSuiteBase { +class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { + + before { + FileUtils.deleteDirectory(new File(checkpointDir)) + } + + after { + FileUtils.deleteDirectory(new File(checkpointDir)) + } + + override def framework = "CheckpointSuite" + + override def batchDuration = Milliseconds(500) + + override def checkpointDir = "checkpoint" + + override def checkpointInterval = batchDuration + + override def actuallyWait = true + + test("basic stream+rdd recovery") { + + assert(batchDuration === Milliseconds(500), "batchDuration for this test must be 1 second") + assert(checkpointInterval === batchDuration, "checkpointInterval for this test much be same as batchDuration") + + System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") + + val stateStreamCheckpointInterval = Seconds(2) + + // this ensure checkpointing occurs at least once + val firstNumBatches = (stateStreamCheckpointInterval.millis / batchDuration.millis) * 2 + val secondNumBatches = firstNumBatches + + // Setup the streams + val input = (1 to 10).map(_ => Seq("a")).toSeq + val operation = (st: DStream[String]) => { + val updateFunc = (values: Seq[Int], state: Option[RichInt]) => { + Some(new RichInt(values.foldLeft(0)(_ + _) + state.map(_.self).getOrElse(0))) + } + st.map(x => (x, 1)) + .updateStateByKey[RichInt](updateFunc) + .checkpoint(stateStreamCheckpointInterval) + .map(t => (t._1, t._2.self)) + } + val ssc = setupStreams(input, operation) + val stateStream = ssc.graph.getOutputStreams().head.dependencies.head.dependencies.head + + // Run till a time such that at least one RDD in the stream should have been checkpointed + ssc.start() + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + advanceClock(clock, firstNumBatches) + + // Check whether some RDD has been checkpointed or not + logInfo("Checkpoint data of state stream = \n[" + stateStream.checkpointData.mkString(",\n") + "]") + assert(!stateStream.checkpointData.isEmpty, "No checkpointed RDDs in state stream before first failure") + stateStream.checkpointData.foreach { + case (time, data) => { + val file = new File(data.toString) + assert(file.exists(), "Checkpoint file '" + file +"' for time " + time + " for state stream before first failure does not exist") + } + } + + // Run till a further time such that previous checkpoint files in the stream would be deleted + // and check whether the earlier checkpoint files are deleted + val checkpointFiles = stateStream.checkpointData.map(x => new File(x._2.toString)) + advanceClock(clock, secondNumBatches) + checkpointFiles.foreach(file => assert(!file.exists, "Checkpoint file '" + file + "' was not deleted")) + + // Restart stream computation using the checkpoint file and check whether + // checkpointed RDDs have been restored or not + ssc.stop() + val sscNew = new StreamingContext(checkpointDir) + val stateStreamNew = sscNew.graph.getOutputStreams().head.dependencies.head.dependencies.head + logInfo("Restored data of state stream = \n[" + stateStreamNew.generatedRDDs.mkString("\n") + "]") + assert(!stateStreamNew.generatedRDDs.isEmpty, "No restored RDDs in state stream after recovery from first failure") + + + // Run one batch to generate a new checkpoint file + sscNew.start() + val clockNew = sscNew.scheduler.clock.asInstanceOf[ManualClock] + advanceClock(clockNew, 1) + + // Check whether some RDD is present in the checkpoint data or not + assert(!stateStreamNew.checkpointData.isEmpty, "No checkpointed RDDs in state stream before second failure") + stateStream.checkpointData.foreach { + case (time, data) => { + val file = new File(data.toString) + assert(file.exists(), "Checkpoint file '" + file +"' for time " + time + " for state stream before seconds failure does not exist") + } + } + + // Restart stream computation from the new checkpoint file to see whether that file has + // correct checkpoint data + sscNew.stop() + val sscNewNew = new StreamingContext(checkpointDir) + val stateStreamNewNew = sscNew.graph.getOutputStreams().head.dependencies.head.dependencies.head + logInfo("Restored data of state stream = \n[" + stateStreamNew.generatedRDDs.mkString("\n") + "]") + assert(!stateStreamNewNew.generatedRDDs.isEmpty, "No restored RDDs in state stream after recovery from second failure") + sscNewNew.start() + advanceClock(sscNewNew.scheduler.clock.asInstanceOf[ManualClock], 1) + sscNewNew.stop() + } + + test("map and reduceByKey") { + testCheckpointedOperation( + Seq( Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq() ), + (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _), + Seq( Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq() ), + 3 + ) + } + + test("reduceByKeyAndWindowInv") { + val n = 10 + val w = 4 + val input = (1 to n).map(x => Seq("a")).toSeq + val output = Seq(Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 3))) ++ (1 to (n - w + 1)).map(x => Seq(("a", 4))) + val operation = (st: DStream[String]) => { + st.map(x => (x, 1)).reduceByKeyAndWindow(_ + _, _ - _, batchDuration * 4, batchDuration) + } + for (i <- Seq(2, 3, 4)) { + testCheckpointedOperation(input, operation, output, i) + } + } + + test("updateStateByKey") { + val input = (1 to 10).map(_ => Seq("a")).toSeq + val output = (1 to 10).map(x => Seq(("a", x))).toSeq + val operation = (st: DStream[String]) => { + val updateFunc = (values: Seq[Int], state: Option[RichInt]) => { + Some(new RichInt(values.foldLeft(0)(_ + _) + state.map(_.self).getOrElse(0))) + } + st.map(x => (x, 1)) + .updateStateByKey[RichInt](updateFunc) + .checkpoint(Seconds(2)) + .map(t => (t._1, t._2.self)) + } + for (i <- Seq(2, 3, 4)) { + testCheckpointedOperation(input, operation, output, i) + } + } - override def framework() = "CheckpointSuite" - override def checkpointFile() = "checkpoint" def testCheckpointedOperation[U: ClassManifest, V: ClassManifest]( - input: Seq[Seq[U]], - operation: DStream[U] => DStream[V], - expectedOutput: Seq[Seq[V]], - useSet: Boolean = false - ) { + input: Seq[Seq[U]], + operation: DStream[U] => DStream[V], + expectedOutput: Seq[Seq[V]], + initialNumBatches: Int + ) { // Current code assumes that: // number of inputs = number of outputs = number of batches to be run - val totalNumBatches = input.size - val initialNumBatches = input.size / 2 val nextNumBatches = totalNumBatches - initialNumBatches val initialNumExpectedOutputs = initialNumBatches + val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs // Do half the computation (half the number of batches), create checkpoint file and quit + val ssc = setupStreams[U, V](input, operation) val output = runStreams[V](ssc, initialNumBatches, initialNumExpectedOutputs) - verifyOutput[V](output, expectedOutput.take(initialNumBatches), useSet) + verifyOutput[V](output, expectedOutput.take(initialNumBatches), true) Thread.sleep(1000) // Restart and complete the computation from checkpoint file - val sscNew = new StreamingContext(checkpointFile) - sscNew.setCheckpointDetails(null, null) - val outputNew = runStreams[V](sscNew, nextNumBatches, expectedOutput.size) - verifyOutput[V](outputNew, expectedOutput, useSet) - - new File(checkpointFile).delete() - new File(checkpointFile + ".bk").delete() - new File("." + checkpointFile + ".crc").delete() - new File("." + checkpointFile + ".bk.crc").delete() + logInfo( + "\n-------------------------------------------\n" + + " Restarting stream computation " + + "\n-------------------------------------------\n" + ) + val sscNew = new StreamingContext(checkpointDir) + val outputNew = runStreams[V](sscNew, nextNumBatches, nextNumExpectedOutputs) + verifyOutput[V](outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), true) } - test("simple per-batch operation") { - testCheckpointedOperation( - Seq( Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq() ), - (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _), - Seq( Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq() ), - true - ) + def advanceClock(clock: ManualClock, numBatches: Long) { + logInfo("Manual clock before advancing = " + clock.time) + for (i <- 1 to numBatches.toInt) { + clock.addToTime(batchDuration.milliseconds) + Thread.sleep(batchDuration.milliseconds) + } + logInfo("Manual clock after advancing = " + clock.time) + Thread.sleep(batchDuration.milliseconds) } }
\ No newline at end of file diff --git a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala index c9bc454f91..b8c7f99603 100644 --- a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala @@ -5,10 +5,16 @@ import util.ManualClock import collection.mutable.ArrayBuffer import org.scalatest.FunSuite import collection.mutable.SynchronizedBuffer +import java.io.{ObjectInputStream, IOException} + +/** + * This is a input stream just for the testsuites. This is equivalent to a checkpointable, + * replayable, reliable message queue like Kafka. It requires a sequence as input, and + * returns the i_th element at the i_th batch unde manual clock. + */ class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, input: Seq[Seq[T]], numPartitions: Int) extends InputDStream[T](ssc_) { - var currentIndex = 0 def start() {} @@ -23,34 +29,49 @@ class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, input: Seq[Seq[ ssc.sc.makeRDD(Seq[T](), numPartitions) } logInfo("Created RDD " + rdd.id) - //currentIndex += 1 Some(rdd) } } +/** + * This is a output stream just for the testsuites. All the output is collected into a + * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint. + */ class TestOutputStream[T: ClassManifest](parent: DStream[T], val output: ArrayBuffer[Seq[T]]) extends PerRDDForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { val collected = rdd.collect() output += collected - }) + }) { + + // This is to clear the output buffer every it is read from a checkpoint + @throws(classOf[IOException]) + private def readObject(ois: ObjectInputStream) { + ois.defaultReadObject() + output.clear() + } +} +/** + * This is the base trait for Spark Streaming testsuites. This provides basic functionality + * to run user-defined set of input on user-defined stream operations, and verify the output. + */ trait TestSuiteBase extends FunSuite with Logging { - System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") + def framework = "TestSuiteBase" - def framework() = "TestSuiteBase" + def master = "local[2]" - def master() = "local[2]" + def batchDuration = Seconds(1) - def batchDuration() = Seconds(1) + def checkpointDir = null.asInstanceOf[String] - def checkpointFile() = null.asInstanceOf[String] + def checkpointInterval = batchDuration - def checkpointInterval() = batchDuration + def numInputPartitions = 2 - def numInputPartitions() = 2 + def maxWaitTimeMillis = 10000 - def maxWaitTimeMillis() = 10000 + def actuallyWait = false def setupStreams[U: ClassManifest, V: ClassManifest]( input: Seq[Seq[U]], @@ -60,8 +81,8 @@ trait TestSuiteBase extends FunSuite with Logging { // Create StreamingContext val ssc = new StreamingContext(master, framework) ssc.setBatchDuration(batchDuration) - if (checkpointFile != null) { - ssc.setCheckpointDetails(checkpointFile, checkpointInterval()) + if (checkpointDir != null) { + ssc.checkpoint(checkpointDir, checkpointInterval) } // Setup the stream computation @@ -82,8 +103,8 @@ trait TestSuiteBase extends FunSuite with Logging { // Create StreamingContext val ssc = new StreamingContext(master, framework) ssc.setBatchDuration(batchDuration) - if (checkpointFile != null) { - ssc.setCheckpointDetails(checkpointFile, checkpointInterval()) + if (checkpointDir != null) { + ssc.checkpoint(checkpointDir, checkpointInterval) } // Setup the stream computation @@ -97,12 +118,19 @@ trait TestSuiteBase extends FunSuite with Logging { ssc } + /** + * Runs the streams set up in `ssc` on manual clock for `numBatches` batches and + * returns the collected output. It will wait until `numExpectedOutput` number of + * output data has been collected or timeout (set by `maxWaitTimeMillis`) is reached. + */ def runStreams[V: ClassManifest]( ssc: StreamingContext, numBatches: Int, numExpectedOutput: Int ): Seq[Seq[V]] = { + System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") + assert(numBatches > 0, "Number of batches to run stream computation is zero") assert(numExpectedOutput > 0, "Number of expected outputs after " + numBatches + " is zero") logInfo("numBatches = " + numBatches + ", numExpectedOutput = " + numExpectedOutput) @@ -118,7 +146,15 @@ trait TestSuiteBase extends FunSuite with Logging { // Advance manual clock val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] logInfo("Manual clock before advancing = " + clock.time) - clock.addToTime(numBatches * batchDuration.milliseconds) + if (actuallyWait) { + for (i <- 1 to numBatches) { + logInfo("Actually waiting for " + batchDuration) + clock.addToTime(batchDuration.milliseconds) + Thread.sleep(batchDuration.milliseconds) + } + } else { + clock.addToTime(numBatches * batchDuration.milliseconds) + } logInfo("Manual clock after advancing = " + clock.time) // Wait until expected number of output items have been generated diff --git a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala index d7d8d5bd36..3e20e16708 100644 --- a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala @@ -5,11 +5,11 @@ import collection.mutable.ArrayBuffer class WindowOperationsSuite extends TestSuiteBase { - override def framework() = "WindowOperationsSuite" + override def framework = "WindowOperationsSuite" - override def maxWaitTimeMillis() = 20000 + override def maxWaitTimeMillis = 20000 - override def batchDuration() = Seconds(1) + override def batchDuration = Seconds(1) val largerSlideInput = Seq( Seq(("a", 1)), @@ -283,7 +283,9 @@ class WindowOperationsSuite extends TestSuiteBase { test("reduceByKeyAndWindowInv - " + name) { val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt val operation = (s: DStream[(String, Int)]) => { - s.reduceByKeyAndWindow(_ + _, _ - _, windowTime, slideTime).persist() + s.reduceByKeyAndWindow(_ + _, _ - _, windowTime, slideTime) + .persist() + .checkpoint(Seconds(100)) // Large value to avoid effect of RDD checkpointing } testOperation(input, operation, expectedOutput, numBatches, true) } |