aboutsummaryrefslogtreecommitdiff
path: root/streaming/src
diff options
context:
space:
mode:
authorDenny <dennybritz@gmail.com>2012-11-06 09:41:45 -0800
committerDenny <dennybritz@gmail.com>2012-11-06 09:41:45 -0800
commit485803d740307e03beee056390b0ecb0a76fbbb1 (patch)
treebf353711817d3b9e338a1f8e5044044af952ec36 /streaming/src
parent0c1de43fc7a9fea8629907d5b331e466f18be418 (diff)
parent395167f2b2a1906cde23b1f3ddc2808514bce47b (diff)
downloadspark-485803d740307e03beee056390b0ecb0a76fbbb1.tar.gz
spark-485803d740307e03beee056390b0ecb0a76fbbb1.tar.bz2
spark-485803d740307e03beee056390b0ecb0a76fbbb1.zip
Merge branch 'dev' of github.com:radlab/spark into kafka
Diffstat (limited to 'streaming/src')
-rw-r--r--streaming/src/main/scala/spark/streaming/Checkpoint.scala84
-rw-r--r--streaming/src/main/scala/spark/streaming/DStream.scala210
-rw-r--r--streaming/src/main/scala/spark/streaming/DStreamGraph.scala43
-rw-r--r--streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala36
-rw-r--r--streaming/src/main/scala/spark/streaming/Scheduler.scala1
-rw-r--r--streaming/src/main/scala/spark/streaming/StateDStream.scala45
-rw-r--r--streaming/src/main/scala/spark/streaming/StreamingContext.scala44
-rw-r--r--streaming/src/main/scala/spark/streaming/Time.scala6
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala10
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala5
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/WordCount2.scala7
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala6
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/WordMax2.scala10
-rw-r--r--streaming/src/test/scala/spark/streaming/CheckpointSuite.scala197
-rw-r--r--streaming/src/test/scala/spark/streaming/TestSuiteBase.scala68
-rw-r--r--streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala10
16 files changed, 545 insertions, 237 deletions
diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala
index 83a43d15cb..1643f45ffb 100644
--- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala
@@ -1,6 +1,6 @@
package spark.streaming
-import spark.Utils
+import spark.{Logging, Utils}
import org.apache.hadoop.fs.{FileUtil, Path}
import org.apache.hadoop.conf.Configuration
@@ -8,13 +8,14 @@ import org.apache.hadoop.conf.Configuration
import java.io.{InputStream, ObjectStreamClass, ObjectInputStream, ObjectOutputStream}
-class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) extends Serializable {
+class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time)
+ extends Logging with Serializable {
val master = ssc.sc.master
val framework = ssc.sc.jobName
val sparkHome = ssc.sc.sparkHome
val jars = ssc.sc.jars
val graph = ssc.graph
- val checkpointFile = ssc.checkpointFile
+ val checkpointDir = ssc.checkpointDir
val checkpointInterval = ssc.checkpointInterval
validate()
@@ -24,22 +25,25 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) ext
assert(framework != null, "Checkpoint.framework is null")
assert(graph != null, "Checkpoint.graph is null")
assert(checkpointTime != null, "Checkpoint.checkpointTime is null")
+ logInfo("Checkpoint for time " + checkpointTime + " validated")
}
- def saveToFile(file: String = checkpointFile) {
- val path = new Path(file)
+ def save(path: String) {
+ val file = new Path(path, "graph")
val conf = new Configuration()
- val fs = path.getFileSystem(conf)
- if (fs.exists(path)) {
- val bkPath = new Path(path.getParent, path.getName + ".bk")
- FileUtil.copy(fs, path, fs, bkPath, true, true, conf)
- //logInfo("Moved existing checkpoint file to " + bkPath)
+ val fs = file.getFileSystem(conf)
+ logDebug("Saved checkpoint for time " + checkpointTime + " to file '" + file + "'")
+ if (fs.exists(file)) {
+ val bkFile = new Path(file.getParent, file.getName + ".bk")
+ FileUtil.copy(fs, file, fs, bkFile, true, true, conf)
+ logDebug("Moved existing checkpoint file to " + bkFile)
}
- val fos = fs.create(path)
+ val fos = fs.create(file)
val oos = new ObjectOutputStream(fos)
oos.writeObject(this)
oos.close()
fs.close()
+ logInfo("Saved checkpoint for time " + checkpointTime + " to file '" + file + "'")
}
def toBytes(): Array[Byte] = {
@@ -48,33 +52,41 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) ext
}
}
-object Checkpoint {
+object Checkpoint extends Logging {
- def loadFromFile(file: String): Checkpoint = {
- try {
- val path = new Path(file)
- val conf = new Configuration()
- val fs = path.getFileSystem(conf)
- if (!fs.exists(path)) {
- throw new Exception("Checkpoint file '" + file + "' does not exist")
+ def load(path: String): Checkpoint = {
+
+ val fs = new Path(path).getFileSystem(new Configuration())
+ val attempts = Seq(new Path(path, "graph"), new Path(path, "graph.bk"), new Path(path), new Path(path + ".bk"))
+ var detailedLog: String = ""
+
+ attempts.foreach(file => {
+ if (fs.exists(file)) {
+ logInfo("Attempting to load checkpoint from file '" + file + "'")
+ try {
+ val fis = fs.open(file)
+ // ObjectInputStream uses the last defined user-defined class loader in the stack
+ // to find classes, which maybe the wrong class loader. Hence, a inherited version
+ // of ObjectInputStream is used to explicitly use the current thread's default class
+ // loader to find and load classes. This is a well know Java issue and has popped up
+ // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627)
+ val ois = new ObjectInputStreamWithLoader(fis, Thread.currentThread().getContextClassLoader)
+ val cp = ois.readObject.asInstanceOf[Checkpoint]
+ ois.close()
+ fs.close()
+ cp.validate()
+ logInfo("Checkpoint successfully loaded from file '" + file + "'")
+ return cp
+ } catch {
+ case e: Exception =>
+ logError("Error loading checkpoint from file '" + file + "'", e)
+ }
+ } else {
+ logWarning("Could not load checkpoint from file '" + file + "' as it does not exist")
}
- val fis = fs.open(path)
- // ObjectInputStream uses the last defined user-defined class loader in the stack
- // to find classes, which maybe the wrong class loader. Hence, a inherited version
- // of ObjectInputStream is used to explicitly use the current thread's default class
- // loader to find and load classes. This is a well know Java issue and has popped up
- // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627)
- val ois = new ObjectInputStreamWithLoader(fis, Thread.currentThread().getContextClassLoader)
- val cp = ois.readObject.asInstanceOf[Checkpoint]
- ois.close()
- fs.close()
- cp.validate()
- cp
- } catch {
- case e: Exception =>
- e.printStackTrace()
- throw new Exception("Could not load checkpoint file '" + file + "'", e)
- }
+
+ })
+ throw new Exception("Could not load checkpoint from path '" + path + "'")
}
def fromBytes(bytes: Array[Byte]): Checkpoint = {
diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala
index a4921bb1a2..922ff5088d 100644
--- a/streaming/src/main/scala/spark/streaming/DStream.scala
+++ b/streaming/src/main/scala/spark/streaming/DStream.scala
@@ -1,6 +1,7 @@
package spark.streaming
-import spark.streaming.StreamingContext._
+import StreamingContext._
+import Time._
import spark._
import spark.SparkContext._
@@ -12,7 +13,9 @@ import scala.collection.mutable.HashMap
import java.util.concurrent.ArrayBlockingQueue
import java.io.{ObjectInputStream, IOException, ObjectOutputStream}
-import scala.Some
+
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.conf.Configuration
abstract class DStream[T: ClassManifest] (@transient var ssc: StreamingContext)
extends Serializable with Logging {
@@ -41,53 +44,56 @@ extends Serializable with Logging {
*/
// RDDs generated, marked as protected[streaming] so that testsuites can access it
- protected[streaming] val generatedRDDs = new HashMap[Time, RDD[T]] ()
+ @transient
+ protected[streaming] var generatedRDDs = new HashMap[Time, RDD[T]] ()
// Time zero for the DStream
- protected var zeroTime: Time = null
+ protected[streaming] var zeroTime: Time = null
// Duration for which the DStream will remember each RDD created
- protected var rememberDuration: Time = null
+ protected[streaming] var rememberDuration: Time = null
// Storage level of the RDDs in the stream
- protected var storageLevel: StorageLevel = StorageLevel.NONE
+ protected[streaming] var storageLevel: StorageLevel = StorageLevel.NONE
- // Checkpoint level and checkpoint interval
- protected var checkpointLevel: StorageLevel = StorageLevel.NONE // NONE means don't checkpoint
- protected var checkpointInterval: Time = null
+ // Checkpoint details
+ protected[streaming] val mustCheckpoint = false
+ protected[streaming] var checkpointInterval: Time = null
+ protected[streaming] val checkpointData = new HashMap[Time, Any]()
// Reference to whole DStream graph
- protected var graph: DStreamGraph = null
+ protected[streaming] var graph: DStreamGraph = null
def isInitialized = (zeroTime != null)
// Duration for which the DStream requires its parent DStream to remember each RDD created
def parentRememberDuration = rememberDuration
- // Change this RDD's storage level
- def persist(
- storageLevel: StorageLevel,
- checkpointLevel: StorageLevel,
- checkpointInterval: Time): DStream[T] = {
- if (this.storageLevel != StorageLevel.NONE && this.storageLevel != storageLevel) {
- // TODO: not sure this is necessary for DStreams
+ // Set caching level for the RDDs created by this DStream
+ def persist(level: StorageLevel): DStream[T] = {
+ if (this.isInitialized) {
throw new UnsupportedOperationException(
- "Cannot change storage level of an DStream after it was already assigned a level")
+ "Cannot change storage level of an DStream after streaming context has started")
}
- this.storageLevel = storageLevel
- this.checkpointLevel = checkpointLevel
- this.checkpointInterval = checkpointInterval
+ this.storageLevel = level
this
}
- // Set caching level for the RDDs created by this DStream
- def persist(newLevel: StorageLevel): DStream[T] = persist(newLevel, StorageLevel.NONE, null)
-
def persist(): DStream[T] = persist(StorageLevel.MEMORY_ONLY)
// Turn on the default caching level for this RDD
def cache(): DStream[T] = persist()
+ def checkpoint(interval: Time): DStream[T] = {
+ if (isInitialized) {
+ throw new UnsupportedOperationException(
+ "Cannot change checkpoint interval of an DStream after streaming context has started")
+ }
+ persist()
+ checkpointInterval = interval
+ this
+ }
+
/**
* This method initializes the DStream by setting the "zero" time, based on which
* the validity of future times is calculated. This method also recursively initializes
@@ -99,7 +105,67 @@ extends Serializable with Logging {
+ ", cannot initialize it again to " + time)
}
zeroTime = time
+
+ // Set the checkpoint interval to be slideTime or 10 seconds, which ever is larger
+ if (mustCheckpoint && checkpointInterval == null) {
+ checkpointInterval = slideTime.max(Seconds(10))
+ logInfo("Checkpoint interval automatically set to " + checkpointInterval)
+ }
+
+ // Set the minimum value of the rememberDuration if not already set
+ var minRememberDuration = slideTime
+ if (checkpointInterval != null && minRememberDuration <= checkpointInterval) {
+ minRememberDuration = checkpointInterval * 2 // times 2 just to be sure that the latest checkpoint is not forgetten
+ }
+ if (rememberDuration == null || rememberDuration < minRememberDuration) {
+ rememberDuration = minRememberDuration
+ }
+
+ // Initialize the dependencies
dependencies.foreach(_.initialize(zeroTime))
+ }
+
+ protected[streaming] def validate() {
+ assert(
+ !mustCheckpoint || checkpointInterval != null,
+ "The checkpoint interval for " + this.getClass.getSimpleName + " has not been set. " +
+ " Please use DStream.checkpoint() to set the interval."
+ )
+
+ assert(
+ checkpointInterval == null || checkpointInterval >= slideTime,
+ "The checkpoint interval for " + this.getClass.getSimpleName + " has been set to " +
+ checkpointInterval + " which is lower than its slide time (" + slideTime + "). " +
+ "Please set it to at least " + slideTime + "."
+ )
+
+ assert(
+ checkpointInterval == null || checkpointInterval.isMultipleOf(slideTime),
+ "The checkpoint interval for " + this.getClass.getSimpleName + " has been set to " +
+ checkpointInterval + " which not a multiple of its slide time (" + slideTime + "). " +
+ "Please set it to a multiple " + slideTime + "."
+ )
+
+ assert(
+ checkpointInterval == null || storageLevel != StorageLevel.NONE,
+ "" + this.getClass.getSimpleName + " has been marked for checkpointing but the storage " +
+ "level has not been set to enable persisting. Please use DStream.persist() to set the " +
+ "storage level to use memory for better checkpointing performance."
+ )
+
+ assert(
+ checkpointInterval == null || rememberDuration > checkpointInterval,
+ "The remember duration for " + this.getClass.getSimpleName + " has been set to " +
+ rememberDuration + " which is not more than the checkpoint interval (" +
+ checkpointInterval + "). Please set it to higher than " + checkpointInterval + "."
+ )
+
+ dependencies.foreach(_.validate())
+
+ logInfo("Slide time = " + slideTime)
+ logInfo("Storage level = " + storageLevel)
+ logInfo("Checkpoint interval = " + checkpointInterval)
+ logInfo("Remember duration = " + rememberDuration)
logInfo("Initialized " + this)
}
@@ -120,17 +186,12 @@ extends Serializable with Logging {
dependencies.foreach(_.setGraph(graph))
}
- protected[streaming] def setRememberDuration(duration: Time = slideTime) {
- if (duration == null) {
- throw new Exception("Duration for remembering RDDs cannot be set to null for " + this)
- } else if (rememberDuration != null && duration < rememberDuration) {
- logWarning("Duration for remembering RDDs cannot be reduced from " + rememberDuration
- + " to " + duration + " for " + this)
- } else {
+ protected[streaming] def setRememberDuration(duration: Time) {
+ if (duration != null && duration > rememberDuration) {
rememberDuration = duration
- dependencies.foreach(_.setRememberDuration(parentRememberDuration))
logInfo("Duration for remembering RDDs set to " + rememberDuration + " for " + this)
}
+ dependencies.foreach(_.setRememberDuration(parentRememberDuration))
}
/** This method checks whether the 'time' is valid wrt slideTime for generating RDD */
@@ -145,10 +206,10 @@ extends Serializable with Logging {
}
/**
- * This method either retrieves a precomputed RDD of this DStream,
- * or computes the RDD (if the time is valid)
+ * Retrieves a precomputed RDD of this DStream, or computes the RDD. This is an internal
+ * method that should not be called directly.
*/
- def getOrCompute(time: Time): Option[RDD[T]] = {
+ protected[streaming] def getOrCompute(time: Time): Option[RDD[T]] = {
// If this DStream was not initialized (i.e., zeroTime not set), then do it
// If RDD was already generated, then retrieve it from HashMap
generatedRDDs.get(time) match {
@@ -163,12 +224,13 @@ extends Serializable with Logging {
if (isTimeValid(time)) {
compute(time) match {
case Some(newRDD) =>
- if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) {
- newRDD.persist(checkpointLevel)
- logInfo("Persisting " + newRDD + " to " + checkpointLevel + " at time " + time)
- } else if (storageLevel != StorageLevel.NONE) {
+ if (storageLevel != StorageLevel.NONE) {
newRDD.persist(storageLevel)
- logInfo("Persisting " + newRDD + " to " + storageLevel + " at time " + time)
+ logInfo("Persisting RDD for time " + time + " to " + storageLevel + " at time " + time)
+ }
+ if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) {
+ newRDD.checkpoint()
+ logInfo("Marking RDD for time " + time + " for checkpointing at time " + time)
}
generatedRDDs.put(time, newRDD)
Some(newRDD)
@@ -183,10 +245,12 @@ extends Serializable with Logging {
}
/**
- * This method generates a SparkStreaming job for the given time
- * and may required to be overriden by subclasses
+ * Generates a SparkStreaming job for the given time. This is an internal method that
+ * should not be called directly. This default implementation creates a job
+ * that materializes the corresponding RDD. Subclasses of DStream may override this
+ * (eg. PerRDDForEachDStream).
*/
- def generateJob(time: Time): Option[Job] = {
+ protected[streaming] def generateJob(time: Time): Option[Job] = {
getOrCompute(time) match {
case Some(rdd) => {
val jobFunc = () => {
@@ -199,20 +263,75 @@ extends Serializable with Logging {
}
}
- def forgetOldRDDs(time: Time) {
+ /**
+ * Dereferences RDDs that are older than rememberDuration.
+ */
+ protected[streaming] def forgetOldRDDs(time: Time) {
val keys = generatedRDDs.keys
var numForgotten = 0
keys.foreach(t => {
if (t <= (time - rememberDuration)) {
generatedRDDs.remove(t)
numForgotten += 1
- //logInfo("Forgot RDD of time " + t + " from " + this)
+ logInfo("Forgot RDD of time " + t + " from " + this)
}
})
logInfo("Forgot " + numForgotten + " RDDs from " + this)
dependencies.foreach(_.forgetOldRDDs(time))
}
+ /**
+ * Refreshes the list of checkpointed RDDs that will be saved along with checkpoint of
+ * this stream. This is an internal method that should not be called directly. This is
+ * a default implementation that saves only the file names of the checkpointed RDDs to
+ * checkpointData. Subclasses of DStream (especially those of InputDStream) may override
+ * this method to save custom checkpoint data.
+ */
+ protected[streaming] def updateCheckpointData(currentTime: Time) {
+ val newCheckpointData = generatedRDDs.filter(_._2.getCheckpointData() != null)
+ .map(x => (x._1, x._2.getCheckpointData()))
+ val oldCheckpointData = checkpointData.clone()
+ if (newCheckpointData.size > 0) {
+ checkpointData.clear()
+ checkpointData ++= newCheckpointData
+ }
+
+ dependencies.foreach(_.updateCheckpointData(currentTime))
+
+ newCheckpointData.foreach {
+ case (time, data) => { logInfo("Added checkpointed RDD for time " + time + " to stream checkpoint") }
+ }
+
+ if (newCheckpointData.size > 0) {
+ (oldCheckpointData -- newCheckpointData.keySet).foreach {
+ case (time, data) => {
+ val path = new Path(data.toString)
+ val fs = path.getFileSystem(new Configuration())
+ fs.delete(path, true)
+ logInfo("Deleted checkpoint file '" + path + "' for time " + time)
+ }
+ }
+ }
+ logInfo("Updated checkpoint data")
+ }
+
+ /**
+ * Restores the RDDs in generatedRDDs from the checkpointData. This is an internal method
+ * that should not be called directly. This is a default implementation that recreates RDDs
+ * from the checkpoint file names stored in checkpointData. Subclasses of DStream that
+ * override the updateCheckpointData() method would also need to override this method.
+ */
+ protected[streaming] def restoreCheckpointData() {
+ logInfo("Restoring checkpoint data from " + checkpointData.size + " checkpointed RDDs")
+ checkpointData.foreach {
+ case(time, data) => {
+ logInfo("Restoring checkpointed RDD for time " + time + " from file")
+ generatedRDDs += ((time, ssc.sc.objectFile[T](data.toString)))
+ }
+ }
+ dependencies.foreach(_.restoreCheckpointData())
+ }
+
@throws(classOf[IOException])
private def writeObject(oos: ObjectOutputStream) {
logDebug(this.getClass().getSimpleName + ".writeObject used")
@@ -239,6 +358,7 @@ extends Serializable with Logging {
private def readObject(ois: ObjectInputStream) {
logDebug(this.getClass().getSimpleName + ".readObject used")
ois.defaultReadObject()
+ generatedRDDs = new HashMap[Time, RDD[T]] ()
}
/**
diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala
index ac44d7a2a6..246522838a 100644
--- a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala
+++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala
@@ -4,7 +4,7 @@ import java.io.{ObjectInputStream, IOException, ObjectOutputStream}
import collection.mutable.ArrayBuffer
import spark.Logging
-final class DStreamGraph extends Serializable with Logging {
+final private[streaming] class DStreamGraph extends Serializable with Logging {
initLogging()
private val inputStreams = new ArrayBuffer[InputDStream[_]]()
@@ -15,23 +15,20 @@ final class DStreamGraph extends Serializable with Logging {
private[streaming] var rememberDuration: Time = null
private[streaming] var checkpointInProgress = false
- def start(time: Time) {
+ private[streaming] def start(time: Time) {
this.synchronized {
if (zeroTime != null) {
throw new Exception("DStream graph computation already started")
}
zeroTime = time
outputStreams.foreach(_.initialize(zeroTime))
- outputStreams.foreach(_.setRememberDuration()) // first set the rememberDuration to default values
- if (rememberDuration != null) {
- // if custom rememberDuration has been provided, set the rememberDuration
- outputStreams.foreach(_.setRememberDuration(rememberDuration))
- }
+ outputStreams.foreach(_.setRememberDuration(rememberDuration))
+ outputStreams.foreach(_.validate)
inputStreams.par.foreach(_.start())
}
}
- def stop() {
+ private[streaming] def stop() {
this.synchronized {
inputStreams.par.foreach(_.stop())
}
@@ -43,7 +40,7 @@ final class DStreamGraph extends Serializable with Logging {
}
}
- def setBatchDuration(duration: Time) {
+ private[streaming] def setBatchDuration(duration: Time) {
this.synchronized {
if (batchDuration != null) {
throw new Exception("Batch duration already set as " + batchDuration +
@@ -53,7 +50,7 @@ final class DStreamGraph extends Serializable with Logging {
batchDuration = duration
}
- def setRememberDuration(duration: Time) {
+ private[streaming] def setRememberDuration(duration: Time) {
this.synchronized {
if (rememberDuration != null) {
throw new Exception("Batch duration already set as " + batchDuration +
@@ -63,37 +60,49 @@ final class DStreamGraph extends Serializable with Logging {
rememberDuration = duration
}
- def addInputStream(inputStream: InputDStream[_]) {
+ private[streaming] def addInputStream(inputStream: InputDStream[_]) {
this.synchronized {
inputStream.setGraph(this)
inputStreams += inputStream
}
}
- def addOutputStream(outputStream: DStream[_]) {
+ private[streaming] def addOutputStream(outputStream: DStream[_]) {
this.synchronized {
outputStream.setGraph(this)
outputStreams += outputStream
}
}
- def getInputStreams() = inputStreams.toArray
+ private[streaming] def getInputStreams() = this.synchronized { inputStreams.toArray }
- def getOutputStreams() = outputStreams.toArray
+ private[streaming] def getOutputStreams() = this.synchronized { outputStreams.toArray }
- def generateRDDs(time: Time): Seq[Job] = {
+ private[streaming] def generateRDDs(time: Time): Seq[Job] = {
this.synchronized {
outputStreams.flatMap(outputStream => outputStream.generateJob(time))
}
}
- def forgetOldRDDs(time: Time) {
+ private[streaming] def forgetOldRDDs(time: Time) {
this.synchronized {
outputStreams.foreach(_.forgetOldRDDs(time))
}
}
- def validate() {
+ private[streaming] def updateCheckpointData(time: Time) {
+ this.synchronized {
+ outputStreams.foreach(_.updateCheckpointData(time))
+ }
+ }
+
+ private[streaming] def restoreCheckpointData() {
+ this.synchronized {
+ outputStreams.foreach(_.restoreCheckpointData())
+ }
+ }
+
+ private[streaming] def validate() {
this.synchronized {
assert(batchDuration != null, "Batch duration has not been set")
assert(batchDuration > Milliseconds(100), "Batch duration of " + batchDuration + " is very low")
diff --git a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala
index 1c57d5f855..6df82c0df3 100644
--- a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala
@@ -21,15 +21,19 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest](
partitioner: Partitioner
) extends DStream[(K,V)](parent.ssc) {
- if (!_windowTime.isMultipleOf(parent.slideTime))
- throw new Exception("The window duration of ReducedWindowedDStream (" + _slideTime + ") " +
- "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")")
+ assert(_windowTime.isMultipleOf(parent.slideTime),
+ "The window duration of ReducedWindowedDStream (" + _slideTime + ") " +
+ "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")"
+ )
- if (!_slideTime.isMultipleOf(parent.slideTime))
- throw new Exception("The slide duration of ReducedWindowedDStream (" + _slideTime + ") " +
- "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")")
+ assert(_slideTime.isMultipleOf(parent.slideTime),
+ "The slide duration of ReducedWindowedDStream (" + _slideTime + ") " +
+ "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")"
+ )
- @transient val reducedStream = parent.reduceByKey(reduceFunc, partitioner)
+ super.persist(StorageLevel.MEMORY_ONLY)
+
+ val reducedStream = parent.reduceByKey(reduceFunc, partitioner)
def windowTime: Time = _windowTime
@@ -37,15 +41,19 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest](
override def slideTime: Time = _slideTime
- //TODO: This is wrong. This should depend on the checkpointInterval
+ override val mustCheckpoint = true
+
override def parentRememberDuration: Time = rememberDuration + windowTime
- override def persist(
- storageLevel: StorageLevel,
- checkpointLevel: StorageLevel,
- checkpointInterval: Time): DStream[(K,V)] = {
- super.persist(storageLevel, checkpointLevel, checkpointInterval)
- reducedStream.persist(storageLevel, checkpointLevel, checkpointInterval)
+ override def persist(storageLevel: StorageLevel): DStream[(K,V)] = {
+ super.persist(storageLevel)
+ reducedStream.persist(storageLevel)
+ this
+ }
+
+ override def checkpoint(interval: Time): DStream[(K, V)] = {
+ super.checkpoint(interval)
+ reducedStream.checkpoint(interval)
this
}
diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala
index 7d52e2eddf..2b3f5a4829 100644
--- a/streaming/src/main/scala/spark/streaming/Scheduler.scala
+++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala
@@ -58,7 +58,6 @@ extends Logging {
graph.forgetOldRDDs(time)
if (ssc.checkpointInterval != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointInterval)) {
ssc.doCheckpoint(time)
- logInfo("Checkpointed at time " + time)
}
}
diff --git a/streaming/src/main/scala/spark/streaming/StateDStream.scala b/streaming/src/main/scala/spark/streaming/StateDStream.scala
index 086752ac55..0211df1343 100644
--- a/streaming/src/main/scala/spark/streaming/StateDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/StateDStream.scala
@@ -23,51 +23,14 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife
rememberPartitioner: Boolean
) extends DStream[(K, S)](parent.ssc) {
+ super.persist(StorageLevel.MEMORY_ONLY)
+
override def dependencies = List(parent)
override def slideTime = parent.slideTime
- override def getOrCompute(time: Time): Option[RDD[(K, S)]] = {
- generatedRDDs.get(time) match {
- case Some(oldRDD) => {
- if (checkpointInterval != null && time > zeroTime && (time - zeroTime).isMultipleOf(checkpointInterval) && oldRDD.dependencies.size > 0) {
- val r = oldRDD
- val oldRDDBlockIds = oldRDD.splits.map(s => "rdd:" + r.id + ":" + s.index)
- val checkpointedRDD = new BlockRDD[(K, S)](ssc.sc, oldRDDBlockIds) {
- override val partitioner = oldRDD.partitioner
- }
- generatedRDDs.update(time, checkpointedRDD)
- logInfo("Checkpointed RDD " + oldRDD.id + " of time " + time + " with its new RDD " + checkpointedRDD.id)
- Some(checkpointedRDD)
- } else {
- Some(oldRDD)
- }
- }
- case None => {
- if (isTimeValid(time)) {
- compute(time) match {
- case Some(newRDD) => {
- if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) {
- newRDD.persist(checkpointLevel)
- logInfo("Persisting " + newRDD + " to " + checkpointLevel + " at time " + time)
- } else if (storageLevel != StorageLevel.NONE) {
- newRDD.persist(storageLevel)
- logInfo("Persisting " + newRDD + " to " + storageLevel + " at time " + time)
- }
- generatedRDDs.put(time, newRDD)
- Some(newRDD)
- }
- case None => {
- None
- }
- }
- } else {
- None
- }
- }
- }
- }
-
+ override val mustCheckpoint = true
+
override def compute(validTime: Time): Option[RDD[(K, S)]] = {
// Try to get the previous state RDD
diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala
index 4a78090597..05c83d6c08 100644
--- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala
@@ -15,6 +15,8 @@ import org.apache.hadoop.io.LongWritable
import org.apache.hadoop.io.Text
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat
+import org.apache.hadoop.fs.Path
+import java.util.UUID
class StreamingContext (
sc_ : SparkContext,
@@ -26,7 +28,7 @@ class StreamingContext (
def this(master: String, frameworkName: String, sparkHome: String = null, jars: Seq[String] = Nil) =
this(new SparkContext(master, frameworkName, sparkHome, jars), null)
- def this(file: String) = this(null, Checkpoint.loadFromFile(file))
+ def this(path: String) = this(null, Checkpoint.load(path))
def this(cp_ : Checkpoint) = this(null, cp_)
@@ -51,8 +53,8 @@ class StreamingContext (
val graph: DStreamGraph = {
if (isCheckpointPresent) {
-
cp_.graph.setContext(this)
+ cp_.graph.restoreCheckpointData()
cp_.graph
} else {
new DStreamGraph()
@@ -62,7 +64,15 @@ class StreamingContext (
val nextNetworkInputStreamId = new AtomicInteger(0)
var networkInputTracker: NetworkInputTracker = null
- private[streaming] var checkpointFile: String = if (isCheckpointPresent) cp_.checkpointFile else null
+ private[streaming] var checkpointDir: String = {
+ if (isCheckpointPresent) {
+ sc.setCheckpointDir(cp_.checkpointDir, true)
+ cp_.checkpointDir
+ } else {
+ null
+ }
+ }
+
private[streaming] var checkpointInterval: Time = if (isCheckpointPresent) cp_.checkpointInterval else null
private[streaming] var receiverJobThread: Thread = null
private[streaming] var scheduler: Scheduler = null
@@ -75,9 +85,15 @@ class StreamingContext (
graph.setRememberDuration(duration)
}
- def setCheckpointDetails(file: String, interval: Time) {
- checkpointFile = file
- checkpointInterval = interval
+ def checkpoint(dir: String, interval: Time) {
+ if (dir != null) {
+ sc.setCheckpointDir(new Path(dir, "rdds-" + UUID.randomUUID.toString).toString)
+ checkpointDir = dir
+ checkpointInterval = interval
+ } else {
+ checkpointDir = null
+ checkpointInterval = null
+ }
}
private[streaming] def getInitialCheckpoint(): Checkpoint = {
@@ -181,16 +197,12 @@ class StreamingContext (
graph.addOutputStream(outputStream)
}
- def validate() {
- assert(graph != null, "Graph is null")
- graph.validate()
- }
-
/**
* This function starts the execution of the streams.
*/
def start() {
- validate()
+ assert(graph != null, "Graph is null")
+ graph.validate()
val networkInputStreams = graph.getInputStreams().filter(s => s match {
case n: NetworkInputDStream[_] => true
@@ -218,16 +230,16 @@ class StreamingContext (
if (scheduler != null) scheduler.stop()
if (networkInputTracker != null) networkInputTracker.stop()
if (receiverJobThread != null) receiverJobThread.interrupt()
- sc.stop()
+ sc.stop()
+ logInfo("StreamingContext stopped successfully")
} catch {
case e: Exception => logWarning("Error while stopping", e)
}
-
- logInfo("StreamingContext stopped")
}
def doCheckpoint(currentTime: Time) {
- new Checkpoint(this, currentTime).saveToFile(checkpointFile)
+ graph.updateCheckpointData(currentTime)
+ new Checkpoint(this, currentTime).save(checkpointDir)
}
}
diff --git a/streaming/src/main/scala/spark/streaming/Time.scala b/streaming/src/main/scala/spark/streaming/Time.scala
index 9ddb65249a..480d292d7c 100644
--- a/streaming/src/main/scala/spark/streaming/Time.scala
+++ b/streaming/src/main/scala/spark/streaming/Time.scala
@@ -25,6 +25,10 @@ case class Time(millis: Long) {
def isMultipleOf(that: Time): Boolean =
(this.millis % that.millis == 0)
+ def min(that: Time): Time = if (this < that) this else that
+
+ def max(that: Time): Time = if (this > that) this else that
+
def isZero: Boolean = (this.millis == 0)
override def toString: String = (millis.toString + " ms")
@@ -39,7 +43,7 @@ object Time {
implicit def toTime(long: Long) = Time(long)
- implicit def toLong(time: Time) = time.milliseconds
+ implicit def toLong(time: Time) = time.milliseconds
}
object Milliseconds {
diff --git a/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala b/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala
index df96a811da..21a83c0fde 100644
--- a/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala
+++ b/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala
@@ -10,20 +10,20 @@ object FileStreamWithCheckpoint {
def main(args: Array[String]) {
if (args.size != 3) {
- println("FileStreamWithCheckpoint <master> <directory> <checkpoint file>")
- println("FileStreamWithCheckpoint restart <directory> <checkpoint file>")
+ println("FileStreamWithCheckpoint <master> <directory> <checkpoint dir>")
+ println("FileStreamWithCheckpoint restart <directory> <checkpoint dir>")
System.exit(-1)
}
val directory = new Path(args(1))
- val checkpointFile = args(2)
+ val checkpointDir = args(2)
val ssc: StreamingContext = {
if (args(0) == "restart") {
// Recreated streaming context from specified checkpoint file
- new StreamingContext(checkpointFile)
+ new StreamingContext(checkpointDir)
} else {
@@ -34,7 +34,7 @@ object FileStreamWithCheckpoint {
// Create new streaming context
val ssc_ = new StreamingContext(args(0), "FileStreamWithCheckpoint")
ssc_.setBatchDuration(Seconds(1))
- ssc_.setCheckpointDetails(checkpointFile, Seconds(1))
+ ssc_.checkpoint(checkpointDir, Seconds(1))
// Setup the streaming computation
val inputStream = ssc_.textFileStream(directory.toString)
diff --git a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala
index 57fd10f0a5..750cb7445f 100644
--- a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala
+++ b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala
@@ -41,9 +41,8 @@ object TopKWordCountRaw {
val windowedCounts = union.mapPartitions(splitAndCountPartitions)
.reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(batchMs), reduces)
- windowedCounts.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2,
- Milliseconds(chkptMs))
- //windowedCounts.print() // TODO: something else?
+ windowedCounts.persist().checkpoint(Milliseconds(chkptMs))
+ //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMs))
def topK(data: Iterator[(String, Long)], k: Int): Iterator[(String, Long)] = {
val taken = new Array[(String, Long)](k)
diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala
index 0d2e62b955..865026033e 100644
--- a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala
+++ b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala
@@ -100,10 +100,9 @@ object WordCount2 {
val windowedCounts = sentences
.mapPartitions(splitAndCountPartitions)
.reduceByKeyAndWindow(add _, subtract _, Seconds(30), batchDuration, reduceTasks.toInt)
- windowedCounts.persist(StorageLevel.MEMORY_ONLY,
- StorageLevel.MEMORY_ONLY_2,
- //new StorageLevel(false, true, true, 3),
- Milliseconds(chkptMillis.toLong))
+
+ windowedCounts.persist().checkpoint(Milliseconds(chkptMillis.toLong))
+ //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMillis.toLong))
windowedCounts.foreachRDD(r => println("Element count: " + r.count()))
ssc.start()
diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala
index abfd12890f..d1ea9a9cd5 100644
--- a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala
+++ b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala
@@ -41,9 +41,9 @@ object WordCountRaw {
val windowedCounts = union.mapPartitions(splitAndCountPartitions)
.reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(batchMs), reduces)
- windowedCounts.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2,
- Milliseconds(chkptMs))
- //windowedCounts.print() // TODO: something else?
+ windowedCounts.persist().checkpoint(chkptMs)
+ //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMs))
+
windowedCounts.foreachRDD(r => println("Element count: " + r.count()))
ssc.start()
diff --git a/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala b/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala
index 9d44da2b11..6a9c8a9a69 100644
--- a/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala
+++ b/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala
@@ -57,11 +57,13 @@ object WordMax2 {
val windowedCounts = sentences
.mapPartitions(splitAndCountPartitions)
.reduceByKey(add _, reduceTasks.toInt)
- .persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2,
- Milliseconds(chkptMillis.toLong))
+ .persist()
+ .checkpoint(Milliseconds(chkptMillis.toLong))
+ //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMillis.toLong))
.reduceByKeyAndWindow(max _, Seconds(10), batchDuration, reduceTasks.toInt)
- //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2,
- // Milliseconds(chkptMillis.toLong))
+ .persist()
+ .checkpoint(Milliseconds(chkptMillis.toLong))
+ //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMillis.toLong))
windowedCounts.foreachRDD(r => println("Element count: " + r.count()))
ssc.start()
diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala
index 6dcedcf463..9fdfd50be2 100644
--- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala
@@ -2,52 +2,195 @@ package spark.streaming
import spark.streaming.StreamingContext._
import java.io.File
+import runtime.RichInt
+import org.scalatest.BeforeAndAfter
+import org.apache.commons.io.FileUtils
+import collection.mutable.{SynchronizedBuffer, ArrayBuffer}
+import util.{Clock, ManualClock}
-class CheckpointSuite extends TestSuiteBase {
+class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
+
+ before {
+ FileUtils.deleteDirectory(new File(checkpointDir))
+ }
+
+ after {
+ FileUtils.deleteDirectory(new File(checkpointDir))
+ }
+
+ override def framework = "CheckpointSuite"
+
+ override def batchDuration = Milliseconds(500)
+
+ override def checkpointDir = "checkpoint"
+
+ override def checkpointInterval = batchDuration
+
+ override def actuallyWait = true
+
+ test("basic stream+rdd recovery") {
+
+ assert(batchDuration === Milliseconds(500), "batchDuration for this test must be 1 second")
+ assert(checkpointInterval === batchDuration, "checkpointInterval for this test much be same as batchDuration")
+
+ System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock")
+
+ val stateStreamCheckpointInterval = Seconds(2)
+
+ // this ensure checkpointing occurs at least once
+ val firstNumBatches = (stateStreamCheckpointInterval.millis / batchDuration.millis) * 2
+ val secondNumBatches = firstNumBatches
+
+ // Setup the streams
+ val input = (1 to 10).map(_ => Seq("a")).toSeq
+ val operation = (st: DStream[String]) => {
+ val updateFunc = (values: Seq[Int], state: Option[RichInt]) => {
+ Some(new RichInt(values.foldLeft(0)(_ + _) + state.map(_.self).getOrElse(0)))
+ }
+ st.map(x => (x, 1))
+ .updateStateByKey[RichInt](updateFunc)
+ .checkpoint(stateStreamCheckpointInterval)
+ .map(t => (t._1, t._2.self))
+ }
+ val ssc = setupStreams(input, operation)
+ val stateStream = ssc.graph.getOutputStreams().head.dependencies.head.dependencies.head
+
+ // Run till a time such that at least one RDD in the stream should have been checkpointed
+ ssc.start()
+ val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
+ advanceClock(clock, firstNumBatches)
+
+ // Check whether some RDD has been checkpointed or not
+ logInfo("Checkpoint data of state stream = \n[" + stateStream.checkpointData.mkString(",\n") + "]")
+ assert(!stateStream.checkpointData.isEmpty, "No checkpointed RDDs in state stream before first failure")
+ stateStream.checkpointData.foreach {
+ case (time, data) => {
+ val file = new File(data.toString)
+ assert(file.exists(), "Checkpoint file '" + file +"' for time " + time + " for state stream before first failure does not exist")
+ }
+ }
+
+ // Run till a further time such that previous checkpoint files in the stream would be deleted
+ // and check whether the earlier checkpoint files are deleted
+ val checkpointFiles = stateStream.checkpointData.map(x => new File(x._2.toString))
+ advanceClock(clock, secondNumBatches)
+ checkpointFiles.foreach(file => assert(!file.exists, "Checkpoint file '" + file + "' was not deleted"))
+
+ // Restart stream computation using the checkpoint file and check whether
+ // checkpointed RDDs have been restored or not
+ ssc.stop()
+ val sscNew = new StreamingContext(checkpointDir)
+ val stateStreamNew = sscNew.graph.getOutputStreams().head.dependencies.head.dependencies.head
+ logInfo("Restored data of state stream = \n[" + stateStreamNew.generatedRDDs.mkString("\n") + "]")
+ assert(!stateStreamNew.generatedRDDs.isEmpty, "No restored RDDs in state stream after recovery from first failure")
+
+
+ // Run one batch to generate a new checkpoint file
+ sscNew.start()
+ val clockNew = sscNew.scheduler.clock.asInstanceOf[ManualClock]
+ advanceClock(clockNew, 1)
+
+ // Check whether some RDD is present in the checkpoint data or not
+ assert(!stateStreamNew.checkpointData.isEmpty, "No checkpointed RDDs in state stream before second failure")
+ stateStream.checkpointData.foreach {
+ case (time, data) => {
+ val file = new File(data.toString)
+ assert(file.exists(), "Checkpoint file '" + file +"' for time " + time + " for state stream before seconds failure does not exist")
+ }
+ }
+
+ // Restart stream computation from the new checkpoint file to see whether that file has
+ // correct checkpoint data
+ sscNew.stop()
+ val sscNewNew = new StreamingContext(checkpointDir)
+ val stateStreamNewNew = sscNew.graph.getOutputStreams().head.dependencies.head.dependencies.head
+ logInfo("Restored data of state stream = \n[" + stateStreamNew.generatedRDDs.mkString("\n") + "]")
+ assert(!stateStreamNewNew.generatedRDDs.isEmpty, "No restored RDDs in state stream after recovery from second failure")
+ sscNewNew.start()
+ advanceClock(sscNewNew.scheduler.clock.asInstanceOf[ManualClock], 1)
+ sscNewNew.stop()
+ }
+
+ test("map and reduceByKey") {
+ testCheckpointedOperation(
+ Seq( Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq() ),
+ (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _),
+ Seq( Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq() ),
+ 3
+ )
+ }
+
+ test("reduceByKeyAndWindowInv") {
+ val n = 10
+ val w = 4
+ val input = (1 to n).map(x => Seq("a")).toSeq
+ val output = Seq(Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 3))) ++ (1 to (n - w + 1)).map(x => Seq(("a", 4)))
+ val operation = (st: DStream[String]) => {
+ st.map(x => (x, 1)).reduceByKeyAndWindow(_ + _, _ - _, batchDuration * 4, batchDuration)
+ }
+ for (i <- Seq(2, 3, 4)) {
+ testCheckpointedOperation(input, operation, output, i)
+ }
+ }
+
+ test("updateStateByKey") {
+ val input = (1 to 10).map(_ => Seq("a")).toSeq
+ val output = (1 to 10).map(x => Seq(("a", x))).toSeq
+ val operation = (st: DStream[String]) => {
+ val updateFunc = (values: Seq[Int], state: Option[RichInt]) => {
+ Some(new RichInt(values.foldLeft(0)(_ + _) + state.map(_.self).getOrElse(0)))
+ }
+ st.map(x => (x, 1))
+ .updateStateByKey[RichInt](updateFunc)
+ .checkpoint(Seconds(2))
+ .map(t => (t._1, t._2.self))
+ }
+ for (i <- Seq(2, 3, 4)) {
+ testCheckpointedOperation(input, operation, output, i)
+ }
+ }
- override def framework() = "CheckpointSuite"
- override def checkpointFile() = "checkpoint"
def testCheckpointedOperation[U: ClassManifest, V: ClassManifest](
- input: Seq[Seq[U]],
- operation: DStream[U] => DStream[V],
- expectedOutput: Seq[Seq[V]],
- useSet: Boolean = false
- ) {
+ input: Seq[Seq[U]],
+ operation: DStream[U] => DStream[V],
+ expectedOutput: Seq[Seq[V]],
+ initialNumBatches: Int
+ ) {
// Current code assumes that:
// number of inputs = number of outputs = number of batches to be run
-
val totalNumBatches = input.size
- val initialNumBatches = input.size / 2
val nextNumBatches = totalNumBatches - initialNumBatches
val initialNumExpectedOutputs = initialNumBatches
+ val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs
// Do half the computation (half the number of batches), create checkpoint file and quit
+
val ssc = setupStreams[U, V](input, operation)
val output = runStreams[V](ssc, initialNumBatches, initialNumExpectedOutputs)
- verifyOutput[V](output, expectedOutput.take(initialNumBatches), useSet)
+ verifyOutput[V](output, expectedOutput.take(initialNumBatches), true)
Thread.sleep(1000)
// Restart and complete the computation from checkpoint file
- val sscNew = new StreamingContext(checkpointFile)
- sscNew.setCheckpointDetails(null, null)
- val outputNew = runStreams[V](sscNew, nextNumBatches, expectedOutput.size)
- verifyOutput[V](outputNew, expectedOutput, useSet)
-
- new File(checkpointFile).delete()
- new File(checkpointFile + ".bk").delete()
- new File("." + checkpointFile + ".crc").delete()
- new File("." + checkpointFile + ".bk.crc").delete()
+ logInfo(
+ "\n-------------------------------------------\n" +
+ " Restarting stream computation " +
+ "\n-------------------------------------------\n"
+ )
+ val sscNew = new StreamingContext(checkpointDir)
+ val outputNew = runStreams[V](sscNew, nextNumBatches, nextNumExpectedOutputs)
+ verifyOutput[V](outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), true)
}
- test("simple per-batch operation") {
- testCheckpointedOperation(
- Seq( Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq() ),
- (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _),
- Seq( Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq() ),
- true
- )
+ def advanceClock(clock: ManualClock, numBatches: Long) {
+ logInfo("Manual clock before advancing = " + clock.time)
+ for (i <- 1 to numBatches.toInt) {
+ clock.addToTime(batchDuration.milliseconds)
+ Thread.sleep(batchDuration.milliseconds)
+ }
+ logInfo("Manual clock after advancing = " + clock.time)
+ Thread.sleep(batchDuration.milliseconds)
}
} \ No newline at end of file
diff --git a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala
index c9bc454f91..b8c7f99603 100644
--- a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala
+++ b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala
@@ -5,10 +5,16 @@ import util.ManualClock
import collection.mutable.ArrayBuffer
import org.scalatest.FunSuite
import collection.mutable.SynchronizedBuffer
+import java.io.{ObjectInputStream, IOException}
+
+/**
+ * This is a input stream just for the testsuites. This is equivalent to a checkpointable,
+ * replayable, reliable message queue like Kafka. It requires a sequence as input, and
+ * returns the i_th element at the i_th batch unde manual clock.
+ */
class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, input: Seq[Seq[T]], numPartitions: Int)
extends InputDStream[T](ssc_) {
- var currentIndex = 0
def start() {}
@@ -23,34 +29,49 @@ class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, input: Seq[Seq[
ssc.sc.makeRDD(Seq[T](), numPartitions)
}
logInfo("Created RDD " + rdd.id)
- //currentIndex += 1
Some(rdd)
}
}
+/**
+ * This is a output stream just for the testsuites. All the output is collected into a
+ * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint.
+ */
class TestOutputStream[T: ClassManifest](parent: DStream[T], val output: ArrayBuffer[Seq[T]])
extends PerRDDForEachDStream[T](parent, (rdd: RDD[T], t: Time) => {
val collected = rdd.collect()
output += collected
- })
+ }) {
+
+ // This is to clear the output buffer every it is read from a checkpoint
+ @throws(classOf[IOException])
+ private def readObject(ois: ObjectInputStream) {
+ ois.defaultReadObject()
+ output.clear()
+ }
+}
+/**
+ * This is the base trait for Spark Streaming testsuites. This provides basic functionality
+ * to run user-defined set of input on user-defined stream operations, and verify the output.
+ */
trait TestSuiteBase extends FunSuite with Logging {
- System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock")
+ def framework = "TestSuiteBase"
- def framework() = "TestSuiteBase"
+ def master = "local[2]"
- def master() = "local[2]"
+ def batchDuration = Seconds(1)
- def batchDuration() = Seconds(1)
+ def checkpointDir = null.asInstanceOf[String]
- def checkpointFile() = null.asInstanceOf[String]
+ def checkpointInterval = batchDuration
- def checkpointInterval() = batchDuration
+ def numInputPartitions = 2
- def numInputPartitions() = 2
+ def maxWaitTimeMillis = 10000
- def maxWaitTimeMillis() = 10000
+ def actuallyWait = false
def setupStreams[U: ClassManifest, V: ClassManifest](
input: Seq[Seq[U]],
@@ -60,8 +81,8 @@ trait TestSuiteBase extends FunSuite with Logging {
// Create StreamingContext
val ssc = new StreamingContext(master, framework)
ssc.setBatchDuration(batchDuration)
- if (checkpointFile != null) {
- ssc.setCheckpointDetails(checkpointFile, checkpointInterval())
+ if (checkpointDir != null) {
+ ssc.checkpoint(checkpointDir, checkpointInterval)
}
// Setup the stream computation
@@ -82,8 +103,8 @@ trait TestSuiteBase extends FunSuite with Logging {
// Create StreamingContext
val ssc = new StreamingContext(master, framework)
ssc.setBatchDuration(batchDuration)
- if (checkpointFile != null) {
- ssc.setCheckpointDetails(checkpointFile, checkpointInterval())
+ if (checkpointDir != null) {
+ ssc.checkpoint(checkpointDir, checkpointInterval)
}
// Setup the stream computation
@@ -97,12 +118,19 @@ trait TestSuiteBase extends FunSuite with Logging {
ssc
}
+ /**
+ * Runs the streams set up in `ssc` on manual clock for `numBatches` batches and
+ * returns the collected output. It will wait until `numExpectedOutput` number of
+ * output data has been collected or timeout (set by `maxWaitTimeMillis`) is reached.
+ */
def runStreams[V: ClassManifest](
ssc: StreamingContext,
numBatches: Int,
numExpectedOutput: Int
): Seq[Seq[V]] = {
+ System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock")
+
assert(numBatches > 0, "Number of batches to run stream computation is zero")
assert(numExpectedOutput > 0, "Number of expected outputs after " + numBatches + " is zero")
logInfo("numBatches = " + numBatches + ", numExpectedOutput = " + numExpectedOutput)
@@ -118,7 +146,15 @@ trait TestSuiteBase extends FunSuite with Logging {
// Advance manual clock
val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
logInfo("Manual clock before advancing = " + clock.time)
- clock.addToTime(numBatches * batchDuration.milliseconds)
+ if (actuallyWait) {
+ for (i <- 1 to numBatches) {
+ logInfo("Actually waiting for " + batchDuration)
+ clock.addToTime(batchDuration.milliseconds)
+ Thread.sleep(batchDuration.milliseconds)
+ }
+ } else {
+ clock.addToTime(numBatches * batchDuration.milliseconds)
+ }
logInfo("Manual clock after advancing = " + clock.time)
// Wait until expected number of output items have been generated
diff --git a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala
index d7d8d5bd36..3e20e16708 100644
--- a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala
+++ b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala
@@ -5,11 +5,11 @@ import collection.mutable.ArrayBuffer
class WindowOperationsSuite extends TestSuiteBase {
- override def framework() = "WindowOperationsSuite"
+ override def framework = "WindowOperationsSuite"
- override def maxWaitTimeMillis() = 20000
+ override def maxWaitTimeMillis = 20000
- override def batchDuration() = Seconds(1)
+ override def batchDuration = Seconds(1)
val largerSlideInput = Seq(
Seq(("a", 1)),
@@ -283,7 +283,9 @@ class WindowOperationsSuite extends TestSuiteBase {
test("reduceByKeyAndWindowInv - " + name) {
val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt
val operation = (s: DStream[(String, Int)]) => {
- s.reduceByKeyAndWindow(_ + _, _ - _, windowTime, slideTime).persist()
+ s.reduceByKeyAndWindow(_ + _, _ - _, windowTime, slideTime)
+ .persist()
+ .checkpoint(Seconds(100)) // Large value to avoid effect of RDD checkpointing
}
testOperation(input, operation, expectedOutput, numBatches, true)
}