aboutsummaryrefslogtreecommitdiff
path: root/streaming
diff options
context:
space:
mode:
authorTathagata Das <tathagata.das1565@gmail.com>2012-11-04 12:12:06 -0800
committerTathagata Das <tathagata.das1565@gmail.com>2012-11-04 12:12:06 -0800
commitd1542387891018914fdd6b647f17f0b05acdd40e (patch)
tree51d0bdbd9014daa6f6f87bb9547acdf110300463 /streaming
parent596154eabe51961733789a18a47067748fb72e8e (diff)
downloadspark-d1542387891018914fdd6b647f17f0b05acdd40e.tar.gz
spark-d1542387891018914fdd6b647f17f0b05acdd40e.tar.bz2
spark-d1542387891018914fdd6b647f17f0b05acdd40e.zip
Made checkpointing of dstream graph to work with checkpointing of RDDs. For streams requiring checkpointing of its RDD, the default checkpoint interval is set to 10 seconds.
Diffstat (limited to 'streaming')
-rw-r--r--streaming/src/main/scala/spark/streaming/Checkpoint.scala83
-rw-r--r--streaming/src/main/scala/spark/streaming/DStream.scala156
-rw-r--r--streaming/src/main/scala/spark/streaming/DStreamGraph.scala7
-rw-r--r--streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala36
-rw-r--r--streaming/src/main/scala/spark/streaming/StateDStream.scala45
-rw-r--r--streaming/src/main/scala/spark/streaming/StreamingContext.scala38
-rw-r--r--streaming/src/main/scala/spark/streaming/Time.scala4
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala10
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala5
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/WordCount2.scala7
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala6
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/WordMax2.scala10
-rw-r--r--streaming/src/test/scala/spark/streaming/CheckpointSuite.scala77
-rw-r--r--streaming/src/test/scala/spark/streaming/TestSuiteBase.scala37
-rw-r--r--streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala4
15 files changed, 337 insertions, 188 deletions
diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala
index 83a43d15cb..cf04c7031e 100644
--- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala
@@ -1,6 +1,6 @@
package spark.streaming
-import spark.Utils
+import spark.{Logging, Utils}
import org.apache.hadoop.fs.{FileUtil, Path}
import org.apache.hadoop.conf.Configuration
@@ -8,13 +8,14 @@ import org.apache.hadoop.conf.Configuration
import java.io.{InputStream, ObjectStreamClass, ObjectInputStream, ObjectOutputStream}
-class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) extends Serializable {
+class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time)
+ extends Logging with Serializable {
val master = ssc.sc.master
val framework = ssc.sc.jobName
val sparkHome = ssc.sc.sparkHome
val jars = ssc.sc.jars
val graph = ssc.graph
- val checkpointFile = ssc.checkpointFile
+ val checkpointDir = ssc.checkpointDir
val checkpointInterval = ssc.checkpointInterval
validate()
@@ -24,22 +25,25 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) ext
assert(framework != null, "Checkpoint.framework is null")
assert(graph != null, "Checkpoint.graph is null")
assert(checkpointTime != null, "Checkpoint.checkpointTime is null")
+ logInfo("Checkpoint for time " + checkpointTime + " validated")
}
- def saveToFile(file: String = checkpointFile) {
- val path = new Path(file)
+ def save(path: String) {
+ val file = new Path(path, "graph")
val conf = new Configuration()
- val fs = path.getFileSystem(conf)
- if (fs.exists(path)) {
- val bkPath = new Path(path.getParent, path.getName + ".bk")
- FileUtil.copy(fs, path, fs, bkPath, true, true, conf)
- //logInfo("Moved existing checkpoint file to " + bkPath)
+ val fs = file.getFileSystem(conf)
+ logDebug("Saved checkpoint for time " + checkpointTime + " to file '" + file + "'")
+ if (fs.exists(file)) {
+ val bkFile = new Path(file.getParent, file.getName + ".bk")
+ FileUtil.copy(fs, file, fs, bkFile, true, true, conf)
+ logDebug("Moved existing checkpoint file to " + bkFile)
}
- val fos = fs.create(path)
+ val fos = fs.create(file)
val oos = new ObjectOutputStream(fos)
oos.writeObject(this)
oos.close()
fs.close()
+ logInfo("Saved checkpoint for time " + checkpointTime + " to file '" + file + "'")
}
def toBytes(): Array[Byte] = {
@@ -50,30 +54,41 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) ext
object Checkpoint {
- def loadFromFile(file: String): Checkpoint = {
- try {
- val path = new Path(file)
- val conf = new Configuration()
- val fs = path.getFileSystem(conf)
- if (!fs.exists(path)) {
- throw new Exception("Checkpoint file '" + file + "' does not exist")
+ def load(path: String): Checkpoint = {
+
+ val fs = new Path(path).getFileSystem(new Configuration())
+ val attempts = Seq(new Path(path), new Path(path, "graph"), new Path(path, "graph.bk"))
+ var lastException: Exception = null
+ var lastExceptionFile: String = null
+
+ attempts.foreach(file => {
+ if (fs.exists(file)) {
+ try {
+ val fis = fs.open(file)
+ // ObjectInputStream uses the last defined user-defined class loader in the stack
+ // to find classes, which maybe the wrong class loader. Hence, a inherited version
+ // of ObjectInputStream is used to explicitly use the current thread's default class
+ // loader to find and load classes. This is a well know Java issue and has popped up
+ // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627)
+ val ois = new ObjectInputStreamWithLoader(fis, Thread.currentThread().getContextClassLoader)
+ val cp = ois.readObject.asInstanceOf[Checkpoint]
+ ois.close()
+ fs.close()
+ cp.validate()
+ println("Checkpoint successfully loaded from file " + file)
+ return cp
+ } catch {
+ case e: Exception =>
+ lastException = e
+ lastExceptionFile = file.toString
+ }
}
- val fis = fs.open(path)
- // ObjectInputStream uses the last defined user-defined class loader in the stack
- // to find classes, which maybe the wrong class loader. Hence, a inherited version
- // of ObjectInputStream is used to explicitly use the current thread's default class
- // loader to find and load classes. This is a well know Java issue and has popped up
- // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627)
- val ois = new ObjectInputStreamWithLoader(fis, Thread.currentThread().getContextClassLoader)
- val cp = ois.readObject.asInstanceOf[Checkpoint]
- ois.close()
- fs.close()
- cp.validate()
- cp
- } catch {
- case e: Exception =>
- e.printStackTrace()
- throw new Exception("Could not load checkpoint file '" + file + "'", e)
+ })
+
+ if (lastException == null) {
+ throw new Exception("Could not load checkpoint from path '" + path + "'")
+ } else {
+ throw new Exception("Error loading checkpoint from path '" + lastExceptionFile + "'", lastException)
}
}
diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala
index a4921bb1a2..de51c5d34a 100644
--- a/streaming/src/main/scala/spark/streaming/DStream.scala
+++ b/streaming/src/main/scala/spark/streaming/DStream.scala
@@ -13,6 +13,7 @@ import scala.collection.mutable.HashMap
import java.util.concurrent.ArrayBlockingQueue
import java.io.{ObjectInputStream, IOException, ObjectOutputStream}
import scala.Some
+import collection.mutable
abstract class DStream[T: ClassManifest] (@transient var ssc: StreamingContext)
extends Serializable with Logging {
@@ -41,53 +42,55 @@ extends Serializable with Logging {
*/
// RDDs generated, marked as protected[streaming] so that testsuites can access it
- protected[streaming] val generatedRDDs = new HashMap[Time, RDD[T]] ()
+ protected[streaming] var generatedRDDs = new HashMap[Time, RDD[T]] ()
// Time zero for the DStream
- protected var zeroTime: Time = null
+ protected[streaming] var zeroTime: Time = null
// Duration for which the DStream will remember each RDD created
- protected var rememberDuration: Time = null
+ protected[streaming] var rememberDuration: Time = null
// Storage level of the RDDs in the stream
- protected var storageLevel: StorageLevel = StorageLevel.NONE
+ protected[streaming] var storageLevel: StorageLevel = StorageLevel.NONE
- // Checkpoint level and checkpoint interval
- protected var checkpointLevel: StorageLevel = StorageLevel.NONE // NONE means don't checkpoint
- protected var checkpointInterval: Time = null
+ // Checkpoint details
+ protected[streaming] val mustCheckpoint = false
+ protected[streaming] var checkpointInterval: Time = null
+ protected[streaming] val checkpointData = new HashMap[Time, Any]()
// Reference to whole DStream graph
- protected var graph: DStreamGraph = null
+ protected[streaming] var graph: DStreamGraph = null
def isInitialized = (zeroTime != null)
// Duration for which the DStream requires its parent DStream to remember each RDD created
def parentRememberDuration = rememberDuration
- // Change this RDD's storage level
- def persist(
- storageLevel: StorageLevel,
- checkpointLevel: StorageLevel,
- checkpointInterval: Time): DStream[T] = {
- if (this.storageLevel != StorageLevel.NONE && this.storageLevel != storageLevel) {
- // TODO: not sure this is necessary for DStreams
+ // Set caching level for the RDDs created by this DStream
+ def persist(level: StorageLevel): DStream[T] = {
+ if (this.isInitialized) {
throw new UnsupportedOperationException(
- "Cannot change storage level of an DStream after it was already assigned a level")
+ "Cannot change storage level of an DStream after streaming context has started")
}
- this.storageLevel = storageLevel
- this.checkpointLevel = checkpointLevel
- this.checkpointInterval = checkpointInterval
+ this.storageLevel = level
this
}
- // Set caching level for the RDDs created by this DStream
- def persist(newLevel: StorageLevel): DStream[T] = persist(newLevel, StorageLevel.NONE, null)
-
def persist(): DStream[T] = persist(StorageLevel.MEMORY_ONLY)
// Turn on the default caching level for this RDD
def cache(): DStream[T] = persist()
+ def checkpoint(interval: Time): DStream[T] = {
+ if (isInitialized) {
+ throw new UnsupportedOperationException(
+ "Cannot change checkpoint interval of an DStream after streaming context has started")
+ }
+ persist()
+ checkpointInterval = interval
+ this
+ }
+
/**
* This method initializes the DStream by setting the "zero" time, based on which
* the validity of future times is calculated. This method also recursively initializes
@@ -99,7 +102,67 @@ extends Serializable with Logging {
+ ", cannot initialize it again to " + time)
}
zeroTime = time
+
+ // Set the checkpoint interval to be slideTime or 10 seconds, which ever is larger
+ if (mustCheckpoint && checkpointInterval == null) {
+ checkpointInterval = slideTime.max(Seconds(10))
+ logInfo("Checkpoint interval automatically set to " + checkpointInterval)
+ }
+
+ // Set the minimum value of the rememberDuration if not already set
+ var minRememberDuration = slideTime
+ if (checkpointInterval != null && minRememberDuration <= checkpointInterval) {
+ minRememberDuration = checkpointInterval + slideTime
+ }
+ if (rememberDuration == null || rememberDuration < minRememberDuration) {
+ rememberDuration = minRememberDuration
+ }
+
+ // Initialize the dependencies
dependencies.foreach(_.initialize(zeroTime))
+ }
+
+ protected[streaming] def validate() {
+ assert(
+ !mustCheckpoint || checkpointInterval != null,
+ "The checkpoint interval for " + this.getClass.getSimpleName + " has not been set. " +
+ " Please use DStream.checkpoint() to set the interval."
+ )
+
+ assert(
+ checkpointInterval == null || checkpointInterval >= slideTime,
+ "The checkpoint interval for " + this.getClass.getSimpleName + " has been set to " +
+ checkpointInterval + " which is lower than its slide time (" + slideTime + "). " +
+ "Please set it to at least " + slideTime + "."
+ )
+
+ assert(
+ checkpointInterval == null || checkpointInterval.isMultipleOf(slideTime),
+ "The checkpoint interval for " + this.getClass.getSimpleName + " has been set to " +
+ checkpointInterval + " which not a multiple of its slide time (" + slideTime + "). " +
+ "Please set it to a multiple " + slideTime + "."
+ )
+
+ assert(
+ checkpointInterval == null || storageLevel != StorageLevel.NONE,
+ "" + this.getClass.getSimpleName + " has been marked for checkpointing but the storage " +
+ "level has not been set to enable persisting. Please use DStream.persist() to set the " +
+ "storage level to use memory for better checkpointing performance."
+ )
+
+ assert(
+ checkpointInterval == null || rememberDuration > checkpointInterval,
+ "The remember duration for " + this.getClass.getSimpleName + " has been set to " +
+ rememberDuration + " which is not more than the checkpoint interval (" +
+ checkpointInterval + "). Please set it to higher than " + checkpointInterval + "."
+ )
+
+ dependencies.foreach(_.validate())
+
+ logInfo("Slide time = " + slideTime)
+ logInfo("Storage level = " + storageLevel)
+ logInfo("Checkpoint interval = " + checkpointInterval)
+ logInfo("Remember duration = " + rememberDuration)
logInfo("Initialized " + this)
}
@@ -120,17 +183,12 @@ extends Serializable with Logging {
dependencies.foreach(_.setGraph(graph))
}
- protected[streaming] def setRememberDuration(duration: Time = slideTime) {
- if (duration == null) {
- throw new Exception("Duration for remembering RDDs cannot be set to null for " + this)
- } else if (rememberDuration != null && duration < rememberDuration) {
- logWarning("Duration for remembering RDDs cannot be reduced from " + rememberDuration
- + " to " + duration + " for " + this)
- } else {
+ protected[streaming] def setRememberDuration(duration: Time) {
+ if (duration != null && duration > rememberDuration) {
rememberDuration = duration
- dependencies.foreach(_.setRememberDuration(parentRememberDuration))
logInfo("Duration for remembering RDDs set to " + rememberDuration + " for " + this)
}
+ dependencies.foreach(_.setRememberDuration(parentRememberDuration))
}
/** This method checks whether the 'time' is valid wrt slideTime for generating RDD */
@@ -163,12 +221,13 @@ extends Serializable with Logging {
if (isTimeValid(time)) {
compute(time) match {
case Some(newRDD) =>
- if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) {
- newRDD.persist(checkpointLevel)
- logInfo("Persisting " + newRDD + " to " + checkpointLevel + " at time " + time)
- } else if (storageLevel != StorageLevel.NONE) {
+ if (storageLevel != StorageLevel.NONE) {
newRDD.persist(storageLevel)
- logInfo("Persisting " + newRDD + " to " + storageLevel + " at time " + time)
+ logInfo("Persisting RDD for time " + time + " to " + storageLevel + " at time " + time)
+ }
+ if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) {
+ newRDD.checkpoint()
+ logInfo("Marking RDD for time " + time + " for checkpointing at time " + time)
}
generatedRDDs.put(time, newRDD)
Some(newRDD)
@@ -199,7 +258,7 @@ extends Serializable with Logging {
}
}
- def forgetOldRDDs(time: Time) {
+ protected[streaming] def forgetOldRDDs(time: Time) {
val keys = generatedRDDs.keys
var numForgotten = 0
keys.foreach(t => {
@@ -213,12 +272,35 @@ extends Serializable with Logging {
dependencies.foreach(_.forgetOldRDDs(time))
}
+ protected[streaming] def updateCheckpointData() {
+ checkpointData.clear()
+ generatedRDDs.foreach {
+ case(time, rdd) => {
+ logDebug("Adding checkpointed RDD for time " + time)
+ val data = rdd.getCheckpointData()
+ if (data != null) {
+ checkpointData += ((time, data))
+ }
+ }
+ }
+ }
+
+ protected[streaming] def restoreCheckpointData() {
+ checkpointData.foreach {
+ case(time, data) => {
+ logInfo("Restoring checkpointed RDD for time " + time)
+ generatedRDDs += ((time, ssc.sc.objectFile[T](data.toString)))
+ }
+ }
+ }
+
@throws(classOf[IOException])
private def writeObject(oos: ObjectOutputStream) {
logDebug(this.getClass().getSimpleName + ".writeObject used")
if (graph != null) {
graph.synchronized {
if (graph.checkpointInProgress) {
+ updateCheckpointData()
oos.defaultWriteObject()
} else {
val msg = "Object of " + this.getClass.getName + " is being serialized " +
@@ -239,6 +321,8 @@ extends Serializable with Logging {
private def readObject(ois: ObjectInputStream) {
logDebug(this.getClass().getSimpleName + ".readObject used")
ois.defaultReadObject()
+ generatedRDDs = new HashMap[Time, RDD[T]] ()
+ restoreCheckpointData()
}
/**
diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala
index ac44d7a2a6..f8922ec790 100644
--- a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala
+++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala
@@ -22,11 +22,8 @@ final class DStreamGraph extends Serializable with Logging {
}
zeroTime = time
outputStreams.foreach(_.initialize(zeroTime))
- outputStreams.foreach(_.setRememberDuration()) // first set the rememberDuration to default values
- if (rememberDuration != null) {
- // if custom rememberDuration has been provided, set the rememberDuration
- outputStreams.foreach(_.setRememberDuration(rememberDuration))
- }
+ outputStreams.foreach(_.setRememberDuration(rememberDuration))
+ outputStreams.foreach(_.validate)
inputStreams.par.foreach(_.start())
}
}
diff --git a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala
index 1c57d5f855..6df82c0df3 100644
--- a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala
@@ -21,15 +21,19 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest](
partitioner: Partitioner
) extends DStream[(K,V)](parent.ssc) {
- if (!_windowTime.isMultipleOf(parent.slideTime))
- throw new Exception("The window duration of ReducedWindowedDStream (" + _slideTime + ") " +
- "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")")
+ assert(_windowTime.isMultipleOf(parent.slideTime),
+ "The window duration of ReducedWindowedDStream (" + _slideTime + ") " +
+ "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")"
+ )
- if (!_slideTime.isMultipleOf(parent.slideTime))
- throw new Exception("The slide duration of ReducedWindowedDStream (" + _slideTime + ") " +
- "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")")
+ assert(_slideTime.isMultipleOf(parent.slideTime),
+ "The slide duration of ReducedWindowedDStream (" + _slideTime + ") " +
+ "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")"
+ )
- @transient val reducedStream = parent.reduceByKey(reduceFunc, partitioner)
+ super.persist(StorageLevel.MEMORY_ONLY)
+
+ val reducedStream = parent.reduceByKey(reduceFunc, partitioner)
def windowTime: Time = _windowTime
@@ -37,15 +41,19 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest](
override def slideTime: Time = _slideTime
- //TODO: This is wrong. This should depend on the checkpointInterval
+ override val mustCheckpoint = true
+
override def parentRememberDuration: Time = rememberDuration + windowTime
- override def persist(
- storageLevel: StorageLevel,
- checkpointLevel: StorageLevel,
- checkpointInterval: Time): DStream[(K,V)] = {
- super.persist(storageLevel, checkpointLevel, checkpointInterval)
- reducedStream.persist(storageLevel, checkpointLevel, checkpointInterval)
+ override def persist(storageLevel: StorageLevel): DStream[(K,V)] = {
+ super.persist(storageLevel)
+ reducedStream.persist(storageLevel)
+ this
+ }
+
+ override def checkpoint(interval: Time): DStream[(K, V)] = {
+ super.checkpoint(interval)
+ reducedStream.checkpoint(interval)
this
}
diff --git a/streaming/src/main/scala/spark/streaming/StateDStream.scala b/streaming/src/main/scala/spark/streaming/StateDStream.scala
index 086752ac55..0211df1343 100644
--- a/streaming/src/main/scala/spark/streaming/StateDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/StateDStream.scala
@@ -23,51 +23,14 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife
rememberPartitioner: Boolean
) extends DStream[(K, S)](parent.ssc) {
+ super.persist(StorageLevel.MEMORY_ONLY)
+
override def dependencies = List(parent)
override def slideTime = parent.slideTime
- override def getOrCompute(time: Time): Option[RDD[(K, S)]] = {
- generatedRDDs.get(time) match {
- case Some(oldRDD) => {
- if (checkpointInterval != null && time > zeroTime && (time - zeroTime).isMultipleOf(checkpointInterval) && oldRDD.dependencies.size > 0) {
- val r = oldRDD
- val oldRDDBlockIds = oldRDD.splits.map(s => "rdd:" + r.id + ":" + s.index)
- val checkpointedRDD = new BlockRDD[(K, S)](ssc.sc, oldRDDBlockIds) {
- override val partitioner = oldRDD.partitioner
- }
- generatedRDDs.update(time, checkpointedRDD)
- logInfo("Checkpointed RDD " + oldRDD.id + " of time " + time + " with its new RDD " + checkpointedRDD.id)
- Some(checkpointedRDD)
- } else {
- Some(oldRDD)
- }
- }
- case None => {
- if (isTimeValid(time)) {
- compute(time) match {
- case Some(newRDD) => {
- if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) {
- newRDD.persist(checkpointLevel)
- logInfo("Persisting " + newRDD + " to " + checkpointLevel + " at time " + time)
- } else if (storageLevel != StorageLevel.NONE) {
- newRDD.persist(storageLevel)
- logInfo("Persisting " + newRDD + " to " + storageLevel + " at time " + time)
- }
- generatedRDDs.put(time, newRDD)
- Some(newRDD)
- }
- case None => {
- None
- }
- }
- } else {
- None
- }
- }
- }
- }
-
+ override val mustCheckpoint = true
+
override def compute(validTime: Time): Option[RDD[(K, S)]] = {
// Try to get the previous state RDD
diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala
index b3148eaa97..3838e84113 100644
--- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala
@@ -15,6 +15,8 @@ import org.apache.hadoop.io.LongWritable
import org.apache.hadoop.io.Text
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat
+import org.apache.hadoop.fs.Path
+import java.util.UUID
class StreamingContext (
sc_ : SparkContext,
@@ -26,7 +28,7 @@ class StreamingContext (
def this(master: String, frameworkName: String, sparkHome: String = null, jars: Seq[String] = Nil) =
this(new SparkContext(master, frameworkName, sparkHome, jars), null)
- def this(file: String) = this(null, Checkpoint.loadFromFile(file))
+ def this(path: String) = this(null, Checkpoint.load(path))
def this(cp_ : Checkpoint) = this(null, cp_)
@@ -51,7 +53,6 @@ class StreamingContext (
val graph: DStreamGraph = {
if (isCheckpointPresent) {
-
cp_.graph.setContext(this)
cp_.graph
} else {
@@ -62,7 +63,15 @@ class StreamingContext (
val nextNetworkInputStreamId = new AtomicInteger(0)
var networkInputTracker: NetworkInputTracker = null
- private[streaming] var checkpointFile: String = if (isCheckpointPresent) cp_.checkpointFile else null
+ private[streaming] var checkpointDir: String = {
+ if (isCheckpointPresent) {
+ sc.setCheckpointDir(cp_.checkpointDir, true)
+ cp_.checkpointDir
+ } else {
+ null
+ }
+ }
+
private[streaming] var checkpointInterval: Time = if (isCheckpointPresent) cp_.checkpointInterval else null
private[streaming] var receiverJobThread: Thread = null
private[streaming] var scheduler: Scheduler = null
@@ -75,9 +84,15 @@ class StreamingContext (
graph.setRememberDuration(duration)
}
- def setCheckpointDetails(file: String, interval: Time) {
- checkpointFile = file
- checkpointInterval = interval
+ def checkpoint(dir: String, interval: Time) {
+ if (dir != null) {
+ sc.setCheckpointDir(new Path(dir, "rdds-" + UUID.randomUUID.toString).toString)
+ checkpointDir = dir
+ checkpointInterval = interval
+ } else {
+ checkpointDir = null
+ checkpointInterval = null
+ }
}
private[streaming] def getInitialCheckpoint(): Checkpoint = {
@@ -170,16 +185,12 @@ class StreamingContext (
graph.addOutputStream(outputStream)
}
- def validate() {
- assert(graph != null, "Graph is null")
- graph.validate()
- }
-
/**
* This function starts the execution of the streams.
*/
def start() {
- validate()
+ assert(graph != null, "Graph is null")
+ graph.validate()
val networkInputStreams = graph.getInputStreams().filter(s => s match {
case n: NetworkInputDStream[_] => true
@@ -216,7 +227,8 @@ class StreamingContext (
}
def doCheckpoint(currentTime: Time) {
- new Checkpoint(this, currentTime).saveToFile(checkpointFile)
+ new Checkpoint(this, currentTime).save(checkpointDir)
+
}
}
diff --git a/streaming/src/main/scala/spark/streaming/Time.scala b/streaming/src/main/scala/spark/streaming/Time.scala
index 9ddb65249a..2ba6502971 100644
--- a/streaming/src/main/scala/spark/streaming/Time.scala
+++ b/streaming/src/main/scala/spark/streaming/Time.scala
@@ -25,6 +25,10 @@ case class Time(millis: Long) {
def isMultipleOf(that: Time): Boolean =
(this.millis % that.millis == 0)
+ def min(that: Time): Time = if (this < that) this else that
+
+ def max(that: Time): Time = if (this > that) this else that
+
def isZero: Boolean = (this.millis == 0)
override def toString: String = (millis.toString + " ms")
diff --git a/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala b/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala
index df96a811da..21a83c0fde 100644
--- a/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala
+++ b/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala
@@ -10,20 +10,20 @@ object FileStreamWithCheckpoint {
def main(args: Array[String]) {
if (args.size != 3) {
- println("FileStreamWithCheckpoint <master> <directory> <checkpoint file>")
- println("FileStreamWithCheckpoint restart <directory> <checkpoint file>")
+ println("FileStreamWithCheckpoint <master> <directory> <checkpoint dir>")
+ println("FileStreamWithCheckpoint restart <directory> <checkpoint dir>")
System.exit(-1)
}
val directory = new Path(args(1))
- val checkpointFile = args(2)
+ val checkpointDir = args(2)
val ssc: StreamingContext = {
if (args(0) == "restart") {
// Recreated streaming context from specified checkpoint file
- new StreamingContext(checkpointFile)
+ new StreamingContext(checkpointDir)
} else {
@@ -34,7 +34,7 @@ object FileStreamWithCheckpoint {
// Create new streaming context
val ssc_ = new StreamingContext(args(0), "FileStreamWithCheckpoint")
ssc_.setBatchDuration(Seconds(1))
- ssc_.setCheckpointDetails(checkpointFile, Seconds(1))
+ ssc_.checkpoint(checkpointDir, Seconds(1))
// Setup the streaming computation
val inputStream = ssc_.textFileStream(directory.toString)
diff --git a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala
index 57fd10f0a5..750cb7445f 100644
--- a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala
+++ b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala
@@ -41,9 +41,8 @@ object TopKWordCountRaw {
val windowedCounts = union.mapPartitions(splitAndCountPartitions)
.reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(batchMs), reduces)
- windowedCounts.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2,
- Milliseconds(chkptMs))
- //windowedCounts.print() // TODO: something else?
+ windowedCounts.persist().checkpoint(Milliseconds(chkptMs))
+ //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMs))
def topK(data: Iterator[(String, Long)], k: Int): Iterator[(String, Long)] = {
val taken = new Array[(String, Long)](k)
diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala
index 0d2e62b955..865026033e 100644
--- a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala
+++ b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala
@@ -100,10 +100,9 @@ object WordCount2 {
val windowedCounts = sentences
.mapPartitions(splitAndCountPartitions)
.reduceByKeyAndWindow(add _, subtract _, Seconds(30), batchDuration, reduceTasks.toInt)
- windowedCounts.persist(StorageLevel.MEMORY_ONLY,
- StorageLevel.MEMORY_ONLY_2,
- //new StorageLevel(false, true, true, 3),
- Milliseconds(chkptMillis.toLong))
+
+ windowedCounts.persist().checkpoint(Milliseconds(chkptMillis.toLong))
+ //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMillis.toLong))
windowedCounts.foreachRDD(r => println("Element count: " + r.count()))
ssc.start()
diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala
index abfd12890f..d1ea9a9cd5 100644
--- a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala
+++ b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala
@@ -41,9 +41,9 @@ object WordCountRaw {
val windowedCounts = union.mapPartitions(splitAndCountPartitions)
.reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(batchMs), reduces)
- windowedCounts.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2,
- Milliseconds(chkptMs))
- //windowedCounts.print() // TODO: something else?
+ windowedCounts.persist().checkpoint(chkptMs)
+ //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMs))
+
windowedCounts.foreachRDD(r => println("Element count: " + r.count()))
ssc.start()
diff --git a/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala b/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala
index 9d44da2b11..6a9c8a9a69 100644
--- a/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala
+++ b/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala
@@ -57,11 +57,13 @@ object WordMax2 {
val windowedCounts = sentences
.mapPartitions(splitAndCountPartitions)
.reduceByKey(add _, reduceTasks.toInt)
- .persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2,
- Milliseconds(chkptMillis.toLong))
+ .persist()
+ .checkpoint(Milliseconds(chkptMillis.toLong))
+ //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMillis.toLong))
.reduceByKeyAndWindow(max _, Seconds(10), batchDuration, reduceTasks.toInt)
- //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2,
- // Milliseconds(chkptMillis.toLong))
+ .persist()
+ .checkpoint(Milliseconds(chkptMillis.toLong))
+ //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMillis.toLong))
windowedCounts.foreachRDD(r => println("Element count: " + r.count()))
ssc.start()
diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala
index 6dcedcf463..dfe31b5771 100644
--- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala
@@ -2,52 +2,95 @@ package spark.streaming
import spark.streaming.StreamingContext._
import java.io.File
+import collection.mutable.ArrayBuffer
+import runtime.RichInt
+import org.scalatest.BeforeAndAfter
+import org.apache.hadoop.fs.Path
+import org.apache.commons.io.FileUtils
-class CheckpointSuite extends TestSuiteBase {
+class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
+
+ before {
+ FileUtils.deleteDirectory(new File(checkpointDir))
+ }
+
+ after {
+ FileUtils.deleteDirectory(new File(checkpointDir))
+ }
override def framework() = "CheckpointSuite"
- override def checkpointFile() = "checkpoint"
+ override def batchDuration() = Seconds(1)
+
+ override def checkpointDir() = "checkpoint"
+
+ override def checkpointInterval() = batchDuration
def testCheckpointedOperation[U: ClassManifest, V: ClassManifest](
input: Seq[Seq[U]],
operation: DStream[U] => DStream[V],
expectedOutput: Seq[Seq[V]],
- useSet: Boolean = false
+ initialNumBatches: Int
) {
// Current code assumes that:
// number of inputs = number of outputs = number of batches to be run
-
val totalNumBatches = input.size
- val initialNumBatches = input.size / 2
val nextNumBatches = totalNumBatches - initialNumBatches
val initialNumExpectedOutputs = initialNumBatches
+ val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs
// Do half the computation (half the number of batches), create checkpoint file and quit
val ssc = setupStreams[U, V](input, operation)
val output = runStreams[V](ssc, initialNumBatches, initialNumExpectedOutputs)
- verifyOutput[V](output, expectedOutput.take(initialNumBatches), useSet)
+ verifyOutput[V](output, expectedOutput.take(initialNumBatches), true)
Thread.sleep(1000)
// Restart and complete the computation from checkpoint file
- val sscNew = new StreamingContext(checkpointFile)
- sscNew.setCheckpointDetails(null, null)
- val outputNew = runStreams[V](sscNew, nextNumBatches, expectedOutput.size)
- verifyOutput[V](outputNew, expectedOutput, useSet)
-
- new File(checkpointFile).delete()
- new File(checkpointFile + ".bk").delete()
- new File("." + checkpointFile + ".crc").delete()
- new File("." + checkpointFile + ".bk.crc").delete()
+ val sscNew = new StreamingContext(checkpointDir)
+ //sscNew.checkpoint(null, null)
+ val outputNew = runStreams[V](sscNew, nextNumBatches, nextNumExpectedOutputs)
+ verifyOutput[V](outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), true)
}
- test("simple per-batch operation") {
+
+ test("map and reduceByKey") {
testCheckpointedOperation(
Seq( Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq() ),
(s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _),
Seq( Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq() ),
- true
+ 3
)
}
+
+ test("reduceByKeyAndWindowInv") {
+ val n = 10
+ val w = 4
+ val input = (1 to n).map(x => Seq("a")).toSeq
+ val output = Seq(Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 3))) ++ (1 to (n - w + 1)).map(x => Seq(("a", 4)))
+ val operation = (st: DStream[String]) => {
+ st.map(x => (x, 1)).reduceByKeyAndWindow(_ + _, _ - _, Seconds(w), Seconds(1))
+ }
+ for (i <- Seq(3, 5, 7)) {
+ testCheckpointedOperation(input, operation, output, i)
+ }
+ }
+
+ test("updateStateByKey") {
+ val input = (1 to 10).map(_ => Seq("a")).toSeq
+ val output = (1 to 10).map(x => Seq(("a", x))).toSeq
+ val operation = (st: DStream[String]) => {
+ val updateFunc = (values: Seq[Int], state: Option[RichInt]) => {
+ Some(new RichInt(values.foldLeft(0)(_ + _) + state.map(_.self).getOrElse(0)))
+ }
+ st.map(x => (x, 1))
+ .updateStateByKey[RichInt](updateFunc)
+ .checkpoint(Seconds(5))
+ .map(t => (t._1, t._2.self))
+ }
+ for (i <- Seq(3, 5, 7)) {
+ testCheckpointedOperation(input, operation, output, i)
+ }
+ }
+
} \ No newline at end of file
diff --git a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala
index c9bc454f91..e441feea19 100644
--- a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala
+++ b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala
@@ -5,10 +5,16 @@ import util.ManualClock
import collection.mutable.ArrayBuffer
import org.scalatest.FunSuite
import collection.mutable.SynchronizedBuffer
+import java.io.{ObjectInputStream, IOException}
+
+/**
+ * This is a input stream just for the testsuites. This is equivalent to a checkpointable,
+ * replayable, reliable message queue like Kafka. It requires a sequence as input, and
+ * returns the i_th element at the i_th batch unde manual clock.
+ */
class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, input: Seq[Seq[T]], numPartitions: Int)
extends InputDStream[T](ssc_) {
- var currentIndex = 0
def start() {}
@@ -23,17 +29,32 @@ class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, input: Seq[Seq[
ssc.sc.makeRDD(Seq[T](), numPartitions)
}
logInfo("Created RDD " + rdd.id)
- //currentIndex += 1
Some(rdd)
}
}
+/**
+ * This is a output stream just for the testsuites. All the output is collected into a
+ * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint.
+ */
class TestOutputStream[T: ClassManifest](parent: DStream[T], val output: ArrayBuffer[Seq[T]])
extends PerRDDForEachDStream[T](parent, (rdd: RDD[T], t: Time) => {
val collected = rdd.collect()
output += collected
- })
+ }) {
+
+ // This is to clear the output buffer every it is read from a checkpoint
+ @throws(classOf[IOException])
+ private def readObject(ois: ObjectInputStream) {
+ ois.defaultReadObject()
+ output.clear()
+ }
+}
+/**
+ * This is the base trait for Spark Streaming testsuites. This provides basic functionality
+ * to run user-defined set of input on user-defined stream operations, and verify the output.
+ */
trait TestSuiteBase extends FunSuite with Logging {
System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock")
@@ -44,7 +65,7 @@ trait TestSuiteBase extends FunSuite with Logging {
def batchDuration() = Seconds(1)
- def checkpointFile() = null.asInstanceOf[String]
+ def checkpointDir() = null.asInstanceOf[String]
def checkpointInterval() = batchDuration
@@ -60,8 +81,8 @@ trait TestSuiteBase extends FunSuite with Logging {
// Create StreamingContext
val ssc = new StreamingContext(master, framework)
ssc.setBatchDuration(batchDuration)
- if (checkpointFile != null) {
- ssc.setCheckpointDetails(checkpointFile, checkpointInterval())
+ if (checkpointDir != null) {
+ ssc.checkpoint(checkpointDir, checkpointInterval())
}
// Setup the stream computation
@@ -82,8 +103,8 @@ trait TestSuiteBase extends FunSuite with Logging {
// Create StreamingContext
val ssc = new StreamingContext(master, framework)
ssc.setBatchDuration(batchDuration)
- if (checkpointFile != null) {
- ssc.setCheckpointDetails(checkpointFile, checkpointInterval())
+ if (checkpointDir != null) {
+ ssc.checkpoint(checkpointDir, checkpointInterval())
}
// Setup the stream computation
diff --git a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala
index d7d8d5bd36..e282f0fdd5 100644
--- a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala
+++ b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala
@@ -283,7 +283,9 @@ class WindowOperationsSuite extends TestSuiteBase {
test("reduceByKeyAndWindowInv - " + name) {
val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt
val operation = (s: DStream[(String, Int)]) => {
- s.reduceByKeyAndWindow(_ + _, _ - _, windowTime, slideTime).persist()
+ s.reduceByKeyAndWindow(_ + _, _ - _, windowTime, slideTime)
+ .persist()
+ .checkpoint(Seconds(100)) // Large value to avoid effect of RDD checkpointing
}
testOperation(input, operation, expectedOutput, numBatches, true)
}