aboutsummaryrefslogtreecommitdiff
path: root/streaming
diff options
context:
space:
mode:
authorTathagata Das <tathagata.das1565@gmail.com>2012-11-05 11:41:36 -0800
committerTathagata Das <tathagata.das1565@gmail.com>2012-11-05 11:41:36 -0800
commit72b2303f99bd652fc4bdaa929f37731a7ba8f640 (patch)
tree6038f22fc22038ee2bec739da5996b8d1ceb0dee /streaming
parentd1542387891018914fdd6b647f17f0b05acdd40e (diff)
downloadspark-72b2303f99bd652fc4bdaa929f37731a7ba8f640.tar.gz
spark-72b2303f99bd652fc4bdaa929f37731a7ba8f640.tar.bz2
spark-72b2303f99bd652fc4bdaa929f37731a7ba8f640.zip
Fixed major bugs in checkpointing.
Diffstat (limited to 'streaming')
-rw-r--r--streaming/src/main/scala/spark/streaming/Checkpoint.scala24
-rw-r--r--streaming/src/main/scala/spark/streaming/DStream.scala47
-rw-r--r--streaming/src/main/scala/spark/streaming/DStreamGraph.scala36
-rw-r--r--streaming/src/main/scala/spark/streaming/Scheduler.scala1
-rw-r--r--streaming/src/main/scala/spark/streaming/StreamingContext.scala8
-rw-r--r--streaming/src/test/scala/spark/streaming/CheckpointSuite.scala139
-rw-r--r--streaming/src/test/scala/spark/streaming/TestSuiteBase.scala37
-rw-r--r--streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala6
8 files changed, 213 insertions, 85 deletions
diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala
index cf04c7031e..6b4b05103f 100644
--- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala
@@ -6,6 +6,7 @@ import org.apache.hadoop.fs.{FileUtil, Path}
import org.apache.hadoop.conf.Configuration
import java.io.{InputStream, ObjectStreamClass, ObjectInputStream, ObjectOutputStream}
+import sys.process.processInternal
class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time)
@@ -52,17 +53,17 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time)
}
}
-object Checkpoint {
+object Checkpoint extends Logging {
def load(path: String): Checkpoint = {
val fs = new Path(path).getFileSystem(new Configuration())
- val attempts = Seq(new Path(path), new Path(path, "graph"), new Path(path, "graph.bk"))
- var lastException: Exception = null
- var lastExceptionFile: String = null
+ val attempts = Seq(new Path(path, "graph"), new Path(path, "graph.bk"), new Path(path), new Path(path + ".bk"))
+ var detailedLog: String = ""
attempts.foreach(file => {
if (fs.exists(file)) {
+ logInfo("Attempting to load checkpoint from file '" + file + "'")
try {
val fis = fs.open(file)
// ObjectInputStream uses the last defined user-defined class loader in the stack
@@ -75,21 +76,18 @@ object Checkpoint {
ois.close()
fs.close()
cp.validate()
- println("Checkpoint successfully loaded from file " + file)
+ logInfo("Checkpoint successfully loaded from file '" + file + "'")
return cp
} catch {
case e: Exception =>
- lastException = e
- lastExceptionFile = file.toString
+ logError("Error loading checkpoint from file '" + file + "'", e)
}
+ } else {
+ logWarning("Could not load checkpoint from file '" + file + "' as it does not exist")
}
- })
- if (lastException == null) {
- throw new Exception("Could not load checkpoint from path '" + path + "'")
- } else {
- throw new Exception("Error loading checkpoint from path '" + lastExceptionFile + "'", lastException)
- }
+ })
+ throw new Exception("Could not load checkpoint from path '" + path + "'")
}
def fromBytes(bytes: Array[Byte]): Checkpoint = {
diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala
index de51c5d34a..2fecbe0acf 100644
--- a/streaming/src/main/scala/spark/streaming/DStream.scala
+++ b/streaming/src/main/scala/spark/streaming/DStream.scala
@@ -14,6 +14,8 @@ import java.util.concurrent.ArrayBlockingQueue
import java.io.{ObjectInputStream, IOException, ObjectOutputStream}
import scala.Some
import collection.mutable
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.conf.Configuration
abstract class DStream[T: ClassManifest] (@transient var ssc: StreamingContext)
extends Serializable with Logging {
@@ -42,6 +44,7 @@ extends Serializable with Logging {
*/
// RDDs generated, marked as protected[streaming] so that testsuites can access it
+ @transient
protected[streaming] var generatedRDDs = new HashMap[Time, RDD[T]] ()
// Time zero for the DStream
@@ -112,7 +115,7 @@ extends Serializable with Logging {
// Set the minimum value of the rememberDuration if not already set
var minRememberDuration = slideTime
if (checkpointInterval != null && minRememberDuration <= checkpointInterval) {
- minRememberDuration = checkpointInterval + slideTime
+ minRememberDuration = checkpointInterval * 2 // times 2 just to be sure that the latest checkpoint is not forgetten
}
if (rememberDuration == null || rememberDuration < minRememberDuration) {
rememberDuration = minRememberDuration
@@ -265,33 +268,59 @@ extends Serializable with Logging {
if (t <= (time - rememberDuration)) {
generatedRDDs.remove(t)
numForgotten += 1
- //logInfo("Forgot RDD of time " + t + " from " + this)
+ logInfo("Forgot RDD of time " + t + " from " + this)
}
})
logInfo("Forgot " + numForgotten + " RDDs from " + this)
dependencies.foreach(_.forgetOldRDDs(time))
}
+ /**
+ * Refreshes the list of checkpointed RDDs that will be saved along with checkpoint of this stream.
+ * Along with that it forget old checkpoint files.
+ */
protected[streaming] def updateCheckpointData() {
+
+ // TODO (tdas): This code can be simplified. Its kept verbose to aid debugging.
+ val checkpointedRDDs = generatedRDDs.filter(_._2.getCheckpointData() != null)
+ val removedCheckpointData = checkpointData.filter(x => !generatedRDDs.contains(x._1))
+
checkpointData.clear()
- generatedRDDs.foreach {
- case(time, rdd) => {
- logDebug("Adding checkpointed RDD for time " + time)
+ checkpointedRDDs.foreach {
+ case (time, rdd) => {
val data = rdd.getCheckpointData()
- if (data != null) {
- checkpointData += ((time, data))
+ assert(data != null)
+ checkpointData += ((time, data))
+ logInfo("Added checkpointed RDD " + rdd + " for time " + time + " to stream checkpoint")
+ }
+ }
+
+ dependencies.foreach(_.updateCheckpointData())
+ // If at least one checkpoint is present, then delete old checkpoints
+ if (checkpointData.size > 0) {
+ // Delete the checkpoint RDD files that are not needed any more
+ removedCheckpointData.foreach {
+ case (time: Time, file: String) => {
+ val path = new Path(file)
+ val fs = path.getFileSystem(new Configuration())
+ fs.delete(path, true)
+ logInfo("Deleted checkpoint file '" + file + "' for time " + time)
}
}
}
+
+ logInfo("Updated checkpoint data")
}
protected[streaming] def restoreCheckpointData() {
+ logInfo("Restoring checkpoint data from " + checkpointData.size + " checkpointed RDDs")
checkpointData.foreach {
case(time, data) => {
- logInfo("Restoring checkpointed RDD for time " + time)
+ logInfo("Restoring checkpointed RDD for time " + time + " from file")
generatedRDDs += ((time, ssc.sc.objectFile[T](data.toString)))
}
}
+ dependencies.foreach(_.restoreCheckpointData())
}
@throws(classOf[IOException])
@@ -300,7 +329,6 @@ extends Serializable with Logging {
if (graph != null) {
graph.synchronized {
if (graph.checkpointInProgress) {
- updateCheckpointData()
oos.defaultWriteObject()
} else {
val msg = "Object of " + this.getClass.getName + " is being serialized " +
@@ -322,7 +350,6 @@ extends Serializable with Logging {
logDebug(this.getClass().getSimpleName + ".readObject used")
ois.defaultReadObject()
generatedRDDs = new HashMap[Time, RDD[T]] ()
- restoreCheckpointData()
}
/**
diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala
index f8922ec790..7437f4402d 100644
--- a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala
+++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala
@@ -4,7 +4,7 @@ import java.io.{ObjectInputStream, IOException, ObjectOutputStream}
import collection.mutable.ArrayBuffer
import spark.Logging
-final class DStreamGraph extends Serializable with Logging {
+final private[streaming] class DStreamGraph extends Serializable with Logging {
initLogging()
private val inputStreams = new ArrayBuffer[InputDStream[_]]()
@@ -15,7 +15,7 @@ final class DStreamGraph extends Serializable with Logging {
private[streaming] var rememberDuration: Time = null
private[streaming] var checkpointInProgress = false
- def start(time: Time) {
+ private[streaming] def start(time: Time) {
this.synchronized {
if (zeroTime != null) {
throw new Exception("DStream graph computation already started")
@@ -28,7 +28,7 @@ final class DStreamGraph extends Serializable with Logging {
}
}
- def stop() {
+ private[streaming] def stop() {
this.synchronized {
inputStreams.par.foreach(_.stop())
}
@@ -40,7 +40,7 @@ final class DStreamGraph extends Serializable with Logging {
}
}
- def setBatchDuration(duration: Time) {
+ private[streaming] def setBatchDuration(duration: Time) {
this.synchronized {
if (batchDuration != null) {
throw new Exception("Batch duration already set as " + batchDuration +
@@ -50,7 +50,7 @@ final class DStreamGraph extends Serializable with Logging {
batchDuration = duration
}
- def setRememberDuration(duration: Time) {
+ private[streaming] def setRememberDuration(duration: Time) {
this.synchronized {
if (rememberDuration != null) {
throw new Exception("Batch duration already set as " + batchDuration +
@@ -60,37 +60,49 @@ final class DStreamGraph extends Serializable with Logging {
rememberDuration = duration
}
- def addInputStream(inputStream: InputDStream[_]) {
+ private[streaming] def addInputStream(inputStream: InputDStream[_]) {
this.synchronized {
inputStream.setGraph(this)
inputStreams += inputStream
}
}
- def addOutputStream(outputStream: DStream[_]) {
+ private[streaming] def addOutputStream(outputStream: DStream[_]) {
this.synchronized {
outputStream.setGraph(this)
outputStreams += outputStream
}
}
- def getInputStreams() = inputStreams.toArray
+ private[streaming] def getInputStreams() = this.synchronized { inputStreams.toArray }
- def getOutputStreams() = outputStreams.toArray
+ private[streaming] def getOutputStreams() = this.synchronized { outputStreams.toArray }
- def generateRDDs(time: Time): Seq[Job] = {
+ private[streaming] def generateRDDs(time: Time): Seq[Job] = {
this.synchronized {
outputStreams.flatMap(outputStream => outputStream.generateJob(time))
}
}
- def forgetOldRDDs(time: Time) {
+ private[streaming] def forgetOldRDDs(time: Time) {
this.synchronized {
outputStreams.foreach(_.forgetOldRDDs(time))
}
}
- def validate() {
+ private[streaming] def updateCheckpointData() {
+ this.synchronized {
+ outputStreams.foreach(_.updateCheckpointData())
+ }
+ }
+
+ private[streaming] def restoreCheckpointData() {
+ this.synchronized {
+ outputStreams.foreach(_.restoreCheckpointData())
+ }
+ }
+
+ private[streaming] def validate() {
this.synchronized {
assert(batchDuration != null, "Batch duration has not been set")
assert(batchDuration > Milliseconds(100), "Batch duration of " + batchDuration + " is very low")
diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala
index 7d52e2eddf..2b3f5a4829 100644
--- a/streaming/src/main/scala/spark/streaming/Scheduler.scala
+++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala
@@ -58,7 +58,6 @@ extends Logging {
graph.forgetOldRDDs(time)
if (ssc.checkpointInterval != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointInterval)) {
ssc.doCheckpoint(time)
- logInfo("Checkpointed at time " + time)
}
}
diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala
index 3838e84113..fb36ab9dc9 100644
--- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala
@@ -54,6 +54,7 @@ class StreamingContext (
val graph: DStreamGraph = {
if (isCheckpointPresent) {
cp_.graph.setContext(this)
+ cp_.graph.restoreCheckpointData()
cp_.graph
} else {
new DStreamGraph()
@@ -218,17 +219,16 @@ class StreamingContext (
if (scheduler != null) scheduler.stop()
if (networkInputTracker != null) networkInputTracker.stop()
if (receiverJobThread != null) receiverJobThread.interrupt()
- sc.stop()
+ sc.stop()
+ logInfo("StreamingContext stopped successfully")
} catch {
case e: Exception => logWarning("Error while stopping", e)
}
-
- logInfo("StreamingContext stopped")
}
def doCheckpoint(currentTime: Time) {
+ graph.updateCheckpointData()
new Checkpoint(this, currentTime).save(checkpointDir)
-
}
}
diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala
index dfe31b5771..aa8ded513c 100644
--- a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala
@@ -2,11 +2,11 @@ package spark.streaming
import spark.streaming.StreamingContext._
import java.io.File
-import collection.mutable.ArrayBuffer
import runtime.RichInt
import org.scalatest.BeforeAndAfter
-import org.apache.hadoop.fs.Path
import org.apache.commons.io.FileUtils
+import collection.mutable.{SynchronizedBuffer, ArrayBuffer}
+import util.ManualClock
class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
@@ -18,39 +18,83 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
FileUtils.deleteDirectory(new File(checkpointDir))
}
- override def framework() = "CheckpointSuite"
+ override def framework = "CheckpointSuite"
- override def batchDuration() = Seconds(1)
+ override def batchDuration = Milliseconds(500)
- override def checkpointDir() = "checkpoint"
+ override def checkpointDir = "checkpoint"
- override def checkpointInterval() = batchDuration
+ override def checkpointInterval = batchDuration
- def testCheckpointedOperation[U: ClassManifest, V: ClassManifest](
- input: Seq[Seq[U]],
- operation: DStream[U] => DStream[V],
- expectedOutput: Seq[Seq[V]],
- initialNumBatches: Int
- ) {
+ override def actuallyWait = true
- // Current code assumes that:
- // number of inputs = number of outputs = number of batches to be run
- val totalNumBatches = input.size
- val nextNumBatches = totalNumBatches - initialNumBatches
- val initialNumExpectedOutputs = initialNumBatches
- val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs
+ test("basic stream+rdd recovery") {
- // Do half the computation (half the number of batches), create checkpoint file and quit
- val ssc = setupStreams[U, V](input, operation)
- val output = runStreams[V](ssc, initialNumBatches, initialNumExpectedOutputs)
- verifyOutput[V](output, expectedOutput.take(initialNumBatches), true)
- Thread.sleep(1000)
+ assert(batchDuration === Milliseconds(500), "batchDuration for this test must be 1 second")
+ System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock")
- // Restart and complete the computation from checkpoint file
+ val checkpointingInterval = Seconds(2)
+
+ // this ensure checkpointing occurs at least once
+ val firstNumBatches = (checkpointingInterval.millis / batchDuration.millis) * 2
+ val secondNumBatches = firstNumBatches
+
+ // Setup the streams
+ val input = (1 to 10).map(_ => Seq("a")).toSeq
+ val operation = (st: DStream[String]) => {
+ val updateFunc = (values: Seq[Int], state: Option[RichInt]) => {
+ Some(new RichInt(values.foldLeft(0)(_ + _) + state.map(_.self).getOrElse(0)))
+ }
+ st.map(x => (x, 1))
+ .updateStateByKey[RichInt](updateFunc)
+ .checkpoint(checkpointingInterval)
+ .map(t => (t._1, t._2.self))
+ }
+ val ssc = setupStreams(input, operation)
+ val stateStream = ssc.graph.getOutputStreams().head.dependencies.head.dependencies.head
+
+ // Run till a time such that at least one RDD in the stream should have been checkpointed
+ ssc.start()
+ val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
+ logInfo("Manual clock before advancing = " + clock.time)
+ for (i <- 1 to firstNumBatches.toInt) {
+ clock.addToTime(batchDuration.milliseconds)
+ Thread.sleep(batchDuration.milliseconds)
+ }
+ logInfo("Manual clock after advancing = " + clock.time)
+ Thread.sleep(batchDuration.milliseconds)
+
+ // Check whether some RDD has been checkpointed or not
+ logInfo("Checkpoint data of state stream = \n[" + stateStream.checkpointData.mkString(",\n") + "]")
+ assert(!stateStream.checkpointData.isEmpty, "No checkpointed RDDs in state stream")
+ stateStream.checkpointData.foreach {
+ case (time, data) => {
+ val file = new File(data.toString)
+ assert(file.exists(), "Checkpoint file '" + file +"' for time " + time + " does not exist")
+ }
+ }
+ val checkpointFiles = stateStream.checkpointData.map(x => new File(x._2.toString))
+
+ // Run till a further time such that previous checkpoint files in the stream would be deleted
+ logInfo("Manual clock before advancing = " + clock.time)
+ for (i <- 1 to secondNumBatches.toInt) {
+ clock.addToTime(batchDuration.milliseconds)
+ Thread.sleep(batchDuration.milliseconds)
+ }
+ logInfo("Manual clock after advancing = " + clock.time)
+ Thread.sleep(batchDuration.milliseconds)
+
+ // Check whether the earlier checkpoint files are deleted
+ checkpointFiles.foreach(file => assert(!file.exists, "Checkpoint file '" + file + "' was not deleted"))
+
+ // Restart stream computation using the checkpoint file and check whether
+ // checkpointed RDDs have been restored or not
+ ssc.stop()
val sscNew = new StreamingContext(checkpointDir)
- //sscNew.checkpoint(null, null)
- val outputNew = runStreams[V](sscNew, nextNumBatches, nextNumExpectedOutputs)
- verifyOutput[V](outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), true)
+ val stateStreamNew = sscNew.graph.getOutputStreams().head.dependencies.head.dependencies.head
+ logInfo("Restored data of state stream = \n[" + stateStreamNew.generatedRDDs.mkString("\n") + "]")
+ assert(!stateStreamNew.generatedRDDs.isEmpty, "No restored RDDs in state stream")
+ sscNew.stop()
}
@@ -69,9 +113,9 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
val input = (1 to n).map(x => Seq("a")).toSeq
val output = Seq(Seq(("a", 1)), Seq(("a", 2)), Seq(("a", 3))) ++ (1 to (n - w + 1)).map(x => Seq(("a", 4)))
val operation = (st: DStream[String]) => {
- st.map(x => (x, 1)).reduceByKeyAndWindow(_ + _, _ - _, Seconds(w), Seconds(1))
+ st.map(x => (x, 1)).reduceByKeyAndWindow(_ + _, _ - _, batchDuration * 4, batchDuration)
}
- for (i <- Seq(3, 5, 7)) {
+ for (i <- Seq(2, 3, 4)) {
testCheckpointedOperation(input, operation, output, i)
}
}
@@ -85,12 +129,45 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
}
st.map(x => (x, 1))
.updateStateByKey[RichInt](updateFunc)
- .checkpoint(Seconds(5))
+ .checkpoint(Seconds(2))
.map(t => (t._1, t._2.self))
}
- for (i <- Seq(3, 5, 7)) {
+ for (i <- Seq(2, 3, 4)) {
testCheckpointedOperation(input, operation, output, i)
}
}
+
+
+ def testCheckpointedOperation[U: ClassManifest, V: ClassManifest](
+ input: Seq[Seq[U]],
+ operation: DStream[U] => DStream[V],
+ expectedOutput: Seq[Seq[V]],
+ initialNumBatches: Int
+ ) {
+
+ // Current code assumes that:
+ // number of inputs = number of outputs = number of batches to be run
+ val totalNumBatches = input.size
+ val nextNumBatches = totalNumBatches - initialNumBatches
+ val initialNumExpectedOutputs = initialNumBatches
+ val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs
+
+ // Do half the computation (half the number of batches), create checkpoint file and quit
+
+ val ssc = setupStreams[U, V](input, operation)
+ val output = runStreams[V](ssc, initialNumBatches, initialNumExpectedOutputs)
+ verifyOutput[V](output, expectedOutput.take(initialNumBatches), true)
+ Thread.sleep(1000)
+
+ // Restart and complete the computation from checkpoint file
+ logInfo(
+ "\n-------------------------------------------\n" +
+ " Restarting stream computation " +
+ "\n-------------------------------------------\n"
+ )
+ val sscNew = new StreamingContext(checkpointDir)
+ val outputNew = runStreams[V](sscNew, nextNumBatches, nextNumExpectedOutputs)
+ verifyOutput[V](outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), true)
+ }
} \ No newline at end of file
diff --git a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala
index e441feea19..b8c7f99603 100644
--- a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala
+++ b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala
@@ -57,21 +57,21 @@ class TestOutputStream[T: ClassManifest](parent: DStream[T], val output: ArrayBu
*/
trait TestSuiteBase extends FunSuite with Logging {
- System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock")
+ def framework = "TestSuiteBase"
- def framework() = "TestSuiteBase"
+ def master = "local[2]"
- def master() = "local[2]"
+ def batchDuration = Seconds(1)
- def batchDuration() = Seconds(1)
+ def checkpointDir = null.asInstanceOf[String]
- def checkpointDir() = null.asInstanceOf[String]
+ def checkpointInterval = batchDuration
- def checkpointInterval() = batchDuration
+ def numInputPartitions = 2
- def numInputPartitions() = 2
+ def maxWaitTimeMillis = 10000
- def maxWaitTimeMillis() = 10000
+ def actuallyWait = false
def setupStreams[U: ClassManifest, V: ClassManifest](
input: Seq[Seq[U]],
@@ -82,7 +82,7 @@ trait TestSuiteBase extends FunSuite with Logging {
val ssc = new StreamingContext(master, framework)
ssc.setBatchDuration(batchDuration)
if (checkpointDir != null) {
- ssc.checkpoint(checkpointDir, checkpointInterval())
+ ssc.checkpoint(checkpointDir, checkpointInterval)
}
// Setup the stream computation
@@ -104,7 +104,7 @@ trait TestSuiteBase extends FunSuite with Logging {
val ssc = new StreamingContext(master, framework)
ssc.setBatchDuration(batchDuration)
if (checkpointDir != null) {
- ssc.checkpoint(checkpointDir, checkpointInterval())
+ ssc.checkpoint(checkpointDir, checkpointInterval)
}
// Setup the stream computation
@@ -118,12 +118,19 @@ trait TestSuiteBase extends FunSuite with Logging {
ssc
}
+ /**
+ * Runs the streams set up in `ssc` on manual clock for `numBatches` batches and
+ * returns the collected output. It will wait until `numExpectedOutput` number of
+ * output data has been collected or timeout (set by `maxWaitTimeMillis`) is reached.
+ */
def runStreams[V: ClassManifest](
ssc: StreamingContext,
numBatches: Int,
numExpectedOutput: Int
): Seq[Seq[V]] = {
+ System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock")
+
assert(numBatches > 0, "Number of batches to run stream computation is zero")
assert(numExpectedOutput > 0, "Number of expected outputs after " + numBatches + " is zero")
logInfo("numBatches = " + numBatches + ", numExpectedOutput = " + numExpectedOutput)
@@ -139,7 +146,15 @@ trait TestSuiteBase extends FunSuite with Logging {
// Advance manual clock
val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
logInfo("Manual clock before advancing = " + clock.time)
- clock.addToTime(numBatches * batchDuration.milliseconds)
+ if (actuallyWait) {
+ for (i <- 1 to numBatches) {
+ logInfo("Actually waiting for " + batchDuration)
+ clock.addToTime(batchDuration.milliseconds)
+ Thread.sleep(batchDuration.milliseconds)
+ }
+ } else {
+ clock.addToTime(numBatches * batchDuration.milliseconds)
+ }
logInfo("Manual clock after advancing = " + clock.time)
// Wait until expected number of output items have been generated
diff --git a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala
index e282f0fdd5..3e20e16708 100644
--- a/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala
+++ b/streaming/src/test/scala/spark/streaming/WindowOperationsSuite.scala
@@ -5,11 +5,11 @@ import collection.mutable.ArrayBuffer
class WindowOperationsSuite extends TestSuiteBase {
- override def framework() = "WindowOperationsSuite"
+ override def framework = "WindowOperationsSuite"
- override def maxWaitTimeMillis() = 20000
+ override def maxWaitTimeMillis = 20000
- override def batchDuration() = Seconds(1)
+ override def batchDuration = Seconds(1)
val largerSlideInput = Seq(
Seq(("a", 1)),