aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/spark/RDD.scala32
-rw-r--r--core/src/main/scala/spark/SparkContext.scala13
-rw-r--r--streaming/src/main/scala/spark/streaming/Checkpoint.scala83
-rw-r--r--streaming/src/main/scala/spark/streaming/DStream.scala156
-rw-r--r--streaming/src/main/scala/spark/streaming/DStreamGraph.scala7
-rw-r--r--streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala36
-rw-r--r--streaming/src/main/scala/spark/streaming/StateDStream.scala45
-rw-r--r--streaming/src/main/scala/spark/streaming/StreamingContext.scala38
-rw-r--r--streaming/src/main/scala/spark/streaming/Time.scala4
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala10
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala5
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/WordCount2.scala7
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala6
-rw-r--r--streaming/src/main/scala/spark/streaming/examples/WordMax2.scala10
-rw-r--r--streaming/src/test/scala/spark/streaming/CheckpointSuite.scala77
-rw-r--r--streaming/src/test/scala/spark/streaming/TestSuiteBase.scala37
-rw-r--r--streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala4
17 files changed, 367 insertions, 203 deletions
diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 7b59a6f09e..63048d5df0 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -119,22 +119,23 @@ abstract class RDD[T: ClassManifest](
private var storageLevel: StorageLevel = StorageLevel.NONE
/** Returns the first parent RDD */
- private[spark] def firstParent[U: ClassManifest] = {
+ protected[spark] def firstParent[U: ClassManifest] = {
dependencies.head.rdd.asInstanceOf[RDD[U]]
}
/** Returns the `i` th parent RDD */
- private[spark] def parent[U: ClassManifest](i: Int) = dependencies(i).rdd.asInstanceOf[RDD[U]]
+ protected[spark] def parent[U: ClassManifest](i: Int) = dependencies(i).rdd.asInstanceOf[RDD[U]]
// Variables relating to checkpointing
- val isCheckpointable = true // override to set this to false to avoid checkpointing an RDD
- var shouldCheckpoint = false // set to true when an RDD is marked for checkpointing
- var isCheckpointInProgress = false // set to true when checkpointing is in progress
- var isCheckpointed = false // set to true after checkpointing is completed
+ protected val isCheckpointable = true // override to set this to false to avoid checkpointing an RDD
- var checkpointFile: String = null // set to the checkpoint file after checkpointing is completed
- var checkpointRDD: RDD[T] = null // set to the HadoopRDD of the checkpoint file
- var checkpointRDDSplits: Seq[Split] = null // set to the splits of the Hadoop RDD
+ protected var shouldCheckpoint = false // set to true when an RDD is marked for checkpointing
+ protected var isCheckpointInProgress = false // set to true when checkpointing is in progress
+ protected[spark] var isCheckpointed = false // set to true after checkpointing is completed
+
+ protected[spark] var checkpointFile: String = null // set to the checkpoint file after checkpointing is completed
+ protected var checkpointRDD: RDD[T] = null // set to the HadoopRDD of the checkpoint file
+ protected var checkpointRDDSplits: Seq[Split] = null // set to the splits of the Hadoop RDD
// Methods available on all RDDs:
@@ -176,6 +177,9 @@ abstract class RDD[T: ClassManifest](
if (isCheckpointed || shouldCheckpoint || isCheckpointInProgress) {
// do nothing
} else if (isCheckpointable) {
+ if (sc.checkpointDir == null) {
+ throw new Exception("Checkpoint directory has not been set in the SparkContext.")
+ }
shouldCheckpoint = true
} else {
throw new Exception(this + " cannot be checkpointed")
@@ -183,6 +187,16 @@ abstract class RDD[T: ClassManifest](
}
}
+ def getCheckpointData(): Any = {
+ synchronized {
+ if (isCheckpointed) {
+ checkpointFile
+ } else {
+ null
+ }
+ }
+ }
+
/**
* Performs the checkpointing of this RDD by saving this . It is called by the DAGScheduler after a job
* using this RDD has completed (therefore the RDD has been materialized and
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index 79ceab5f4f..d7326971a9 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -584,14 +584,15 @@ class SparkContext(
* overwriting existing files may be overwritten). The directory will be deleted on exit
* if indicated.
*/
- def setCheckpointDir(dir: String, deleteOnExit: Boolean = false) {
+ def setCheckpointDir(dir: String, useExisting: Boolean = false) {
val path = new Path(dir)
val fs = path.getFileSystem(new Configuration())
- if (fs.exists(path)) {
- throw new Exception("Checkpoint directory '" + path + "' already exists.")
- } else {
- fs.mkdirs(path)
- if (deleteOnExit) fs.deleteOnExit(path)
+ if (!useExisting) {
+ if (fs.exists(path)) {
+ throw new Exception("Checkpoint directory '" + path + "' already exists.")
+ } else {
+ fs.mkdirs(path)
+ }
}
checkpointDir = dir
}
diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala
index 83a43d15cb..cf04c7031e 100644
--- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala
@@ -1,6 +1,6 @@
package spark.streaming
-import spark.Utils
+import spark.{Logging, Utils}
import org.apache.hadoop.fs.{FileUtil, Path}
import org.apache.hadoop.conf.Configuration
@@ -8,13 +8,14 @@ import org.apache.hadoop.conf.Configuration
import java.io.{InputStream, ObjectStreamClass, ObjectInputStream, ObjectOutputStream}
-class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) extends Serializable {
+class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time)
+ extends Logging with Serializable {
val master = ssc.sc.master
val framework = ssc.sc.jobName
val sparkHome = ssc.sc.sparkHome
val jars = ssc.sc.jars
val graph = ssc.graph
- val checkpointFile = ssc.checkpointFile
+ val checkpointDir = ssc.checkpointDir
val checkpointInterval = ssc.checkpointInterval
validate()
@@ -24,22 +25,25 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) ext
assert(framework != null, "Checkpoint.framework is null")
assert(graph != null, "Checkpoint.graph is null")
assert(checkpointTime != null, "Checkpoint.checkpointTime is null")
+ logInfo("Checkpoint for time " + checkpointTime + " validated")
}
- def saveToFile(file: String = checkpointFile) {
- val path = new Path(file)
+ def save(path: String) {
+ val file = new Path(path, "graph")
val conf = new Configuration()
- val fs = path.getFileSystem(conf)
- if (fs.exists(path)) {
- val bkPath = new Path(path.getParent, path.getName + ".bk")
- FileUtil.copy(fs, path, fs, bkPath, true, true, conf)
- //logInfo("Moved existing checkpoint file to " + bkPath)
+ val fs = file.getFileSystem(conf)
+ logDebug("Saved checkpoint for time " + checkpointTime + " to file '" + file + "'")
+ if (fs.exists(file)) {
+ val bkFile = new Path(file.getParent, file.getName + ".bk")
+ FileUtil.copy(fs, file, fs, bkFile, true, true, conf)
+ logDebug("Moved existing checkpoint file to " + bkFile)
}
- val fos = fs.create(path)
+ val fos = fs.create(file)
val oos = new ObjectOutputStream(fos)
oos.writeObject(this)
oos.close()
fs.close()
+ logInfo("Saved checkpoint for time " + checkpointTime + " to file '" + file + "'")
}
def toBytes(): Array[Byte] = {
@@ -50,30 +54,41 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) ext
object Checkpoint {
- def loadFromFile(file: String): Checkpoint = {
- try {
- val path = new Path(file)
- val conf = new Configuration()
- val fs = path.getFileSystem(conf)
- if (!fs.exists(path)) {
- throw new Exception("Checkpoint file '" + file + "' does not exist")
+ def load(path: String): Checkpoint = {
+
+ val fs = new Path(path).getFileSystem(new Configuration())
+ val attempts = Seq(new Path(path), new Path(path, "graph"), new Path(path, "graph.bk"))
+ var lastException: Exception = null
+ var lastExceptionFile: String = null
+
+ attempts.foreach(file => {
+ if (fs.exists(file)) {
+ try {
+ val fis = fs.open(file)
+ // ObjectInputStream uses the last defined user-defined class loader in the stack
+ // to find classes, which maybe the wrong class loader. Hence, a inherited version
+ // of ObjectInputStream is used to explicitly use the current thread's default class
+ // loader to find and load classes. This is a well know Java issue and has popped up
+ // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627)
+ val ois = new ObjectInputStreamWithLoader(fis, Thread.currentThread().getContextClassLoader)
+ val cp = ois.readObject.asInstanceOf[Checkpoint]
+ ois.close()
+ fs.close()
+ cp.validate()
+ println("Checkpoint successfully loaded from file " + file)
+ return cp
+ } catch {
+ case e: Exception =>
+ lastException = e
+ lastExceptionFile = file.toString
+ }
}
- val fis = fs.open(path)
- // ObjectInputStream uses the last defined user-defined class loader in the stack
- // to find classes, which maybe the wrong class loader. Hence, a inherited version
- // of ObjectInputStream is used to explicitly use the current thread's default class
- // loader to find and load classes. This is a well know Java issue and has popped up
- // in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627)
- val ois = new ObjectInputStreamWithLoader(fis, Thread.currentThread().getContextClassLoader)
- val cp = ois.readObject.asInstanceOf[Checkpoint]
- ois.close()
- fs.close()
- cp.validate()
- cp
- } catch {
- case e: Exception =>
- e.printStackTrace()
- throw new Exception("Could not load checkpoint file '" + file + "'", e)
+ })
+
+ if (lastException == null) {
+ throw new Exception("Could not load checkpoint from path '" + path + "'")
+ } else {
+ throw new Exception("Error loading checkpoint from path '" + lastExceptionFile + "'", lastException)
}
}
diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala
index a4921bb1a2..de51c5d34a 100644
--- a/streaming/src/main/scala/spark/streaming/DStream.scala
+++ b/streaming/src/main/scala/spark/streaming/DStream.scala
@@ -13,6 +13,7 @@ import scala.collection.mutable.HashMap
import java.util.concurrent.ArrayBlockingQueue
import java.io.{ObjectInputStream, IOException, ObjectOutputStream}
import scala.Some
+import collection.mutable
abstract class DStream[T: ClassManifest] (@transient var ssc: StreamingContext)
extends Serializable with Logging {
@@ -41,53 +42,55 @@ extends Serializable with Logging {
*/
// RDDs generated, marked as protected[streaming] so that testsuites can access it
- protected[streaming] val generatedRDDs = new HashMap[Time, RDD[T]] ()
+ protected[streaming] var generatedRDDs = new HashMap[Time, RDD[T]] ()
// Time zero for the DStream
- protected var zeroTime: Time = null
+ protected[streaming] var zeroTime: Time = null
// Duration for which the DStream will remember each RDD created
- protected var rememberDuration: Time = null
+ protected[streaming] var rememberDuration: Time = null
// Storage level of the RDDs in the stream
- protected var storageLevel: StorageLevel = StorageLevel.NONE
+ protected[streaming] var storageLevel: StorageLevel = StorageLevel.NONE
- // Checkpoint level and checkpoint interval
- protected var checkpointLevel: StorageLevel = StorageLevel.NONE // NONE means don't checkpoint
- protected var checkpointInterval: Time = null
+ // Checkpoint details
+ protected[streaming] val mustCheckpoint = false
+ protected[streaming] var checkpointInterval: Time = null
+ protected[streaming] val checkpointData = new HashMap[Time, Any]()
// Reference to whole DStream graph
- protected var graph: DStreamGraph = null
+ protected[streaming] var graph: DStreamGraph = null
def isInitialized = (zeroTime != null)
// Duration for which the DStream requires its parent DStream to remember each RDD created
def parentRememberDuration = rememberDuration
- // Change this RDD's storage level
- def persist(
- storageLevel: StorageLevel,
- checkpointLevel: StorageLevel,
- checkpointInterval: Time): DStream[T] = {
- if (this.storageLevel != StorageLevel.NONE && this.storageLevel != storageLevel) {
- // TODO: not sure this is necessary for DStreams
+ // Set caching level for the RDDs created by this DStream
+ def persist(level: StorageLevel): DStream[T] = {
+ if (this.isInitialized) {
throw new UnsupportedOperationException(
- "Cannot change storage level of an DStream after it was already assigned a level")
+ "Cannot change storage level of an DStream after streaming context has started")
}
- this.storageLevel = storageLevel
- this.checkpointLevel = checkpointLevel
- this.checkpointInterval = checkpointInterval
+ this.storageLevel = level
this
}
- // Set caching level for the RDDs created by this DStream
- def persist(newLevel: StorageLevel): DStream[T] = persist(newLevel, StorageLevel.NONE, null)
-
def persist(): DStream[T] = persist(StorageLevel.MEMORY_ONLY)
// Turn on the default caching level for this RDD
def cache(): DStream[T] = persist()
+ def checkpoint(interval: Time): DStream[T] = {
+ if (isInitialized) {
+ throw new UnsupportedOperationException(
+ "Cannot change checkpoint interval of an DStream after streaming context has started")
+ }
+ persist()
+ checkpointInterval = interval
+ this
+ }
+
/**
* This method initializes the DStream by setting the "zero" time, based on which
* the validity of future times is calculated. This method also recursively initializes
@@ -99,7 +102,67 @@ extends Serializable with Logging {
+ ", cannot initialize it again to " + time)
}
zeroTime = time
+
+ // Set the checkpoint interval to be slideTime or 10 seconds, which ever is larger
+ if (mustCheckpoint && checkpointInterval == null) {
+ checkpointInterval = slideTime.max(Seconds(10))
+ logInfo("Checkpoint interval automatically set to " + checkpointInterval)
+ }
+
+ // Set the minimum value of the rememberDuration if not already set
+ var minRememberDuration = slideTime
+ if (checkpointInterval != null && minRememberDuration <= checkpointInterval) {
+ minRememberDuration = checkpointInterval + slideTime
+ }
+ if (rememberDuration == null || rememberDuration < minRememberDuration) {
+ rememberDuration = minRememberDuration
+ }
+
+ // Initialize the dependencies
dependencies.foreach(_.initialize(zeroTime))
+ }
+
+ protected[streaming] def validate() {
+ assert(
+ !mustCheckpoint || checkpointInterval != null,
+ "The checkpoint interval for " + this.getClass.getSimpleName + " has not been set. " +
+ " Please use DStream.checkpoint() to set the interval."
+ )
+
+ assert(
+ checkpointInterval == null || checkpointInterval >= slideTime,
+ "The checkpoint interval for " + this.getClass.getSimpleName + " has been set to " +
+ checkpointInterval + " which is lower than its slide time (" + slideTime + "). " +
+ "Please set it to at least " + slideTime + "."
+ )
+
+ assert(
+ checkpointInterval == null || checkpointInterval.isMultipleOf(slideTime),
+ "The checkpoint interval for " + this.getClass.getSimpleName + " has been set to " +
+ checkpointInterval + " which not a multiple of its slide time (" + slideTime + "). " +
+ "Please set it to a multiple " + slideTime + "."
+ )
+
+ assert(
+ checkpointInterval == null || storageLevel != StorageLevel.NONE,
+ "" + this.getClass.getSimpleName + " has been marked for checkpointing but the storage " +
+ "level has not been set to enable persisting. Please use DStream.persist() to set the " +
+ "storage level to use memory for better checkpointing performance."
+ )
+
+ assert(
+ checkpointInterval == null || rememberDuration > checkpointInterval,
+ "The remember duration for " + this.getClass.getSimpleName + " has been set to " +
+ rememberDuration + " which is not more than the checkpoint interval (" +
+ checkpointInterval + "). Please set it to higher than " + checkpointInterval + "."
+ )
+
+ dependencies.foreach(_.validate())
+
+ logInfo("Slide time = " + slideTime)
+ logInfo("Storage level = " + storageLevel)
+ logInfo("Checkpoint interval = " + checkpointInterval)
+ logInfo("Remember duration = " + rememberDuration)
logInfo("Initialized " + this)
}
@@ -120,17 +183,12 @@ extends Serializable with Logging {
dependencies.foreach(_.setGraph(graph))
}
- protected[streaming] def setRememberDuration(duration: Time = slideTime) {
- if (duration == null) {
- throw new Exception("Duration for remembering RDDs cannot be set to null for " + this)
- } else if (rememberDuration != null && duration < rememberDuration) {
- logWarning("Duration for remembering RDDs cannot be reduced from " + rememberDuration
- + " to " + duration + " for " + this)
- } else {
+ protected[streaming] def setRememberDuration(duration: Time) {
+ if (duration != null && duration > rememberDuration) {
rememberDuration = duration
- dependencies.foreach(_.setRememberDuration(parentRememberDuration))
logInfo("Duration for remembering RDDs set to " + rememberDuration + " for " + this)
}
+ dependencies.foreach(_.setRememberDuration(parentRememberDuration))
}
/** This method checks whether the 'time' is valid wrt slideTime for generating RDD */
@@ -163,12 +221,13 @@ extends Serializable with Logging {
if (isTimeValid(time)) {
compute(time) match {
case Some(newRDD) =>
- if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) {
- newRDD.persist(checkpointLevel)
- logInfo("Persisting " + newRDD + " to " + checkpointLevel + " at time " + time)
- } else if (storageLevel != StorageLevel.NONE) {
+ if (storageLevel != StorageLevel.NONE) {
newRDD.persist(storageLevel)
- logInfo("Persisting " + newRDD + " to " + storageLevel + " at time " + time)
+ logInfo("Persisting RDD for time " + time + " to " + storageLevel + " at time " + time)
+ }
+ if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) {
+ newRDD.checkpoint()
+ logInfo("Marking RDD for time " + time + " for checkpointing at time " + time)
}
generatedRDDs.put(time, newRDD)
Some(newRDD)
@@ -199,7 +258,7 @@ extends Serializable with Logging {
}
}
- def forgetOldRDDs(time: Time) {
+ protected[streaming] def forgetOldRDDs(time: Time) {
val keys = generatedRDDs.keys
var numForgotten = 0
keys.foreach(t => {
@@ -213,12 +272,35 @@ extends Serializable with Logging {
dependencies.foreach(_.forgetOldRDDs(time))
}
+ protected[streaming] def updateCheckpointData() {
+ checkpointData.clear()
+ generatedRDDs.foreach {
+ case(time, rdd) => {
+ logDebug("Adding checkpointed RDD for time " + time)
+ val data = rdd.getCheckpointData()
+ if (data != null) {
+ checkpointData += ((time, data))
+ }
+ }
+ }
+ }
+
+ protected[streaming] def restoreCheckpointData() {
+ checkpointData.foreach {
+ case(time, data) => {
+ logInfo("Restoring checkpointed RDD for time " + time)
+ generatedRDDs += ((time, ssc.sc.objectFile[T](data.toString)))
+ }
+ }
+ }
+
@throws(classOf[IOException])
private def writeObject(oos: ObjectOutputStream) {
logDebug(this.getClass().getSimpleName + ".writeObject used")
if (graph != null) {
graph.synchronized {
if (graph.checkpointInProgress) {
+ updateCheckpointData()
oos.defaultWriteObject()
} else {
val msg = "Object of " + this.getClass.getName + " is being serialized " +
@@ -239,6 +321,8 @@ extends Serializable with Logging {
private def readObject(ois: ObjectInputStream) {
logDebug(this.getClass().getSimpleName + ".readObject used")
ois.defaultReadObject()
+ generatedRDDs = new HashMap[Time, RDD[T]] ()
+ restoreCheckpointData()
}
/**
diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala
index ac44d7a2a6..f8922ec790 100644
--- a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala
+++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala
@@ -22,11 +22,8 @@ final class DStreamGraph extends Serializable with Logging {
}
zeroTime = time
outputStreams.foreach(_.initialize(zeroTime))
- outputStreams.foreach(_.setRememberDuration()) // first set the rememberDuration to default values
- if (rememberDuration != null) {
- // if custom rememberDuration has been provided, set the rememberDuration
- outputStreams.foreach(_.setRememberDuration(rememberDuration))
- }
+ outputStreams.foreach(_.setRememberDuration(rememberDuration))
+ outputStreams.foreach(_.validate)
inputStreams.par.foreach(_.start())
}
}
diff --git a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala
index 1c57d5f855..6df82c0df3 100644
--- a/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/ReducedWindowedDStream.scala
@@ -21,15 +21,19 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest](
partitioner: Partitioner
) extends DStream[(K,V)](parent.ssc) {
- if (!_windowTime.isMultipleOf(parent.slideTime))
- throw new Exception("The window duration of ReducedWindowedDStream (" + _slideTime + ") " +
- "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")")
+ assert(_windowTime.isMultipleOf(parent.slideTime),
+ "The window duration of ReducedWindowedDStream (" + _slideTime + ") " +
+ "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")"
+ )
- if (!_slideTime.isMultipleOf(parent.slideTime))
- throw new Exception("The slide duration of ReducedWindowedDStream (" + _slideTime + ") " +
- "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")")
+ assert(_slideTime.isMultipleOf(parent.slideTime),
+ "The slide duration of ReducedWindowedDStream (" + _slideTime + ") " +
+ "must be multiple of the slide duration of parent DStream (" + parent.slideTime + ")"
+ )
- @transient val reducedStream = parent.reduceByKey(reduceFunc, partitioner)
+ super.persist(StorageLevel.MEMORY_ONLY)
+
+ val reducedStream = parent.reduceByKey(reduceFunc, partitioner)
def windowTime: Time = _windowTime
@@ -37,15 +41,19 @@ class ReducedWindowedDStream[K: ClassManifest, V: ClassManifest](
override def slideTime: Time = _slideTime
- //TODO: This is wrong. This should depend on the checkpointInterval
+ override val mustCheckpoint = true
+
override def parentRememberDuration: Time = rememberDuration + windowTime
- override def persist(
- storageLevel: StorageLevel,
- checkpointLevel: StorageLevel,
- checkpointInterval: Time): DStream[(K,V)] = {
- super.persist(storageLevel, checkpointLevel, checkpointInterval)
- reducedStream.persist(storageLevel, checkpointLevel, checkpointInterval)
+ override def persist(storageLevel: StorageLevel): DStream[(K,V)] = {
+ super.persist(storageLevel)
+ reducedStream.persist(storageLevel)
+ this
+ }
+
+ override def checkpoint(interval: Time): DStream[(K, V)] = {
+ super.checkpoint(interval)
+ reducedStream.checkpoint(interval)
this
}
diff --git a/streaming/src/main/scala/spark/streaming/StateDStream.scala b/streaming/src/main/scala/spark/streaming/StateDStream.scala
index 086752ac55..0211df1343 100644
--- a/streaming/src/main/scala/spark/streaming/StateDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/StateDStream.scala
@@ -23,51 +23,14 @@ class StateDStream[K: ClassManifest, V: ClassManifest, S <: AnyRef : ClassManife
rememberPartitioner: Boolean
) extends DStream[(K, S)](parent.ssc) {
+ super.persist(StorageLevel.MEMORY_ONLY)
+
override def dependencies = List(parent)
override def slideTime = parent.slideTime
- override def getOrCompute(time: Time): Option[RDD[(K, S)]] = {
- generatedRDDs.get(time) match {
- case Some(oldRDD) => {
- if (checkpointInterval != null && time > zeroTime && (time - zeroTime).isMultipleOf(checkpointInterval) && oldRDD.dependencies.size > 0) {
- val r = oldRDD
- val oldRDDBlockIds = oldRDD.splits.map(s => "rdd:" + r.id + ":" + s.index)
- val checkpointedRDD = new BlockRDD[(K, S)](ssc.sc, oldRDDBlockIds) {
- override val partitioner = oldRDD.partitioner
- }
- generatedRDDs.update(time, checkpointedRDD)
- logInfo("Checkpointed RDD " + oldRDD.id + " of time " + time + " with its new RDD " + checkpointedRDD.id)
- Some(checkpointedRDD)
- } else {
- Some(oldRDD)
- }
- }
- case None => {
- if (isTimeValid(time)) {
- compute(time) match {
- case Some(newRDD) => {
- if (checkpointInterval != null && (time - zeroTime).isMultipleOf(checkpointInterval)) {
- newRDD.persist(checkpointLevel)
- logInfo("Persisting " + newRDD + " to " + checkpointLevel + " at time " + time)
- } else if (storageLevel != StorageLevel.NONE) {
- newRDD.persist(storageLevel)
- logInfo("Persisting " + newRDD + " to " + storageLevel + " at time " + time)
- }
- generatedRDDs.put(time, newRDD)
- Some(newRDD)
- }
- case None => {
- None
- }
- }
- } else {
- None
- }
- }
- }
- }
-
+ override val mustCheckpoint = true
+
override def compute(validTime: Time): Option[RDD[(K, S)]] = {
// Try to get the previous state RDD
diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala
index b3148eaa97..3838e84113 100644
--- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala
@@ -15,6 +15,8 @@ import org.apache.hadoop.io.LongWritable
import org.apache.hadoop.io.Text
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat
+import org.apache.hadoop.fs.Path
+import java.util.UUID
class StreamingContext (
sc_ : SparkContext,
@@ -26,7 +28,7 @@ class StreamingContext (
def this(master: String, frameworkName: String, sparkHome: String = null, jars: Seq[String] = Nil) =
this(new SparkContext(master, frameworkName, sparkHome, jars), null)
- def this(file: String) = this(null, Checkpoint.loadFromFile(file))
+ def this(path: String) = this(null, Checkpoint.load(path))
def this(cp_ : Checkpoint) = this(null, cp_)
@@ -51,7 +53,6 @@ class StreamingContext (
val graph: DStreamGraph = {
if (isCheckpointPresent) {
-
cp_.graph.setContext(this)
cp_.graph
} else {
@@ -62,7 +63,15 @@ class StreamingContext (
val nextNetworkInputStreamId = new AtomicInteger(0)
var networkInputTracker: NetworkInputTracker = null
- private[streaming] var checkpointFile: String = if (isCheckpointPresent) cp_.checkpointFile else null
+ private[streaming] var checkpointDir: String = {
+ if (isCheckpointPresent) {
+ sc.setCheckpointDir(cp_.checkpointDir, true)
+ cp_.checkpointDir
+ } else {
+ null
+ }
+ }
+
private[streaming] var checkpointInterval: Time = if (isCheckpointPresent) cp_.checkpointInterval else null
private[streaming] var receiverJobThread: Thread = null
private[streaming] var scheduler: Scheduler = null
@@ -75,9 +84,15 @@ class StreamingContext (
graph.setRememberDuration(duration)
}
- def setCheckpointDetails(file: String, interval: Time) {
- checkpointFile = file
- checkpointInterval = interval
+ def checkpoint(dir: String, interval: Time) {
+ if (dir != null) {
+ sc.setCheckpointDir(new Path(dir, "rdds-" + UUID.randomUUID.toString).toString)
+ checkpointDir = dir
+ checkpointInterval = interval
+ } else {
+ checkpointDir = null
+ checkpointInterval = null
+ }
}
private[streaming] def getInitialCheckpoint(): Checkpoint = {
@@ -170,16 +185,12 @@ class StreamingContext (
graph.addOutputStream(outputStream)
}
- def validate() {
- assert(graph != null, "Graph is null")
- graph.validate()
- }
-
/**
* This function starts the execution of the streams.
*/
def start() {
- validate()
+ assert(graph != null, "Graph is null")
+ graph.validate()
val networkInputStreams = graph.getInputStreams().filter(s => s match {
case n: NetworkInputDStream[_] => true
@@ -216,7 +227,8 @@ class StreamingContext (
}
def doCheckpoint(currentTime: Time) {
- new Checkpoint(this, currentTime).saveToFile(checkpointFile)
+ new Checkpoint(this, currentTime).save(checkpointDir)
+
}
}
diff --git a/streaming/src/main/scala/spark/streaming/Time.scala b/streaming/src/main/scala/spark/streaming/Time.scala
index 9ddb65249a..2ba6502971 100644
--- a/streaming/src/main/scala/spark/streaming/Time.scala
+++ b/streaming/src/main/scala/spark/streaming/Time.scala
@@ -25,6 +25,10 @@ case class Time(millis: Long) {
def isMultipleOf(that: Time): Boolean =
(this.millis % that.millis == 0)
+ def min(that: Time): Time = if (this < that) this else that
+
+ def max(that: Time): Time = if (this > that) this else that
+
def isZero: Boolean = (this.millis == 0)
override def toString: String = (millis.toString + " ms")
diff --git a/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala b/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala
index df96a811da..21a83c0fde 100644
--- a/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala
+++ b/streaming/src/main/scala/spark/streaming/examples/FileStreamWithCheckpoint.scala
@@ -10,20 +10,20 @@ object FileStreamWithCheckpoint {
def main(args: Array[String]) {
if (args.size != 3) {
- println("FileStreamWithCheckpoint <master> <directory> <checkpoint file>")
- println("FileStreamWithCheckpoint restart <directory> <checkpoint file>")
+ println("FileStreamWithCheckpoint <master> <directory> <checkpoint dir>")
+ println("FileStreamWithCheckpoint restart <directory> <checkpoint dir>")
System.exit(-1)
}
val directory = new Path(args(1))
- val checkpointFile = args(2)
+ val checkpointDir = args(2)
val ssc: StreamingContext = {
if (args(0) == "restart") {
// Recreated streaming context from specified checkpoint file
- new StreamingContext(checkpointFile)
+ new StreamingContext(checkpointDir)
} else {
@@ -34,7 +34,7 @@ object FileStreamWithCheckpoint {
// Create new streaming context
val ssc_ = new StreamingContext(args(0), "FileStreamWithCheckpoint")
ssc_.setBatchDuration(Seconds(1))
- ssc_.setCheckpointDetails(checkpointFile, Seconds(1))
+ ssc_.checkpoint(checkpointDir, Seconds(1))
// Setup the streaming computation
val inputStream = ssc_.textFileStream(directory.toString)
diff --git a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala
index 57fd10f0a5..750cb7445f 100644
--- a/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala
+++ b/streaming/src/main/scala/spark/streaming/examples/TopKWordCountRaw.scala
@@ -41,9 +41,8 @@ object TopKWordCountRaw {
val windowedCounts = union.mapPartitions(splitAndCountPartitions)
.reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(batchMs), reduces)
- windowedCounts.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2,
- Milliseconds(chkptMs))
- //windowedCounts.print() // TODO: something else?
+ windowedCounts.persist().checkpoint(Milliseconds(chkptMs))
+ //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMs))
def topK(data: Iterator[(String, Long)], k: Int): Iterator[(String, Long)] = {
val taken = new Array[(String, Long)](k)
diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala
index 0d2e62b955..865026033e 100644
--- a/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala
+++ b/streaming/src/main/scala/spark/streaming/examples/WordCount2.scala
@@ -100,10 +100,9 @@ object WordCount2 {
val windowedCounts = sentences
.mapPartitions(splitAndCountPartitions)
.reduceByKeyAndWindow(add _, subtract _, Seconds(30), batchDuration, reduceTasks.toInt)
- windowedCounts.persist(StorageLevel.MEMORY_ONLY,
- StorageLevel.MEMORY_ONLY_2,
- //new StorageLevel(false, true, true, 3),
- Milliseconds(chkptMillis.toLong))
+
+ windowedCounts.persist().checkpoint(Milliseconds(chkptMillis.toLong))
+ //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMillis.toLong))
windowedCounts.foreachRDD(r => println("Element count: " + r.count()))
ssc.start()
diff --git a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala
index abfd12890f..d1ea9a9cd5 100644
--- a/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala
+++ b/streaming/src/main/scala/spark/streaming/examples/WordCountRaw.scala
@@ -41,9 +41,9 @@ object WordCountRaw {
val windowedCounts = union.mapPartitions(splitAndCountPartitions)
.reduceByKeyAndWindow(add _, subtract _, Seconds(30), Milliseconds(batchMs), reduces)
- windowedCounts.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2,
- Milliseconds(chkptMs))
- //windowedCounts.print() // TODO: something else?
+ windowedCounts.persist().checkpoint(chkptMs)
+ //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMs))
+
windowedCounts.foreachRDD(r => println("Element count: " + r.count()))
ssc.start()
diff --git a/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala b/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala
index 9d44da2b11..6a9c8a9a69 100644
--- a/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala
+++ b/streaming/src/main/scala/spark/streaming/examples/WordMax2.scala
@@ -57,11 +57,13 @@ object WordMax2 {
val windowedCounts = sentences
.mapPartitions(splitAndCountPartitions)
.reduceByKey(add _, reduceTasks.toInt)
- .persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2,
- Milliseconds(chkptMillis.toLong))
+ .persist()
+ .checkpoint(Milliseconds(chkptMillis.toLong))
+ //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMillis.toLong))
.reduceByKeyAndWindow(max _, Seconds(10), batchDuration, reduceTasks.toInt)
- //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2,
- // Milliseconds(chkptMillis.toLong))
+ .persist()
+ .checkpoint(Milliseconds(chkptMillis.toLong))
+ //.persist(StorageLevel.MEMORY_ONLY, StorageLevel.MEMORY_ONLY_2, Milliseconds(chkptMillis.toLong))
windowedCounts.foreachRDD(r => println("Element count: " + r.count()))
ssc.start()
diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala
index 6dcedcf463..dfe31b5771 100644
--- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala
@@ -2,52 +2,95 @@ package spark.streaming
import spark.streaming.StreamingContext._
import java.io.File
+import collection.mutable.ArrayBuffer
+import runtime.RichInt
+import org.scalatest.BeforeAndAfter
+import org.apache.hadoop.fs.Path
+import org.apache.commons.io.FileUtils
-class CheckpointSuite extends TestSuiteBase {
+class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
+
+ before {
+ FileUtils.deleteDirectory(new File(checkpointDir))
+ }
+
+ after {
+ FileUtils.deleteDirectory(new File(checkpointDir))
+ }
override def framework() = "CheckpointSuite"
- override def checkpointFile() = "checkpoint"
+ override def batchDuration() = Seconds(1)
+
+ override def checkpointDir() = "checkpoint"
+
+ override def checkpointInterval() = batchDuration
def testCheckpointedOperation[U: ClassManifest, V: ClassManifest](
input: Seq[Seq[U]],
operation: DStream[U] => DStream[V],
expectedOutput: Seq[Seq[V]],
- useSet: Boolean = false
+ initialNumBatches: Int
) {
// Current code assumes that:
// number of inputs = number of outputs = number of batches to be run
-
val totalNumBatches = input.size
- val initialNumBatches = input.size / 2
val nextNumBatches = totalNumBatches - initialNumBatches
val initialNumExpectedOutputs = initialNumBatches
+ val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs
// Do half the computation (half the number of batches), create checkpoint file and quit
val ssc = setupStreams[U, V](input, operation)
val output = runStreams[V](ssc, initialNumBatches, initialNumExpectedOutputs)
- verifyOutput[V](output, expectedOutput.take(initialNumBatches), useSet)
+ verifyOutput[V](output, expectedOutput.take(initialNumBatches), true)
Thread.sleep(1000)
// Restart and complete the computation from checkpoint file
- val sscNew = new StreamingContext(checkpointFile)
- sscNew.setCheckpointDetails(null, null)
- val outputNew = runStreams[V](sscNew, nextNumBatches, expectedOutput.size)
- verifyOutput[V](outputNew, expectedOutput, useSet)
-
- new File(checkpointFile).delete()
- new File(checkpointFile + ".bk").delete()
- new File("." + checkpointFile + ".crc").delete()
- new File("." + checkpointFile + ".bk.crc").delete()
+ val sscNew = new StreamingContext(checkpointDir)
+ //sscNew.checkpoint(null, null)
+ val outputNew = runStreams[V](sscNew, nextNumBatches, nextNumExpectedOutputs)
+ verifyOutput[V](outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), true)
}
- test("simple per-batch operation") {
+
+ test("map and reduceByKey") {
testCheckpointedOperation(
Seq( Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq() ),
(s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _),
Seq( Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq() ),
- true
+ 3
)
}
+
+ test("reduceByKeyAndWindowInv") {
+ val n = 10
+ val w = 4
+ val input = (1 to n).map(x => Seq("a")).toSeq
+ val output = Seq(Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 3))) ++ (1 to (n - w + 1)).map(x => Seq(("a", 4)))
+ val operation = (st: DStream[String]) => {
+ st.map(x => (x, 1)).reduceByKeyAndWindow(_ + _, _ - _, Seconds(w), Seconds(1))
+ }
+ for (i <- Seq(3, 5, 7)) {
+ testCheckpointedOperation(input, operation, output, i)
+ }
+ }
+
+ test("updateStateByKey") {
+ val input = (1 to 10).map(_ => Seq("a")).toSeq
+ val output = (1 to 10).map(x => Seq(("a", x))).toSeq
+ val operation = (st: DStream[String]) => {
+ val updateFunc = (values: Seq[Int], state: Option[RichInt]) => {
+ Some(new RichInt(values.foldLeft(0)(_ + _) + state.map(_.self).getOrElse(0)))
+ }
+ st.map(x => (x, 1))
+ .updateStateByKey[RichInt](updateFunc)
+ .checkpoint(Seconds(5))
+ .map(t => (t._1, t._2.self))
+ }
+ for (i <- Seq(3, 5, 7)) {
+ testCheckpointedOperation(input, operation, output, i)
+ }
+ }
+
} \ No newline at end of file
diff --git a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala
index c9bc454f91..e441feea19 100644
--- a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala
+++ b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala
@@ -5,10 +5,16 @@ import util.ManualClock
import collection.mutable.ArrayBuffer
import org.scalatest.FunSuite
import collection.mutable.SynchronizedBuffer
+import java.io.{ObjectInputStream, IOException}
+
+/**
+ * This is a input stream just for the testsuites. This is equivalent to a checkpointable,
+ * replayable, reliable message queue like Kafka. It requires a sequence as input, and
+ * returns the i_th element at the i_th batch unde manual clock.
+ */
class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, input: Seq[Seq[T]], numPartitions: Int)
extends InputDStream[T](ssc_) {
- var currentIndex = 0
def start() {}
@@ -23,17 +29,32 @@ class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, input: Seq[Seq[
ssc.sc.makeRDD(Seq[T](), numPartitions)
}
logInfo("Created RDD " + rdd.id)
- //currentIndex += 1
Some(rdd)
}
}
+/**
+ * This is a output stream just for the testsuites. All the output is collected into a
+ * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint.
+ */
class TestOutputStream[T: ClassManifest](parent: DStream[T], val output: ArrayBuffer[Seq[T]])
extends PerRDDForEachDStream[T](parent, (rdd: RDD[T], t: Time) => {
val collected = rdd.collect()
output += collected
- })
+ }) {
+
+ // This is to clear the output buffer every it is read from a checkpoint
+ @throws(classOf[IOException])
+ private def readObject(ois: ObjectInputStream) {
+ ois.defaultReadObject()
+ output.clear()
+ }
+}
+/**
+ * This is the base trait for Spark Streaming testsuites. This provides basic functionality
+ * to run user-defined set of input on user-defined stream operations, and verify the output.
+ */
trait TestSuiteBase extends FunSuite with Logging {
System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock")
@@ -44,7 +65,7 @@ trait TestSuiteBase extends FunSuite with Logging {
def batchDuration() = Seconds(1)
- def checkpointFile() = null.asInstanceOf[String]
+ def checkpointDir() = null.asInstanceOf[String]
def checkpointInterval() = batchDuration
@@ -60,8 +81,8 @@ trait TestSuiteBase extends FunSuite with Logging {
// Create StreamingContext
val ssc = new StreamingContext(master, framework)
ssc.setBatchDuration(batchDuration)
- if (checkpointFile != null) {
- ssc.setCheckpointDetails(checkpointFile, checkpointInterval())
+ if (checkpointDir != null) {
+ ssc.checkpoint(checkpointDir, checkpointInterval())
}
// Setup the stream computation
@@ -82,8 +103,8 @@ trait TestSuiteBase extends FunSuite with Logging {
// Create StreamingContext
val ssc = new StreamingContext(master, framework)
ssc.setBatchDuration(batchDuration)
- if (checkpointFile != null) {
- ssc.setCheckpointDetails(checkpointFile, checkpointInterval())
+ if (checkpointDir != null) {
+ ssc.checkpoint(checkpointDir, checkpointInterval())
}
// Setup the stream computation
diff --git a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala
index d7d8d5bd36..e282f0fdd5 100644
--- a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala
+++ b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala
@@ -283,7 +283,9 @@ class WindowOperationsSuite extends TestSuiteBase {
test("reduceByKeyAndWindowInv - " + name) {
val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt
val operation = (s: DStream[(String, Int)]) => {
- s.reduceByKeyAndWindow(_ + _, _ - _, windowTime, slideTime).persist()
+ s.reduceByKeyAndWindow(_ + _, _ - _, windowTime, slideTime)
+ .persist()
+ .checkpoint(Seconds(100)) // Large value to avoid effect of RDD checkpointing
}
testOperation(input, operation, expectedOutput, numBatches, true)
}