diff options
17 files changed, 367 insertions, 203 deletions
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 7b59a6f09e..63048d5df0 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -119,22 +119,23 @@ abstract class RDD[T: ClassManifest]( private var storageLevel: StorageLevel = StorageLevel.NONE /** Returns the first parent RDD */ - private[spark] def firstParent[U: ClassManifest] = { + protected[spark] def firstParent[U: ClassManifest] = { dependencies.head.rdd.asInstanceOf[RDD[U]] } /** Returns the `i` th parent RDD */ - private[spark] def parent[U: ClassManifest](i: Int) = dependencies(i).rdd.asInstanceOf[RDD[U]] + protected[spark] def parent[U: ClassManifest](i: Int) = dependencies(i).rdd.asInstanceOf[RDD[U]] // Variables relating to checkpointing - val isCheckpointable = true // override to set this to false to avoid checkpointing an RDD - var shouldCheckpoint = false // set to true when an RDD is marked for checkpointing - var isCheckpointInProgress = false // set to true when checkpointing is in progress - var isCheckpointed = false // set to true after checkpointing is completed + protected val isCheckpointable = true // override to set this to false to avoid checkpointing an RDD - var checkpointFile: String = null // set to the checkpoint file after checkpointing is completed - var checkpointRDD: RDD[T] = null // set to the HadoopRDD of the checkpoint file - var checkpointRDDSplits: Seq[Split] = null // set to the splits of the Hadoop RDD + protected var shouldCheckpoint = false // set to true when an RDD is marked for checkpointing + protected var isCheckpointInProgress = false // set to true when checkpointing is in progress + protected[spark] var isCheckpointed = false // set to true after checkpointing is completed + + protected[spark] var checkpointFile: String = null // set to the checkpoint file after checkpointing is completed + protected var checkpointRDD: RDD[T] = null // set to the HadoopRDD of the checkpoint file + protected var checkpointRDDSplits: Seq[Split] = null // set to the splits of the Hadoop RDD // Methods available on all RDDs: @@ -176,6 +177,9 @@ abstract class RDD[T: ClassManifest]( if (isCheckpointed || shouldCheckpoint || isCheckpointInProgress) { // do nothing } else if (isCheckpointable) { + if (sc.checkpointDir == null) { + throw new Exception("Checkpoint directory has not been set in the SparkContext.") + } shouldCheckpoint = true } else { throw new Exception(this + " cannot be checkpointed") @@ -183,6 +187,16 @@ abstract class RDD[T: ClassManifest]( } } + def getCheckpointData(): Any = { + synchronized { + if (isCheckpointed) { + checkpointFile + } else { + null + } + } + } + /** * Performs the checkpointing of this RDD by saving this . It is called by the DAGScheduler after a job * using this RDD has completed (therefore the RDD has been materialized and diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 79ceab5f4f..d7326971a9 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -584,14 +584,15 @@ class SparkContext( * overwriting existing files may be overwritten). The directory will be deleted on exit * if indicated. */ - def setCheckpointDir(dir: String, deleteOnExit: Boolean = false) { + def setCheckpointDir(dir: String, useExisting: Boolean = false) { val path = new Path(dir) val fs = path.getFileSystem(new Configuration()) - if (fs.exists(path)) { - throw new Exception("Checkpoint directory '" + path + "' already exists.") - } else { - fs.mkdirs(path) - if (deleteOnExit) fs.deleteOnExit(path) + if (!useExisting) { + if (fs.exists(path)) { + throw new Exception("Checkpoint directory '" + path + "' already exists.") + } else { + fs.mkdirs(path) + } } checkpointDir = dir } diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala index 83a43d15cb..cf04c7031e 100644 --- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala @@ -1,6 +1,6 @@ package spark.streaming -import spark.Utils +import spark.{Logging, Utils} import org.apache.hadoop.fs.{FileUtil, Path} import org.apache.hadoop.conf.Configuration @@ -8,13 +8,14 @@ import org.apache.hadoop.conf.Configuration import java.io.{InputStream, ObjectStreamClass, ObjectInputStream, ObjectOutputStream} -class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) extends Serializable { +class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) + extends Logging with Serializable { val master = ssc.sc.master val framework = ssc.sc.jobName val sparkHome = ssc.sc.sparkHome val jars = ssc.sc.jars val graph = ssc.graph - val checkpointFile = ssc.checkpointFile + val checkpointDir = ssc.checkpointDir val checkpointInterval = ssc.checkpointInterval validate() @@ -24,22 +25,25 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) ext assert(framework != null, "Checkpoint.framework is null") assert(graph != null, "Checkpoint.graph is null") assert(checkpointTime != null, "Checkpoint.checkpointTime is null") + logInfo("Checkpoint for time " + checkpointTime + " validated") } - def saveToFile(file: String = checkpointFile) { - val path = new Path(file) + def save(path: String) { + val file = new Path(path, "graph") val conf = new Configuration() - val fs = path.getFileSystem(conf) - if (fs.exists(path)) { - val bkPath = new Path(path.getParent, path.getName + ".bk") - FileUtil.copy(fs, path, fs, bkPath, true, true, conf) - //logInfo("Moved existing checkpoint file to " + bkPath) + val fs = file.getFileSystem(conf) + logDebug("Saved checkpoint for time " + checkpointTime + " to file '" + file + "'") + if (fs.exists(file)) { + val bkFile = new Path(file.getParent, file.getName + ".bk") + FileUtil.copy(fs, file, fs, bkFile, true, true, conf) + logDebug("Moved existing checkpoint file to " + bkFile) } - val fos = fs.create(path) + val fos = fs.create(file) val oos = new ObjectOutputStream(fos) oos.writeObject(this) oos.close() fs.close() + logInfo("Saved checkpoint for time " + checkpointTime + " to file '" + file + "'") } def toBytes(): Array[Byte] = { @@ -50,30 +54,41 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) ext object Checkpoint { - def loadFromFile(file: String): Checkpoint = { - try { - val path = new Path(file) - val conf = new Configuration() - val fs = path.getFileSystem(conf) - if (!fs.exists(path)) { - throw new Exception("Checkpoint file '" + file + "' does not exist") + def load(path: String): Checkpoint = { + + val fs = new Path(path).getFileSystem(new Configuration()) + val attempts = Seq(new Path(path), new Path(path, "graph"), new Path(path, "graph.bk")) + var lastException: Exception = null + var lastExceptionFile: String = null + + attempts.foreach(file => { + if (fs.exists(file)) { + try { + val fis = fs.open(file) + // ObjectInputStream uses the last defined user-defined class loader in the stack + // to find classes, which maybe the wrong class loader. Hence, a inherited version + // of ObjectInputStream is used to explicitly use the current thread's default class + // loader to find and load classes. This is a well know Java issue and has popped up + // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627) + val ois = new ObjectInputStreamWithLoader(fis, Thread.currentThread().getContextClassLoader) + val cp = ois.readObject.asInstanceOf[Checkpoint] + ois.close() + fs.close() + cp.validate() + println("Checkpoint successfully loaded from file " + file) + return cp + } catch { + case e: Exception => + lastException = e + lastExceptionFile = file.toString + } } - val fis = fs.open(path) - // ObjectInputStream uses the last defined user-defined class loader in the stack - // to find classes, which maybe the wrong class loader. Hence, a inherited version - // of ObjectInputStream is used to explicitly use the current thread's default class - // loader to find and load classes. This is a well know Java issue and has popped up - // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627) - val ois = new ObjectInputStreamWithLoader(fis, Thread.currentThread().getContextClassLoader) - val cp = ois.readObject.asInstanceOf[Checkpoint] - ois.close() - fs.close() - cp.validate() - cp - } catch { - case e: Exception => - e.printStackTrace() - throw new Exception("Could not load checkpoint file '" + file + "'", e) + }) + + if (lastException == null) { + throw new Exception("Could not load checkpoint from path '" + path + "'") + } else { + throw new Exception("Error loading checkpoint from path '" + lastExceptionFile + "'", lastException) } } diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index a4921bb1a2..de51c5d34a 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -13,6 +13,7 @@ import scala.collection.mutable.HashMap import java.util.concurrent.ArrayBlockingQueue import java.io.{ObjectInputStream, IOException, ObjectOutputStream} import scala.Some +import collection.mutable abstract class DStream[T: ClassManifest] (@transient var ssc: StreamingContext) extends Serializable with Logging { @@ -41,53 +42,55 @@ extends Serializable with Logging { */ // RDDs generated, marked as protected[streaming] so that testsuites can access it - protected[streaming] val generatedRDDs = new HashMap[Time, RDD[T]] () + protected[streaming] var generatedRDDs = new HashMap[Time, RDD[T]] () // Time zero for the DStream - protected var zeroTime: Time = null + protected[streaming] var zeroTime: Time = null // Duration for which the DStream will remember each RDD created - protected var rememberDuration: Time = null + protected[streaming] var rememberDuration: Time = null // Storage level of the RDDs in the stream - protected var storageLevel: StorageLevel = StorageLevel.NONE + protected[streaming] var storageLevel: StorageLevel = StorageLevel.NONE - // Checkpoint level and checkpoint interval - protected var checkpointLevel: StorageLevel = StorageLevel.NONE // NONE means don't checkpoint - protected var checkpointInterval: Time = null + // Checkpoint details + protected[streaming] val mustCheckpoint = false + protected[streaming] var checkpointInterval: Time = null + protected[streaming] val checkpointData = new HashMap[Time, Any]() // Reference to whole DStream graph - protected var graph: DStreamGraph = null + protected[streaming] var graph: DStreamGraph = null def isInitialized = (zeroTime != null) // Duration for which the DStream requires its parent DStream to remember each RDD created def parentRememberDuration = rememberDuration - // Change this RDD's storage level - def persist( - storageLevel: StorageLevel, - checkpointLevel: StorageLevel, - checkpointInterval: Time): DStream[T] = { - if (this.storageLevel != StorageLevel.NONE && this.storageLevel != storageLevel) { - // TODO: not sure this is necessary for DStreams + // Set caching level for the RDDs created by this DStream + def persist(level: StorageLevel): DStream[T] = { + if (this.isInitialized) { throw new UnsupportedOperationException( - "Cannot change storage level of an DStream after it was already assigned a level") + "Cannot change storage level of an DStream after streaming context has started") } - this.storageLevel = storageLevel - this.checkpointLevel = checkpointLevel - this.checkpointInterval = checkpointInterval + this.storageLevel = level this } - // Set caching level for the RDDs created by this DStream - def persist(newLevel: StorageLevel): DStream[T] = persist(newLevel, StorageLevel.NONE, null) - def persist(): DStream[T] = persist(StorageLevel.MEMORY_ONLY) // Turn on the default caching level for this RDD def cache(): DStream[T] = persist() + def checkpoint(interval: Time): DStream[T] = { + if (isInitialized) { + throw new UnsupportedOperationException( + "Cannot change checkpoint interval of an DStream after streaming context has started") + } + persist() + checkpointInterval = interval + this + } + /** * This method initializes the DStream by setting the "zero" time, based on which * the validity of future times is calculated. This method also recursively initializes @@ -99,7 +102,67 @@ extends Serializable with Logging { + ", cannot initialize it again to " + time) } zeroTime = time + + // Set the checkpoint interval to be slideTime or 10 seconds, which ever is larger + if (mustCheckpoint && checkpointInterval == null) { + checkpointInterval = slideTime.max(Seconds(10)) + logInfo("Checkpoint interval automatically set to " + checkpointInterval) + } + + // Set the minimum value of the rememberDuration if not already set + var minRememberDuration = slideTime + if (checkpointInterval != null && minRememberDuration <= checkpointInterval) { + minRememberDuration = checkpointInterval + slideTime + } + if (rememberDuration == null || rememberDuration < minRememberDuration) { + rememberDuration = minRememberDuration + } + + // Initialize the dependencies dependencies.foreach(_.initialize(zeroTime)) + } + + protected[streaming] def validate() { + assert( + !mustCheckpoint || checkpointInterval != null, + "The checkpoint interval for " + this.getClass.getSimpleName + " has not been set. " + + " Please use DStream.checkpoint() to set the interval." + ) + + assert( + checkpointInterval == null || checkpointInterval >= slideTime, + "The checkpoint interval for " + this.getClass.getSimpleName + " has been set to " + + checkpointInterval + " which is lower than its slide time (" + slideTime + "). " + + "Please set it to at least " + slideTime + "." + ) + + assert( + checkpointInterval == null || checkpointInterval.isMultipleOf(slideTime), + "The checkpoint interval for " + this.getClass.getSimpleName + " has been set to " + + checkpointInterval + " which not a multiple of its slide time (" + slideTime + "). " + + "Please set it to a multiple " + slideTime + "." + ) + + assert( + checkpointInterval == null || storageLevel != StorageLevel.NONE, + "" + this.getClass.getSimpleName + " has been marked for checkpointing but the storage " + + "level has not been set to enable persisting. Please use DStream.persist() to set the " + + "storage level to use memory for better checkpointing performance." + ) + + assert( + checkpointInterval == null || rememberDuration > checkpointInterval, + "The remember duration for " + this.getClass.getSimpleName + " has been set to " + + rememberDuration + " which is not more than the checkpoint interval (" + + checkpointInterval + "). Please set it to higher than " + checkpointInterval + "." + ) + + dependencies.foreach(_.validate()) + + logInfo("Slide time = " + slideTime) + logInfo("Storage level = " + storageLevel) + logInfo("Checkpoint interval = " + checkpointInterval) + logInfo("Remember duration = " + rememberDuration) logInfo("Initialized " + this) } @@ -120,17 +183,12 @@ extends Serializable with Logging { dependencies.foreach(_.setGraph(graph)) } - protected[streaming] def setRememberDuration(duration: Time = slideTime) { - if (duration == null) { - throw new Exception("Duration for remembering RDDs cannot be set to null for " + this) - } else if (rememberDuration != null && duration < rememberDuration) { - logWarning("Duration for remembering RDDs cannot be reduced from " + rememberDuration - + " to " + duration + " for " + this) - } else { + protected[streaming] def setRememberDuration(duration: Time) { + if (duration != null && duration > rememberDuration) { rememberDuration = duration - dependencies.foreach(_.setRememberDuration(parentRememberDuration)) logInfo("Duration for remembering RDDs set to " + rememberDuration + " for " + this) } + dependencies.foreach(_.setRememberDuration(parentRememberDuration)) } /** This method checks whether the 'time' is valid wrt slideTime for generating RDD */ @@ -163,12 +221,13 @@ extends Serializable with Logging { if (isTimeValid(time)) { compute(time) match { case Some(newRDD) => - if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) { - newRDD.persist(checkpointLevel) - logInfo("Persisting " + newRDD + " to " + checkpointLevel + " at time " + time) - } else if (storageLevel != StorageLevel.NONE) { + if (storageLevel != StorageLevel.NONE) { newRDD.persist(storageLevel) - logInfo("Persisting " + newRDD + " to " + storageLevel + " at time " + time) + logInfo("Persisting RDD for time " + time + " to " + storageLevel + " at time " + time) + } + if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) { + newRDD.checkpoint() + logInfo("Marking RDD for time " + time + " for checkpointing at time " + time) } generatedRDDs.put(time, newRDD) Some(newRDD) @@ -199,7 +258,7 @@ extends Serializable with Logging { } } - def forgetOldRDDs(time: Time) { + protected[streaming] def forgetOldRDDs(time: Time) { val keys = generatedRDDs.keys var numForgotten = 0 keys.foreach(t => { @@ -213,12 +272,35 @@ extends Serializable with Logging { dependencies.foreach(_.forgetOldRDDs(time)) } + protected[streaming] def updateCheckpointData() { + checkpointData.clear() + generatedRDDs.foreach { + case(time, rdd) => { + logDebug("Adding checkpointed RDD for time " + time) + val data = rdd.getCheckpointData() + if (data != null) { + checkpointData += ((time, data)) + } + } + } + } + + protected[streaming] def restoreCheckpointData() { + checkpointData.foreach { + case(time, data) => { + logInfo("Restoring checkpointed RDD for time " + time) + generatedRDDs += ((time, ssc.sc.objectFile[T](data.toString))) + } + } + } + @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { logDebug(this.getClass().getSimpleName + ".writeObject used") if (graph != null) { graph.synchronized { if (graph.checkpointInProgress) { + updateCheckpointData() oos.defaultWriteObject() } else { val msg = "Object of " + this.getClass.getName + " is being serialized " + @@ -239,6 +321,8 @@ extends Serializable with Logging { private def readObject(ois: ObjectInputStream) { logDebug(this.getClass().getSimpleName + ".readObject used") ois.defaultReadObject() + generatedRDDs = new HashMap[Time, RDD[T]] () + restoreCheckpointData() } /** diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala index ac44d7a2a6..f8922ec790 100644 --- a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala @@ -22,11 +22,8 @@ final class DStreamGraph extends Serializable with Logging { } zeroTime = time outputStreams.foreach(_.initialize(zeroTime)) - outputStreams.foreach(_.setRememberDuration()) // first set the rememberDuration to default values - if (rememberDuration != null) { - // if custom rememberDuration has been provided, set the rememberDuration - outputStreams.foreach(_.setRememberDuration(rememberDuration)) - } + outputStreams.foreach(_.setRememberDuration(rememberDuration)) + outputStreams.foreach(_.validate) inputStreams.par.foreach(_.start()) } } diff --git a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala index 1c57d5f855..6df82c0df3 100644 --- a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala @@ -21,15 +21,19 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( partitioner: Partitioner ) extends DStream[(K,V)](parent.ssc) { - if (!_windowTime.isMultipleOf(parent.slideTime)) - throw new Exception("The window duration of ReducedWindowedDStream (" + _slideTime + ") " + - "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")") + assert(_windowTime.isMultipleOf(parent.slideTime), + "The window duration of ReducedWindowedDStream (" + _slideTime + ") " + + "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")" + ) - if (!_slideTime.isMultipleOf(parent.slideTime)) - throw new Exception("The slide duration of ReducedWindowedDStream (" + _slideTime + ") " + - "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")") + assert(_slideTime.isMultipleOf(parent.slideTime), + "The slide duration of ReducedWindowedDStream (" + _slideTime + ") " + + "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")" + ) - @transient val reducedStream = parent.reduceByKey(reduceFunc, partitioner) + super.persist(StorageLevel.MEMORY_ONLY) + + val reducedStream = parent.reduceByKey(reduceFunc, partitioner) def windowTime: Time = _windowTime @@ -37,15 +41,19 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest]( override def slideTime: Time = _slideTime - //TODO: This is wrong. This should depend on the checkpointInterval + override val mustCheckpoint = true + override def parentRememberDuration: Time = rememberDuration + windowTime - override def persist( - storageLevel: StorageLevel, - checkpointLevel: StorageLevel, - checkpointInterval: Time): DStream[(K,V)] = { - super.persist(storageLevel, checkpointLevel, checkpointInterval) - reducedStream.persist(storageLevel, checkpointLevel, checkpointInterval) + override def persist(storageLevel: StorageLevel): DStream[(K,V)] = { + super.persist(storageLevel) + reducedStream.persist(storageLevel) + this + } + + override def checkpoint(interval: Time): DStream[(K, V)] = { + super.checkpoint(interval) + reducedStream.checkpoint(interval) this } diff --git a/streaming/src/main/scala/spark/streaming/StateDStream.scala b/streaming/src/main/scala/spark/streaming/StateDStream.scala index 086752ac55..0211df1343 100644 --- a/streaming/src/main/scala/spark/streaming/StateDStream.scala +++ b/streaming/src/main/scala/spark/streaming/StateDStream.scala @@ -23,51 +23,14 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife rememberPartitioner: Boolean ) extends DStream[(K, S)](parent.ssc) { + super.persist(StorageLevel.MEMORY_ONLY) + override def dependencies = List(parent) override def slideTime = parent.slideTime - override def getOrCompute(time: Time): Option[RDD[(K, S)]] = { - generatedRDDs.get(time) match { - case Some(oldRDD) => { - if (checkpointInterval != null && time > zeroTime && (time - zeroTime).isMultipleOf(checkpointInterval) && oldRDD.dependencies.size > 0) { - val r = oldRDD - val oldRDDBlockIds = oldRDD.splits.map(s => "rdd:" + r.id + ":" + s.index) - val checkpointedRDD = new BlockRDD[(K, S)](ssc.sc, oldRDDBlockIds) { - override val partitioner = oldRDD.partitioner - } - generatedRDDs.update(time, checkpointedRDD) - logInfo("Checkpointed RDD " + oldRDD.id + " of time " + time + " with its new RDD " + checkpointedRDD.id) - Some(checkpointedRDD) - } else { - Some(oldRDD) - } - } - case None => { - if (isTimeValid(time)) { - compute(time) match { - case Some(newRDD) => { - if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) { - newRDD.persist(checkpointLevel) - logInfo("Persisting " + newRDD + " to " + checkpointLevel + " at time " + time) - } else if (storageLevel != StorageLevel.NONE) { - newRDD.persist(storageLevel) - logInfo("Persisting " + newRDD + " to " + storageLevel + " at time " + time) - } - generatedRDDs.put(time, newRDD) - Some(newRDD) - } - case None => { - None - } - } - } else { - None - } - } - } - } - + override val mustCheckpoint = true + override def compute(validTime: Time): Option[RDD[(K, S)]] = { // Try to get the previous state RDD diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index b3148eaa97..3838e84113 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -15,6 +15,8 @@ import org.apache.hadoop.io.LongWritable import org.apache.hadoop.io.Text import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.hadoop.mapreduce.lib.input.TextInputFormat +import org.apache.hadoop.fs.Path +import java.util.UUID class StreamingContext ( sc_ : SparkContext, @@ -26,7 +28,7 @@ class StreamingContext ( def this(master: String, frameworkName: String, sparkHome: String = null, jars: Seq[String] = Nil) = this(new SparkContext(master, frameworkName, sparkHome, jars), null) - def this(file: String) = this(null, Checkpoint.loadFromFile(file)) + def this(path: String) = this(null, Checkpoint.load(path)) def this(cp_ : Checkpoint) = this(null, cp_) @@ -51,7 +53,6 @@ class StreamingContext ( val graph: DStreamGraph = { if (isCheckpointPresent) { - cp_.graph.setContext(this) cp_.graph } else { @@ -62,7 +63,15 @@ class StreamingContext ( val nextNetworkInputStreamId = new AtomicInteger(0) var networkInputTracker: NetworkInputTracker = null - private[streaming] var checkpointFile: String = if (isCheckpointPresent) cp_.checkpointFile else null + private[streaming] var checkpointDir: String = { + if (isCheckpointPresent) { + sc.setCheckpointDir(cp_.checkpointDir, true) + cp_.checkpointDir + } else { + null + } + } + private[streaming] var checkpointInterval: Time = if (isCheckpointPresent) cp_.checkpointInterval else null private[streaming] var receiverJobThread: Thread = null private[streaming] var scheduler: Scheduler = null @@ -75,9 +84,15 @@ class StreamingContext ( graph.setRememberDuration(duration) } - def setCheckpointDetails(file: String, interval: Time) { - checkpointFile = file - checkpointInterval = interval + def checkpoint(dir: String, interval: Time) { + if (dir != null) { + sc.setCheckpointDir(new Path(dir, "rdds-" + UUID.randomUUID.toString).toString) + checkpointDir = dir + checkpointInterval = interval + } else { + checkpointDir = null + checkpointInterval = null + } } private[streaming] def getInitialCheckpoint(): Checkpoint = { @@ -170,16 +185,12 @@ class StreamingContext ( graph.addOutputStream(outputStream) } - def validate() { - assert(graph != null, "Graph is null") - graph.validate() - } - /** * This function starts the execution of the streams. */ def start() { - validate() + assert(graph != null, "Graph is null") + graph.validate() val networkInputStreams = graph.getInputStreams().filter(s => s match { case n: NetworkInputDStream[_] => true @@ -216,7 +227,8 @@ class StreamingContext ( } def doCheckpoint(currentTime: Time) { - new Checkpoint(this, currentTime).saveToFile(checkpointFile) + new Checkpoint(this, currentTime).save(checkpointDir) + } } diff --git a/streaming/src/main/scala/spark/streaming/Time.scala b/streaming/src/main/scala/spark/streaming/Time.scala index 9ddb65249a..2ba6502971 100644 --- a/streaming/src/main/scala/spark/streaming/Time.scala +++ b/streaming/src/main/scala/spark/streaming/Time.scala @@ -25,6 +25,10 @@ case class Time(millis: Long) { def isMultipleOf(that: Time): Boolean = (this.millis % that.millis == 0) + def min(that: Time): Time = if (this < that) this else that + + def max(that: Time): Time = if (this > that) this else that + def isZero: Boolean = (this.millis == 0) override def toString: String = (millis.toString + " ms") diff --git a/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala b/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala index df96a811da..21a83c0fde 100644 --- a/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala +++ b/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala @@ -10,20 +10,20 @@ object FileStreamWithCheckpoint { def main(args: Array[String]) { if (args.size != 3) { - println("FileStreamWithCheckpoint <master> <directory> <checkpoint file>") - println("FileStreamWithCheckpoint restart <directory> <checkpoint file>") + println("FileStreamWithCheckpoint <master> <directory> <checkpoint dir>") + println("FileStreamWithCheckpoint restart <directory> <checkpoint dir>") System.exit(-1) } val directory = new Path(args(1)) - val checkpointFile = args(2) + val checkpointDir = args(2) val ssc: StreamingContext = { if (args(0) == "restart") { // Recreated streaming context from specified checkpoint file - new StreamingContext(checkpointFile) + new StreamingContext(checkpointDir) } else { @@ -34,7 +34,7 @@ object FileStreamWithCheckpoint { // Create new streaming context val ssc_ = new StreamingContext(args(0), "FileStreamWithCheckpoint") ssc_.setBatchDuration(Seconds(1)) - ssc_.setCheckpointDetails(checkpointFile, Seconds(1)) + ssc_.checkpoint(checkpointDir, Seconds(1)) // Setup the streaming computation val inputStream = ssc_.textFileStream(directory.toString) diff --git a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala index 57fd10f0a5..750cb7445f 100644 --- a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala @@ -41,9 +41,8 @@ object TopKWordCountRaw { val windowedCounts = union.mapPartitions(splitAndCountPartitions) .reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(batchMs), reduces) - windowedCounts.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, - Milliseconds(chkptMs)) - //windowedCounts.print() // TODO: something else? + windowedCounts.persist().checkpoint(Milliseconds(chkptMs)) + //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMs)) def topK(data: Iterator[(String, Long)], k: Int): Iterator[(String, Long)] = { val taken = new Array[(String, Long)](k) diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala index 0d2e62b955..865026033e 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala @@ -100,10 +100,9 @@ object WordCount2 { val windowedCounts = sentences .mapPartitions(splitAndCountPartitions) .reduceByKeyAndWindow(add _, subtract _, Seconds(30), batchDuration, reduceTasks.toInt) - windowedCounts.persist(StorageLevel.MEMORY_ONLY, - StorageLevel.MEMORY_ONLY_2, - //new StorageLevel(false, true, true, 3), - Milliseconds(chkptMillis.toLong)) + + windowedCounts.persist().checkpoint(Milliseconds(chkptMillis.toLong)) + //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMillis.toLong)) windowedCounts.foreachRDD(r => println("Element count: " + r.count())) ssc.start() diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala index abfd12890f..d1ea9a9cd5 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala @@ -41,9 +41,9 @@ object WordCountRaw { val windowedCounts = union.mapPartitions(splitAndCountPartitions) .reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(batchMs), reduces) - windowedCounts.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, - Milliseconds(chkptMs)) - //windowedCounts.print() // TODO: something else? + windowedCounts.persist().checkpoint(chkptMs) + //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMs)) + windowedCounts.foreachRDD(r => println("Element count: " + r.count())) ssc.start() diff --git a/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala b/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala index 9d44da2b11..6a9c8a9a69 100644 --- a/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala +++ b/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala @@ -57,11 +57,13 @@ object WordMax2 { val windowedCounts = sentences .mapPartitions(splitAndCountPartitions) .reduceByKey(add _, reduceTasks.toInt) - .persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, - Milliseconds(chkptMillis.toLong)) + .persist() + .checkpoint(Milliseconds(chkptMillis.toLong)) + //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMillis.toLong)) .reduceByKeyAndWindow(max _, Seconds(10), batchDuration, reduceTasks.toInt) - //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, - // Milliseconds(chkptMillis.toLong)) + .persist() + .checkpoint(Milliseconds(chkptMillis.toLong)) + //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMillis.toLong)) windowedCounts.foreachRDD(r => println("Element count: " + r.count())) ssc.start() diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala index 6dcedcf463..dfe31b5771 100644 --- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala @@ -2,52 +2,95 @@ package spark.streaming import spark.streaming.StreamingContext._ import java.io.File +import collection.mutable.ArrayBuffer +import runtime.RichInt +import org.scalatest.BeforeAndAfter +import org.apache.hadoop.fs.Path +import org.apache.commons.io.FileUtils -class CheckpointSuite extends TestSuiteBase { +class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { + + before { + FileUtils.deleteDirectory(new File(checkpointDir)) + } + + after { + FileUtils.deleteDirectory(new File(checkpointDir)) + } override def framework() = "CheckpointSuite" - override def checkpointFile() = "checkpoint" + override def batchDuration() = Seconds(1) + + override def checkpointDir() = "checkpoint" + + override def checkpointInterval() = batchDuration def testCheckpointedOperation[U: ClassManifest, V: ClassManifest]( input: Seq[Seq[U]], operation: DStream[U] => DStream[V], expectedOutput: Seq[Seq[V]], - useSet: Boolean = false + initialNumBatches: Int ) { // Current code assumes that: // number of inputs = number of outputs = number of batches to be run - val totalNumBatches = input.size - val initialNumBatches = input.size / 2 val nextNumBatches = totalNumBatches - initialNumBatches val initialNumExpectedOutputs = initialNumBatches + val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs // Do half the computation (half the number of batches), create checkpoint file and quit val ssc = setupStreams[U, V](input, operation) val output = runStreams[V](ssc, initialNumBatches, initialNumExpectedOutputs) - verifyOutput[V](output, expectedOutput.take(initialNumBatches), useSet) + verifyOutput[V](output, expectedOutput.take(initialNumBatches), true) Thread.sleep(1000) // Restart and complete the computation from checkpoint file - val sscNew = new StreamingContext(checkpointFile) - sscNew.setCheckpointDetails(null, null) - val outputNew = runStreams[V](sscNew, nextNumBatches, expectedOutput.size) - verifyOutput[V](outputNew, expectedOutput, useSet) - - new File(checkpointFile).delete() - new File(checkpointFile + ".bk").delete() - new File("." + checkpointFile + ".crc").delete() - new File("." + checkpointFile + ".bk.crc").delete() + val sscNew = new StreamingContext(checkpointDir) + //sscNew.checkpoint(null, null) + val outputNew = runStreams[V](sscNew, nextNumBatches, nextNumExpectedOutputs) + verifyOutput[V](outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), true) } - test("simple per-batch operation") { + + test("map and reduceByKey") { testCheckpointedOperation( Seq( Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq() ), (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _), Seq( Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq() ), - true + 3 ) } + + test("reduceByKeyAndWindowInv") { + val n = 10 + val w = 4 + val input = (1 to n).map(x => Seq("a")).toSeq + val output = Seq(Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 3))) ++ (1 to (n - w + 1)).map(x => Seq(("a", 4))) + val operation = (st: DStream[String]) => { + st.map(x => (x, 1)).reduceByKeyAndWindow(_ + _, _ - _, Seconds(w), Seconds(1)) + } + for (i <- Seq(3, 5, 7)) { + testCheckpointedOperation(input, operation, output, i) + } + } + + test("updateStateByKey") { + val input = (1 to 10).map(_ => Seq("a")).toSeq + val output = (1 to 10).map(x => Seq(("a", x))).toSeq + val operation = (st: DStream[String]) => { + val updateFunc = (values: Seq[Int], state: Option[RichInt]) => { + Some(new RichInt(values.foldLeft(0)(_ + _) + state.map(_.self).getOrElse(0))) + } + st.map(x => (x, 1)) + .updateStateByKey[RichInt](updateFunc) + .checkpoint(Seconds(5)) + .map(t => (t._1, t._2.self)) + } + for (i <- Seq(3, 5, 7)) { + testCheckpointedOperation(input, operation, output, i) + } + } + }
\ No newline at end of file diff --git a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala index c9bc454f91..e441feea19 100644 --- a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala @@ -5,10 +5,16 @@ import util.ManualClock import collection.mutable.ArrayBuffer import org.scalatest.FunSuite import collection.mutable.SynchronizedBuffer +import java.io.{ObjectInputStream, IOException} + +/** + * This is a input stream just for the testsuites. This is equivalent to a checkpointable, + * replayable, reliable message queue like Kafka. It requires a sequence as input, and + * returns the i_th element at the i_th batch unde manual clock. + */ class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, input: Seq[Seq[T]], numPartitions: Int) extends InputDStream[T](ssc_) { - var currentIndex = 0 def start() {} @@ -23,17 +29,32 @@ class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, input: Seq[Seq[ ssc.sc.makeRDD(Seq[T](), numPartitions) } logInfo("Created RDD " + rdd.id) - //currentIndex += 1 Some(rdd) } } +/** + * This is a output stream just for the testsuites. All the output is collected into a + * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint. + */ class TestOutputStream[T: ClassManifest](parent: DStream[T], val output: ArrayBuffer[Seq[T]]) extends PerRDDForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { val collected = rdd.collect() output += collected - }) + }) { + + // This is to clear the output buffer every it is read from a checkpoint + @throws(classOf[IOException]) + private def readObject(ois: ObjectInputStream) { + ois.defaultReadObject() + output.clear() + } +} +/** + * This is the base trait for Spark Streaming testsuites. This provides basic functionality + * to run user-defined set of input on user-defined stream operations, and verify the output. + */ trait TestSuiteBase extends FunSuite with Logging { System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") @@ -44,7 +65,7 @@ trait TestSuiteBase extends FunSuite with Logging { def batchDuration() = Seconds(1) - def checkpointFile() = null.asInstanceOf[String] + def checkpointDir() = null.asInstanceOf[String] def checkpointInterval() = batchDuration @@ -60,8 +81,8 @@ trait TestSuiteBase extends FunSuite with Logging { // Create StreamingContext val ssc = new StreamingContext(master, framework) ssc.setBatchDuration(batchDuration) - if (checkpointFile != null) { - ssc.setCheckpointDetails(checkpointFile, checkpointInterval()) + if (checkpointDir != null) { + ssc.checkpoint(checkpointDir, checkpointInterval()) } // Setup the stream computation @@ -82,8 +103,8 @@ trait TestSuiteBase extends FunSuite with Logging { // Create StreamingContext val ssc = new StreamingContext(master, framework) ssc.setBatchDuration(batchDuration) - if (checkpointFile != null) { - ssc.setCheckpointDetails(checkpointFile, checkpointInterval()) + if (checkpointDir != null) { + ssc.checkpoint(checkpointDir, checkpointInterval()) } // Setup the stream computation diff --git a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala index d7d8d5bd36..e282f0fdd5 100644 --- a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala @@ -283,7 +283,9 @@ class WindowOperationsSuite extends TestSuiteBase { test("reduceByKeyAndWindowInv - " + name) { val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt val operation = (s: DStream[(String, Int)]) => { - s.reduceByKeyAndWindow(_ + _, _ - _, windowTime, slideTime).persist() + s.reduceByKeyAndWindow(_ + _, _ - _, windowTime, slideTime) + .persist() + .checkpoint(Seconds(100)) // Large value to avoid effect of RDD checkpointing } testOperation(input, operation, expectedOutput, numBatches, true) } |