aboutsummaryrefslogtreecommitdiff
path: root/streaming
diff options
context:
space:
mode:
authorTathagata Das <tathagata.das1565@gmail.com>2012-10-14 21:39:30 -0700
committerTathagata Das <tathagata.das1565@gmail.com>2012-10-14 21:39:30 -0700
commit3f1aae5c71a220564adc9039dbc0e4b22aea315d (patch)
tree062abae114d095bf2fa1644289c7f774e6bb5b42 /streaming
parentb08708e6fcb59a09b36c5b8e3e7a4aa98f7ad050 (diff)
downloadspark-3f1aae5c71a220564adc9039dbc0e4b22aea315d.tar.gz
spark-3f1aae5c71a220564adc9039dbc0e4b22aea315d.tar.bz2
spark-3f1aae5c71a220564adc9039dbc0e4b22aea315d.zip
Refactored DStreamSuiteBase to create CheckpointSuite- testsuite for testing checkpointing under different operations.
Diffstat (limited to 'streaming')
-rw-r--r--streaming/src/main/scala/spark/streaming/Checkpoint.scala88
-rw-r--r--streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala4
-rw-r--r--streaming/src/main/scala/spark/streaming/DStream.scala4
-rw-r--r--streaming/src/main/scala/spark/streaming/DStreamGraph.scala13
-rw-r--r--streaming/src/main/scala/spark/streaming/Scheduler.scala20
-rw-r--r--streaming/src/main/scala/spark/streaming/StreamingContext.scala18
-rw-r--r--streaming/src/main/scala/spark/streaming/util/Clock.scala11
-rw-r--r--streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala4
-rw-r--r--streaming/src/test/scala/spark/streaming/CheckpointSuite.scala48
-rw-r--r--streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala5
-rw-r--r--streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala171
-rw-r--r--streaming/src/test/scala/spark/streaming/DStreamWindowSuite.scala39
12 files changed, 290 insertions, 135 deletions
diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala
index 3bd8fd5a27..b38911b646 100644
--- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala
@@ -5,11 +5,12 @@ import spark.Utils
import org.apache.hadoop.fs.{FileUtil, Path}
import org.apache.hadoop.conf.Configuration
-import java.io.{ObjectInputStream, ObjectOutputStream}
+import java.io.{InputStream, ObjectStreamClass, ObjectInputStream, ObjectOutputStream}
-class Checkpoint(@transient ssc: StreamingContext) extends Serializable {
+
+class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) extends Serializable {
val master = ssc.sc.master
- val frameworkName = ssc.sc.frameworkName
+ val framework = ssc.sc.frameworkName
val sparkHome = ssc.sc.sparkHome
val jars = ssc.sc.jars
val graph = ssc.graph
@@ -17,7 +18,16 @@ class Checkpoint(@transient ssc: StreamingContext) extends Serializable {
val checkpointFile = ssc.checkpointFile
val checkpointInterval = ssc.checkpointInterval
- def saveToFile(file: String) {
+ validate()
+
+ def validate() {
+ assert(master != null, "Checkpoint.master is null")
+ assert(framework != null, "Checkpoint.framework is null")
+ assert(graph != null, "Checkpoint.graph is null")
+ assert(batchDuration != null, "Checkpoint.batchDuration is null")
+ }
+
+ def saveToFile(file: String = checkpointFile) {
val path = new Path(file)
val conf = new Configuration()
val fs = path.getFileSystem(conf)
@@ -34,8 +44,7 @@ class Checkpoint(@transient ssc: StreamingContext) extends Serializable {
}
def toBytes(): Array[Byte] = {
- val cp = new Checkpoint(ssc)
- val bytes = Utils.serialize(cp)
+ val bytes = Utils.serialize(this)
bytes
}
}
@@ -43,50 +52,41 @@ class Checkpoint(@transient ssc: StreamingContext) extends Serializable {
object Checkpoint {
def loadFromFile(file: String): Checkpoint = {
- val path = new Path(file)
- val conf = new Configuration()
- val fs = path.getFileSystem(conf)
- if (!fs.exists(path)) {
- throw new Exception("Could not read checkpoint file " + path)
+ try {
+ val path = new Path(file)
+ val conf = new Configuration()
+ val fs = path.getFileSystem(conf)
+ if (!fs.exists(path)) {
+ throw new Exception("Checkpoint file '" + file + "' does not exist")
+ }
+ val fis = fs.open(path)
+ val ois = new ObjectInputStreamWithLoader(fis, Thread.currentThread().getContextClassLoader)
+ val cp = ois.readObject.asInstanceOf[Checkpoint]
+ ois.close()
+ fs.close()
+ cp.validate()
+ cp
+ } catch {
+ case e: Exception =>
+ e.printStackTrace()
+ throw new Exception("Could not load checkpoint file '" + file + "'", e)
}
- val fis = fs.open(path)
- val ois = new ObjectInputStream(fis)
- val cp = ois.readObject.asInstanceOf[Checkpoint]
- ois.close()
- fs.close()
- cp
}
def fromBytes(bytes: Array[Byte]): Checkpoint = {
- Utils.deserialize[Checkpoint](bytes)
- }
-
- /*def toBytes(ssc: StreamingContext): Array[Byte] = {
- val cp = new Checkpoint(ssc)
- val bytes = Utils.serialize(cp)
- bytes
+ val cp = Utils.deserialize[Checkpoint](bytes)
+ cp.validate()
+ cp
}
+}
-
- def saveContext(ssc: StreamingContext, file: String) {
- val cp = new Checkpoint(ssc)
- val path = new Path(file)
- val conf = new Configuration()
- val fs = path.getFileSystem(conf)
- if (fs.exists(path)) {
- val bkPath = new Path(path.getParent, path.getName + ".bk")
- FileUtil.copy(fs, path, fs, bkPath, true, true, conf)
- println("Moved existing checkpoint file to " + bkPath)
+class ObjectInputStreamWithLoader(inputStream_ : InputStream, loader: ClassLoader) extends ObjectInputStream(inputStream_) {
+ override def resolveClass(desc: ObjectStreamClass): Class[_] = {
+ try {
+ return loader.loadClass(desc.getName())
+ } catch {
+ case e: Exception =>
}
- val fos = fs.create(path)
- val oos = new ObjectOutputStream(fos)
- oos.writeObject(cp)
- oos.close()
- fs.close()
- }
-
- def loadContext(file: String): StreamingContext = {
- loadCheckpoint(file).createNewContext()
+ return super.resolveClass(desc)
}
- */
}
diff --git a/streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala b/streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala
index 9bc204dd09..80150708fd 100644
--- a/streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala
+++ b/streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala
@@ -5,8 +5,8 @@ import spark.RDD
/**
* An input stream that always returns the same RDD on each timestep. Useful for testing.
*/
-class ConstantInputDStream[T: ClassManifest](ssc: StreamingContext, rdd: RDD[T])
- extends InputDStream[T](ssc) {
+class ConstantInputDStream[T: ClassManifest](ssc_ : StreamingContext, rdd: RDD[T])
+ extends InputDStream[T](ssc_) {
override def start() {}
diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala
index 78e4c57647..0a43a042d0 100644
--- a/streaming/src/main/scala/spark/streaming/DStream.scala
+++ b/streaming/src/main/scala/spark/streaming/DStream.scala
@@ -180,7 +180,7 @@ extends Serializable with Logging {
@throws(classOf[IOException])
private def writeObject(oos: ObjectOutputStream) {
- println(this.getClass().getSimpleName + ".writeObject used")
+ logDebug(this.getClass().getSimpleName + ".writeObject used")
if (graph != null) {
graph.synchronized {
if (graph.checkpointInProgress) {
@@ -202,7 +202,7 @@ extends Serializable with Logging {
@throws(classOf[IOException])
private def readObject(ois: ObjectInputStream) {
- println(this.getClass().getSimpleName + ".readObject used")
+ logDebug(this.getClass().getSimpleName + ".readObject used")
ois.defaultReadObject()
}
diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala
index 67859e0131..bcd365e932 100644
--- a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala
+++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala
@@ -2,8 +2,10 @@ package spark.streaming
import java.io.{ObjectInputStream, IOException, ObjectOutputStream}
import collection.mutable.ArrayBuffer
+import spark.Logging
-final class DStreamGraph extends Serializable {
+final class DStreamGraph extends Serializable with Logging {
+ initLogging()
private val inputStreams = new ArrayBuffer[InputDStream[_]]()
private val outputStreams = new ArrayBuffer[DStream[_]]()
@@ -11,18 +13,15 @@ final class DStreamGraph extends Serializable {
private[streaming] var zeroTime: Time = null
private[streaming] var checkpointInProgress = false;
- def started() = (zeroTime != null)
-
def start(time: Time) {
this.synchronized {
- if (started) {
+ if (zeroTime != null) {
throw new Exception("DStream graph computation already started")
}
zeroTime = time
outputStreams.foreach(_.initialize(zeroTime))
inputStreams.par.foreach(_.start())
}
-
}
def stop() {
@@ -60,21 +59,21 @@ final class DStreamGraph extends Serializable {
@throws(classOf[IOException])
private def writeObject(oos: ObjectOutputStream) {
this.synchronized {
+ logDebug("DStreamGraph.writeObject used")
checkpointInProgress = true
oos.defaultWriteObject()
checkpointInProgress = false
}
- println("DStreamGraph.writeObject used")
}
@throws(classOf[IOException])
private def readObject(ois: ObjectInputStream) {
this.synchronized {
+ logDebug("DStreamGraph.readObject used")
checkpointInProgress = true
ois.defaultReadObject()
checkpointInProgress = false
}
- println("DStreamGraph.readObject used")
}
}
diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala
index d62b7e7140..1e1425a88a 100644
--- a/streaming/src/main/scala/spark/streaming/Scheduler.scala
+++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala
@@ -1,7 +1,6 @@
package spark.streaming
-import spark.streaming.util.RecurringTimer
-import spark.streaming.util.Clock
+import util.{ManualClock, RecurringTimer, Clock}
import spark.SparkEnv
import spark.Logging
@@ -23,13 +22,24 @@ extends Logging {
val clock = Class.forName(clockClass).newInstance().asInstanceOf[Clock]
val timer = new RecurringTimer(clock, ssc.batchDuration, generateRDDs(_))
-
def start() {
- if (graph.started) {
+ // If context was started from checkpoint, then restart timer such that
+ // this timer's triggers occur at the same time as the original timer.
+ // Otherwise just start the timer from scratch, and initialize graph based
+ // on this first trigger time of the timer.
+ if (ssc.isCheckpointPresent) {
+ // If manual clock is being used for testing, then
+ // set manual clock to the last checkpointed time
+ if (clock.isInstanceOf[ManualClock]) {
+ val lastTime = ssc.getInitialCheckpoint.checkpointTime.milliseconds
+ clock.asInstanceOf[ManualClock].setTime(lastTime)
+ }
timer.restart(graph.zeroTime.milliseconds)
+ logInfo("Scheduler's timer restarted")
} else {
val zeroTime = Time(timer.start())
graph.start(zeroTime)
+ logInfo("Scheduler's timer started")
}
logInfo("Scheduler started")
}
@@ -47,7 +57,7 @@ extends Logging {
graph.generateRDDs(time).foreach(submitJob)
logInfo("Generated RDDs for time " + time)
if (ssc.checkpointInterval != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointInterval)) {
- ssc.checkpoint()
+ ssc.doCheckpoint(time)
}
}
diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala
index 1499ef4ea2..e072f15c93 100644
--- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala
@@ -46,7 +46,7 @@ class StreamingContext (
val sc: SparkContext = {
if (isCheckpointPresent) {
- new SparkContext(cp_.master, cp_.frameworkName, cp_.sparkHome, cp_.jars)
+ new SparkContext(cp_.master, cp_.framework, cp_.sparkHome, cp_.jars)
} else {
sc_
}
@@ -85,9 +85,13 @@ class StreamingContext (
checkpointFile = file
checkpointInterval = interval
}
-
+
+ private[streaming] def getInitialCheckpoint(): Checkpoint = {
+ if (isCheckpointPresent) cp_ else null
+ }
+
private[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement()
-
+
def createNetworkTextStream(hostname: String, port: Int): DStream[String] = {
createNetworkObjectStream[String](hostname, port, ObjectInputReceiver.bytesToLines)
}
@@ -156,10 +160,10 @@ class StreamingContext (
inputStream
}
- def createQueueStream[T: ClassManifest](iterator: Array[RDD[T]]): DStream[T] = {
+ def createQueueStream[T: ClassManifest](array: Array[RDD[T]]): DStream[T] = {
val queue = new Queue[RDD[T]]
val inputStream = createQueueStream(queue, true, null)
- queue ++= iterator
+ queue ++= array
inputStream
}
@@ -233,8 +237,8 @@ class StreamingContext (
logInfo("StreamingContext stopped")
}
- def checkpoint() {
- new Checkpoint(this).saveToFile(checkpointFile)
+ def doCheckpoint(currentTime: Time) {
+ new Checkpoint(this, currentTime).saveToFile(checkpointFile)
}
}
diff --git a/streaming/src/main/scala/spark/streaming/util/Clock.scala b/streaming/src/main/scala/spark/streaming/util/Clock.scala
index 72e786e0c3..ed087e4ea8 100644
--- a/streaming/src/main/scala/spark/streaming/util/Clock.scala
+++ b/streaming/src/main/scala/spark/streaming/util/Clock.scala
@@ -56,10 +56,17 @@ class SystemClock() extends Clock {
class ManualClock() extends Clock {
- var time = 0L
-
+ var time = 0L
+
def currentTime() = time
+ def setTime(timeToSet: Long) = {
+ this.synchronized {
+ time = timeToSet
+ this.notifyAll()
+ }
+ }
+
def addToTime(timeToAdd: Long) = {
this.synchronized {
time += timeToAdd
diff --git a/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala
index 7f19b26a79..dc55fd902b 100644
--- a/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala
+++ b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala
@@ -25,13 +25,13 @@ class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) =>
}
def start(): Long = {
- val startTime = math.ceil(clock.currentTime / period).toLong * period
+ val startTime = (math.floor(clock.currentTime.toDouble / period) + 1).toLong * period
start(startTime)
}
def restart(originalStartTime: Long): Long = {
val gap = clock.currentTime - originalStartTime
- val newStartTime = math.ceil(gap / period).toLong * period + originalStartTime
+ val newStartTime = (math.floor(gap.toDouble / period).toLong + 1) * period + originalStartTime
start(newStartTime)
}
diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala
new file mode 100644
index 0000000000..11cecf9822
--- /dev/null
+++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala
@@ -0,0 +1,48 @@
+package spark.streaming
+
+import spark.streaming.StreamingContext._
+
+class CheckpointSuite extends DStreamSuiteBase {
+
+ override def framework() = "CheckpointSuite"
+
+ override def checkpointFile() = "checkpoint"
+
+ def testCheckpointedOperation[U: ClassManifest, V: ClassManifest](
+ input: Seq[Seq[U]],
+ operation: DStream[U] => DStream[V],
+ expectedOutput: Seq[Seq[V]],
+ useSet: Boolean = false
+ ) {
+ System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock")
+
+ // Current code assumes that:
+ // number of inputs = number of outputs = number of batches to be run
+
+ // Do half the computation (half the number of batches), create checkpoint file and quit
+ val totalNumBatches = input.size
+ val initialNumBatches = input.size / 2
+ val nextNumBatches = totalNumBatches - initialNumBatches
+ val initialNumExpectedOutputs = initialNumBatches
+
+ val ssc = setupStreams[U, V](input, operation)
+ val output = runStreams[V](ssc, initialNumBatches, initialNumExpectedOutputs)
+ verifyOutput[V](output, expectedOutput.take(initialNumBatches), useSet)
+ Thread.sleep(1000)
+
+ // Restart and complete the computation from checkpoint file
+ val sscNew = new StreamingContext(checkpointFile)
+ sscNew.setCheckpointDetails(null, null)
+ val outputNew = runStreams[V](sscNew, nextNumBatches, expectedOutput.size)
+ verifyOutput[V](outputNew, expectedOutput, useSet)
+ }
+
+ test("simple per-batch operation") {
+ testCheckpointedOperation(
+ Seq( Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq() ),
+ (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _),
+ Seq( Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq() ),
+ true
+ )
+ }
+} \ No newline at end of file
diff --git a/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala b/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala
index 965b58c03f..f8ca7febe7 100644
--- a/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala
+++ b/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala
@@ -22,9 +22,9 @@ class DStreamBasicSuite extends DStreamSuiteBase {
test("shuffle-based operations") {
// reduceByKey
testOperation(
- Seq(Seq("a", "a", "b"), Seq("", ""), Seq()),
+ Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ),
(s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _),
- Seq(Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq()),
+ Seq( Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq() ),
true
)
@@ -62,7 +62,6 @@ class DStreamBasicSuite extends DStreamSuiteBase {
var newState = 0
if (values != null && values.size > 0) newState += values.reduce(_ + _)
if (state != null) newState += state.self
- //println("values = " + values + ", state = " + state + ", " + " new state = " + newState)
new RichInt(newState)
}
s.map(x => (x, 1)).updateStateByKey[RichInt](updateFunc).map(t => (t._1, t._2.self))
diff --git a/streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala b/streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala
index 59fe36baf0..cb95c36782 100644
--- a/streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala
+++ b/streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala
@@ -4,70 +4,157 @@ import spark.{RDD, Logging}
import util.ManualClock
import collection.mutable.ArrayBuffer
import org.scalatest.FunSuite
-import scala.collection.mutable.Queue
+import collection.mutable.SynchronizedBuffer
+class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, val input: Seq[Seq[T]])
+ extends InputDStream[T](ssc_) {
+ var currentIndex = 0
+
+ def start() {}
+
+ def stop() {}
+
+ def compute(validTime: Time): Option[RDD[T]] = {
+ logInfo("Computing RDD for time " + validTime)
+ val rdd = if (currentIndex < input.size) {
+ ssc.sc.makeRDD(input(currentIndex), 2)
+ } else {
+ ssc.sc.makeRDD(Seq[T](), 2)
+ }
+ logInfo("Created RDD " + rdd.id)
+ currentIndex += 1
+ Some(rdd)
+ }
+}
+
+class TestOutputStream[T: ClassManifest](parent: DStream[T], val output: ArrayBuffer[Seq[T]])
+ extends PerRDDForEachDStream[T](parent, (rdd: RDD[T], t: Time) => {
+ val collected = rdd.collect()
+ output += collected
+ })
trait DStreamSuiteBase extends FunSuite with Logging {
- def batchDuration() = Seconds(1)
+ def framework() = "DStreamSuiteBase"
- def maxWaitTimeMillis() = 10000
+ def master() = "local[2]"
- def testOperation[U: ClassManifest, V: ClassManifest](
- input: Seq[Seq[U]],
- operation: DStream[U] => DStream[V],
- expectedOutput: Seq[Seq[V]],
- useSet: Boolean = false
- ) {
+ def batchDuration() = Seconds(1)
- val manualClock = true
+ def checkpointFile() = null.asInstanceOf[String]
- if (manualClock) {
- System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock")
- }
+ def checkpointInterval() = batchDuration
- val ssc = new StreamingContext("local", "test")
+ def maxWaitTimeMillis() = 10000
- try {
- ssc.setBatchDuration(Milliseconds(batchDuration))
+ def setupStreams[U: ClassManifest, V: ClassManifest](
+ input: Seq[Seq[U]],
+ operation: DStream[U] => DStream[V]
+ ): StreamingContext = {
+
+ // Create StreamingContext
+ val ssc = new StreamingContext(master, framework)
+ ssc.setBatchDuration(batchDuration)
+ if (checkpointFile != null) {
+ ssc.setCheckpointDetails(checkpointFile, checkpointInterval())
+ }
- val inputQueue = new Queue[RDD[U]]()
- inputQueue ++= input.map(ssc.sc.makeRDD(_, 2))
- val emptyRDD = ssc.sc.makeRDD(Seq[U](), 2)
+ // Setup the stream computation
+ val inputStream = new TestInputStream(ssc, input)
+ ssc.registerInputStream(inputStream)
+ val operatedStream = operation(inputStream)
+ val outputStream = new TestOutputStream(operatedStream, new ArrayBuffer[Seq[V]] with SynchronizedBuffer[Seq[V]])
+ ssc.registerOutputStream(outputStream)
+ ssc
+ }
- val inputStream = ssc.createQueueStream(inputQueue, true, emptyRDD)
- val outputStream = operation(inputStream)
+ def runStreams[V: ClassManifest](
+ ssc: StreamingContext,
+ numBatches: Int,
+ numExpectedOutput: Int
+ ): Seq[Seq[V]] = {
+ logInfo("numBatches = " + numBatches + ", numExpectedOutput = " + numExpectedOutput)
- val output = new ArrayBuffer[Seq[V]]()
- outputStream.foreachRDD(rdd => output += rdd.collect())
+ // Get the output buffer
+ val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[V]]
+ val output = outputStream.output
+ try {
+ // Start computation
ssc.start()
- val clock = ssc.scheduler.clock
- if (clock.isInstanceOf[ManualClock]) {
- clock.asInstanceOf[ManualClock].addToTime((input.size - 1) * batchDuration.milliseconds)
- }
+ // Advance manual clock
+ val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
+ logInfo("Manual clock before advancing = " + clock.time)
+ clock.addToTime(numBatches * batchDuration.milliseconds)
+ logInfo("Manual clock after advancing = " + clock.time)
+ // Wait until expected number of output items have been generated
val startTime = System.currentTimeMillis()
- while (output.size < expectedOutput.size && System.currentTimeMillis() - startTime < maxWaitTimeMillis) {
- println("output.size = " + output.size + ", expectedOutput.size = " + expectedOutput.size)
- Thread.sleep(500)
+ while (output.size < numExpectedOutput && System.currentTimeMillis() - startTime < maxWaitTimeMillis) {
+ logInfo("output.size = " + output.size + ", numExpectedOutput = " + numExpectedOutput)
+ Thread.sleep(100)
}
+ val timeTaken = System.currentTimeMillis() - startTime
- println("output.size = " + output.size)
- println("output")
- output.foreach(x => println("[" + x.mkString(",") + "]"))
-
- assert(output.size === expectedOutput.size)
- for (i <- 0 until output.size) {
- if (useSet) {
- assert(output(i).toSet === expectedOutput(i).toSet)
- } else {
- assert(output(i).toList === expectedOutput(i).toList)
- }
- }
+ assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms")
+ assert(output.size === numExpectedOutput, "Unexpected number of outputs generated")
+ } catch {
+ case e: Exception => e.printStackTrace(); throw e;
} finally {
ssc.stop()
}
+
+ output
+ }
+
+ def verifyOutput[V: ClassManifest](
+ output: Seq[Seq[V]],
+ expectedOutput: Seq[Seq[V]],
+ useSet: Boolean
+ ) {
+ logInfo("--------------------------------")
+ logInfo("output.size = " + output.size)
+ logInfo("output")
+ output.foreach(x => logInfo("[" + x.mkString(",") + "]"))
+ logInfo("expected output.size = " + expectedOutput.size)
+ logInfo("expected output")
+ expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]"))
+ logInfo("--------------------------------")
+
+ // Match the output with the expected output
+ assert(output.size === expectedOutput.size, "Number of outputs do not match")
+ for (i <- 0 until output.size) {
+ if (useSet) {
+ assert(output(i).toSet === expectedOutput(i).toSet)
+ } else {
+ assert(output(i).toList === expectedOutput(i).toList)
+ }
+ }
+ logInfo("Output verified successfully")
+ }
+
+ def testOperation[U: ClassManifest, V: ClassManifest](
+ input: Seq[Seq[U]],
+ operation: DStream[U] => DStream[V],
+ expectedOutput: Seq[Seq[V]],
+ useSet: Boolean = false
+ ) {
+ testOperation[U, V](input, operation, expectedOutput, -1, useSet)
+ }
+
+ def testOperation[U: ClassManifest, V: ClassManifest](
+ input: Seq[Seq[U]],
+ operation: DStream[U] => DStream[V],
+ expectedOutput: Seq[Seq[V]],
+ numBatches: Int,
+ useSet: Boolean
+ ) {
+ System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock")
+
+ val numBatches_ = if (numBatches > 0) numBatches else expectedOutput.size
+ val ssc = setupStreams[U, V](input, operation)
+ val output = runStreams[V](ssc, numBatches_, expectedOutput.size)
+ verifyOutput[V](output, expectedOutput, useSet)
}
}
diff --git a/streaming/src/test/scala/spark/streaming/DStreamWindowSuite.scala b/streaming/src/test/scala/spark/streaming/DStreamWindowSuite.scala
index 061cab2cbb..8dd18f491a 100644
--- a/streaming/src/test/scala/spark/streaming/DStreamWindowSuite.scala
+++ b/streaming/src/test/scala/spark/streaming/DStreamWindowSuite.scala
@@ -4,6 +4,10 @@ import spark.streaming.StreamingContext._
class DStreamWindowSuite extends DStreamSuiteBase {
+ override def framework() = "DStreamWindowSuite"
+
+ override def maxWaitTimeMillis() = 20000
+
val largerSlideInput = Seq(
Seq(("a", 1)), // 1st window from here
Seq(("a", 2)),
@@ -81,16 +85,15 @@ class DStreamWindowSuite extends DStreamSuiteBase {
name: String,
input: Seq[Seq[(String, Int)]],
expectedOutput: Seq[Seq[(String, Int)]],
- windowTime: Time = Seconds(2),
- slideTime: Time = Seconds(1)
+ windowTime: Time = batchDuration * 2,
+ slideTime: Time = batchDuration
) {
test("reduceByKeyAndWindow - " + name) {
- testOperation(
- input,
- (s: DStream[(String, Int)]) => s.reduceByKeyAndWindow(_ + _, windowTime, slideTime).persist(),
- expectedOutput,
- true
- )
+ val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt
+ val operation = (s: DStream[(String, Int)]) => {
+ s.reduceByKeyAndWindow(_ + _, windowTime, slideTime).persist()
+ }
+ testOperation(input, operation, expectedOutput, numBatches, true)
}
}
@@ -98,16 +101,15 @@ class DStreamWindowSuite extends DStreamSuiteBase {
name: String,
input: Seq[Seq[(String, Int)]],
expectedOutput: Seq[Seq[(String, Int)]],
- windowTime: Time = Seconds(2),
- slideTime: Time = Seconds(1)
+ windowTime: Time = batchDuration * 2,
+ slideTime: Time = batchDuration
) {
test("reduceByKeyAndWindowInv - " + name) {
- testOperation(
- input,
- (s: DStream[(String, Int)]) => s.reduceByKeyAndWindow(_ + _, _ - _, windowTime, slideTime).persist(),
- expectedOutput,
- true
- )
+ val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt
+ val operation = (s: DStream[(String, Int)]) => {
+ s.reduceByKeyAndWindow(_ + _, _ - _, windowTime, slideTime).persist()
+ }
+ testOperation(input, operation, expectedOutput, numBatches, true)
}
}
@@ -116,8 +118,8 @@ class DStreamWindowSuite extends DStreamSuiteBase {
testReduceByKeyAndWindow(
"basic reduction",
- Seq(Seq(("a", 1), ("a", 3)) ),
- Seq(Seq(("a", 4)) )
+ Seq( Seq(("a", 1), ("a", 3)) ),
+ Seq( Seq(("a", 4)) )
)
testReduceByKeyAndWindow(
@@ -126,7 +128,6 @@ class DStreamWindowSuite extends DStreamSuiteBase {
Seq( Seq(("a", 1)), Seq(("a", 2)) )
)
-
testReduceByKeyAndWindow(
"new key added into window",
Seq( Seq(("a", 1)), Seq(("a", 1), ("b", 1)) ),