diff options
Diffstat (limited to 'streaming')
12 files changed, 290 insertions, 135 deletions
diff --git a/streaming/src/main/scala/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/spark/streaming/Checkpoint.scala index 3bd8fd5a27..b38911b646 100644 --- a/streaming/src/main/scala/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/spark/streaming/Checkpoint.scala @@ -5,11 +5,12 @@ import spark.Utils import org.apache.hadoop.fs.{FileUtil, Path} import org.apache.hadoop.conf.Configuration -import java.io.{ObjectInputStream, ObjectOutputStream} +import java.io.{InputStream, ObjectStreamClass, ObjectInputStream, ObjectOutputStream} -class Checkpoint(@transient ssc: StreamingContext) extends Serializable { + +class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) extends Serializable { val master = ssc.sc.master - val frameworkName = ssc.sc.frameworkName + val framework = ssc.sc.frameworkName val sparkHome = ssc.sc.sparkHome val jars = ssc.sc.jars val graph = ssc.graph @@ -17,7 +18,16 @@ class Checkpoint(@transient ssc: StreamingContext) extends Serializable { val checkpointFile = ssc.checkpointFile val checkpointInterval = ssc.checkpointInterval - def saveToFile(file: String) { + validate() + + def validate() { + assert(master != null, "Checkpoint.master is null") + assert(framework != null, "Checkpoint.framework is null") + assert(graph != null, "Checkpoint.graph is null") + assert(batchDuration != null, "Checkpoint.batchDuration is null") + } + + def saveToFile(file: String = checkpointFile) { val path = new Path(file) val conf = new Configuration() val fs = path.getFileSystem(conf) @@ -34,8 +44,7 @@ class Checkpoint(@transient ssc: StreamingContext) extends Serializable { } def toBytes(): Array[Byte] = { - val cp = new Checkpoint(ssc) - val bytes = Utils.serialize(cp) + val bytes = Utils.serialize(this) bytes } } @@ -43,50 +52,41 @@ class Checkpoint(@transient ssc: StreamingContext) extends Serializable { object Checkpoint { def loadFromFile(file: String): Checkpoint = { - val path = new Path(file) - val conf = new Configuration() - val fs = path.getFileSystem(conf) - if (!fs.exists(path)) { - throw new Exception("Could not read checkpoint file " + path) + try { + val path = new Path(file) + val conf = new Configuration() + val fs = path.getFileSystem(conf) + if (!fs.exists(path)) { + throw new Exception("Checkpoint file '" + file + "' does not exist") + } + val fis = fs.open(path) + val ois = new ObjectInputStreamWithLoader(fis, Thread.currentThread().getContextClassLoader) + val cp = ois.readObject.asInstanceOf[Checkpoint] + ois.close() + fs.close() + cp.validate() + cp + } catch { + case e: Exception => + e.printStackTrace() + throw new Exception("Could not load checkpoint file '" + file + "'", e) } - val fis = fs.open(path) - val ois = new ObjectInputStream(fis) - val cp = ois.readObject.asInstanceOf[Checkpoint] - ois.close() - fs.close() - cp } def fromBytes(bytes: Array[Byte]): Checkpoint = { - Utils.deserialize[Checkpoint](bytes) - } - - /*def toBytes(ssc: StreamingContext): Array[Byte] = { - val cp = new Checkpoint(ssc) - val bytes = Utils.serialize(cp) - bytes + val cp = Utils.deserialize[Checkpoint](bytes) + cp.validate() + cp } +} - - def saveContext(ssc: StreamingContext, file: String) { - val cp = new Checkpoint(ssc) - val path = new Path(file) - val conf = new Configuration() - val fs = path.getFileSystem(conf) - if (fs.exists(path)) { - val bkPath = new Path(path.getParent, path.getName + ".bk") - FileUtil.copy(fs, path, fs, bkPath, true, true, conf) - println("Moved existing checkpoint file to " + bkPath) +class ObjectInputStreamWithLoader(inputStream_ : InputStream, loader: ClassLoader) extends ObjectInputStream(inputStream_) { + override def resolveClass(desc: ObjectStreamClass): Class[_] = { + try { + return loader.loadClass(desc.getName()) + } catch { + case e: Exception => } - val fos = fs.create(path) - val oos = new ObjectOutputStream(fos) - oos.writeObject(cp) - oos.close() - fs.close() - } - - def loadContext(file: String): StreamingContext = { - loadCheckpoint(file).createNewContext() + return super.resolveClass(desc) } - */ } diff --git a/streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala b/streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala index 9bc204dd09..80150708fd 100644 --- a/streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala +++ b/streaming/src/main/scala/spark/streaming/ConstantInputDStream.scala @@ -5,8 +5,8 @@ import spark.RDD /** * An input stream that always returns the same RDD on each timestep. Useful for testing. */ -class ConstantInputDStream[T: ClassManifest](ssc: StreamingContext, rdd: RDD[T]) - extends InputDStream[T](ssc) { +class ConstantInputDStream[T: ClassManifest](ssc_ : StreamingContext, rdd: RDD[T]) + extends InputDStream[T](ssc_) { override def start() {} diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala index 78e4c57647..0a43a042d0 100644 --- a/streaming/src/main/scala/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/spark/streaming/DStream.scala @@ -180,7 +180,7 @@ extends Serializable with Logging { @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { - println(this.getClass().getSimpleName + ".writeObject used") + logDebug(this.getClass().getSimpleName + ".writeObject used") if (graph != null) { graph.synchronized { if (graph.checkpointInProgress) { @@ -202,7 +202,7 @@ extends Serializable with Logging { @throws(classOf[IOException]) private def readObject(ois: ObjectInputStream) { - println(this.getClass().getSimpleName + ".readObject used") + logDebug(this.getClass().getSimpleName + ".readObject used") ois.defaultReadObject() } diff --git a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala index 67859e0131..bcd365e932 100644 --- a/streaming/src/main/scala/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/spark/streaming/DStreamGraph.scala @@ -2,8 +2,10 @@ package spark.streaming import java.io.{ObjectInputStream, IOException, ObjectOutputStream} import collection.mutable.ArrayBuffer +import spark.Logging -final class DStreamGraph extends Serializable { +final class DStreamGraph extends Serializable with Logging { + initLogging() private val inputStreams = new ArrayBuffer[InputDStream[_]]() private val outputStreams = new ArrayBuffer[DStream[_]]() @@ -11,18 +13,15 @@ final class DStreamGraph extends Serializable { private[streaming] var zeroTime: Time = null private[streaming] var checkpointInProgress = false; - def started() = (zeroTime != null) - def start(time: Time) { this.synchronized { - if (started) { + if (zeroTime != null) { throw new Exception("DStream graph computation already started") } zeroTime = time outputStreams.foreach(_.initialize(zeroTime)) inputStreams.par.foreach(_.start()) } - } def stop() { @@ -60,21 +59,21 @@ final class DStreamGraph extends Serializable { @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream) { this.synchronized { + logDebug("DStreamGraph.writeObject used") checkpointInProgress = true oos.defaultWriteObject() checkpointInProgress = false } - println("DStreamGraph.writeObject used") } @throws(classOf[IOException]) private def readObject(ois: ObjectInputStream) { this.synchronized { + logDebug("DStreamGraph.readObject used") checkpointInProgress = true ois.defaultReadObject() checkpointInProgress = false } - println("DStreamGraph.readObject used") } } diff --git a/streaming/src/main/scala/spark/streaming/Scheduler.scala b/streaming/src/main/scala/spark/streaming/Scheduler.scala index d62b7e7140..1e1425a88a 100644 --- a/streaming/src/main/scala/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/spark/streaming/Scheduler.scala @@ -1,7 +1,6 @@ package spark.streaming -import spark.streaming.util.RecurringTimer -import spark.streaming.util.Clock +import util.{ManualClock, RecurringTimer, Clock} import spark.SparkEnv import spark.Logging @@ -23,13 +22,24 @@ extends Logging { val clock = Class.forName(clockClass).newInstance().asInstanceOf[Clock] val timer = new RecurringTimer(clock, ssc.batchDuration, generateRDDs(_)) - def start() { - if (graph.started) { + // If context was started from checkpoint, then restart timer such that + // this timer's triggers occur at the same time as the original timer. + // Otherwise just start the timer from scratch, and initialize graph based + // on this first trigger time of the timer. + if (ssc.isCheckpointPresent) { + // If manual clock is being used for testing, then + // set manual clock to the last checkpointed time + if (clock.isInstanceOf[ManualClock]) { + val lastTime = ssc.getInitialCheckpoint.checkpointTime.milliseconds + clock.asInstanceOf[ManualClock].setTime(lastTime) + } timer.restart(graph.zeroTime.milliseconds) + logInfo("Scheduler's timer restarted") } else { val zeroTime = Time(timer.start()) graph.start(zeroTime) + logInfo("Scheduler's timer started") } logInfo("Scheduler started") } @@ -47,7 +57,7 @@ extends Logging { graph.generateRDDs(time).foreach(submitJob) logInfo("Generated RDDs for time " + time) if (ssc.checkpointInterval != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointInterval)) { - ssc.checkpoint() + ssc.doCheckpoint(time) } } diff --git a/streaming/src/main/scala/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/spark/streaming/StreamingContext.scala index 1499ef4ea2..e072f15c93 100644 --- a/streaming/src/main/scala/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/spark/streaming/StreamingContext.scala @@ -46,7 +46,7 @@ class StreamingContext ( val sc: SparkContext = { if (isCheckpointPresent) { - new SparkContext(cp_.master, cp_.frameworkName, cp_.sparkHome, cp_.jars) + new SparkContext(cp_.master, cp_.framework, cp_.sparkHome, cp_.jars) } else { sc_ } @@ -85,9 +85,13 @@ class StreamingContext ( checkpointFile = file checkpointInterval = interval } - + + private[streaming] def getInitialCheckpoint(): Checkpoint = { + if (isCheckpointPresent) cp_ else null + } + private[streaming] def getNewNetworkStreamId() = nextNetworkInputStreamId.getAndIncrement() - + def createNetworkTextStream(hostname: String, port: Int): DStream[String] = { createNetworkObjectStream[String](hostname, port, ObjectInputReceiver.bytesToLines) } @@ -156,10 +160,10 @@ class StreamingContext ( inputStream } - def createQueueStream[T: ClassManifest](iterator: Array[RDD[T]]): DStream[T] = { + def createQueueStream[T: ClassManifest](array: Array[RDD[T]]): DStream[T] = { val queue = new Queue[RDD[T]] val inputStream = createQueueStream(queue, true, null) - queue ++= iterator + queue ++= array inputStream } @@ -233,8 +237,8 @@ class StreamingContext ( logInfo("StreamingContext stopped") } - def checkpoint() { - new Checkpoint(this).saveToFile(checkpointFile) + def doCheckpoint(currentTime: Time) { + new Checkpoint(this, currentTime).saveToFile(checkpointFile) } } diff --git a/streaming/src/main/scala/spark/streaming/util/Clock.scala b/streaming/src/main/scala/spark/streaming/util/Clock.scala index 72e786e0c3..ed087e4ea8 100644 --- a/streaming/src/main/scala/spark/streaming/util/Clock.scala +++ b/streaming/src/main/scala/spark/streaming/util/Clock.scala @@ -56,10 +56,17 @@ class SystemClock() extends Clock { class ManualClock() extends Clock { - var time = 0L - + var time = 0L + def currentTime() = time + def setTime(timeToSet: Long) = { + this.synchronized { + time = timeToSet + this.notifyAll() + } + } + def addToTime(timeToAdd: Long) = { this.synchronized { time += timeToAdd diff --git a/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala index 7f19b26a79..dc55fd902b 100644 --- a/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala +++ b/streaming/src/main/scala/spark/streaming/util/RecurringTimer.scala @@ -25,13 +25,13 @@ class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) => } def start(): Long = { - val startTime = math.ceil(clock.currentTime / period).toLong * period + val startTime = (math.floor(clock.currentTime.toDouble / period) + 1).toLong * period start(startTime) } def restart(originalStartTime: Long): Long = { val gap = clock.currentTime - originalStartTime - val newStartTime = math.ceil(gap / period).toLong * period + originalStartTime + val newStartTime = (math.floor(gap.toDouble / period).toLong + 1) * period + originalStartTime start(newStartTime) } diff --git a/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala new file mode 100644 index 0000000000..11cecf9822 --- /dev/null +++ b/streaming/src/test/scala/spark/streaming/CheckpointSuite.scala @@ -0,0 +1,48 @@ +package spark.streaming + +import spark.streaming.StreamingContext._ + +class CheckpointSuite extends DStreamSuiteBase { + + override def framework() = "CheckpointSuite" + + override def checkpointFile() = "checkpoint" + + def testCheckpointedOperation[U: ClassManifest, V: ClassManifest]( + input: Seq[Seq[U]], + operation: DStream[U] => DStream[V], + expectedOutput: Seq[Seq[V]], + useSet: Boolean = false + ) { + System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") + + // Current code assumes that: + // number of inputs = number of outputs = number of batches to be run + + // Do half the computation (half the number of batches), create checkpoint file and quit + val totalNumBatches = input.size + val initialNumBatches = input.size / 2 + val nextNumBatches = totalNumBatches - initialNumBatches + val initialNumExpectedOutputs = initialNumBatches + + val ssc = setupStreams[U, V](input, operation) + val output = runStreams[V](ssc, initialNumBatches, initialNumExpectedOutputs) + verifyOutput[V](output, expectedOutput.take(initialNumBatches), useSet) + Thread.sleep(1000) + + // Restart and complete the computation from checkpoint file + val sscNew = new StreamingContext(checkpointFile) + sscNew.setCheckpointDetails(null, null) + val outputNew = runStreams[V](sscNew, nextNumBatches, expectedOutput.size) + verifyOutput[V](outputNew, expectedOutput, useSet) + } + + test("simple per-batch operation") { + testCheckpointedOperation( + Seq( Seq("a", "a", "b"), Seq("", ""), Seq(), Seq("a", "a", "b"), Seq("", ""), Seq() ), + (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _), + Seq( Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq(), Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq() ), + true + ) + } +}
\ No newline at end of file diff --git a/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala b/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala index 965b58c03f..f8ca7febe7 100644 --- a/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala +++ b/streaming/src/test/scala/spark/streaming/DStreamBasicSuite.scala @@ -22,9 +22,9 @@ class DStreamBasicSuite extends DStreamSuiteBase { test("shuffle-based operations") { // reduceByKey testOperation( - Seq(Seq("a", "a", "b"), Seq("", ""), Seq()), + Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ), (s: DStream[String]) => s.map(x => (x, 1)).reduceByKey(_ + _), - Seq(Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq()), + Seq( Seq(("a", 2), ("b", 1)), Seq(("", 2)), Seq() ), true ) @@ -62,7 +62,6 @@ class DStreamBasicSuite extends DStreamSuiteBase { var newState = 0 if (values != null && values.size > 0) newState += values.reduce(_ + _) if (state != null) newState += state.self - //println("values = " + values + ", state = " + state + ", " + " new state = " + newState) new RichInt(newState) } s.map(x => (x, 1)).updateStateByKey[RichInt](updateFunc).map(t => (t._1, t._2.self)) diff --git a/streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala b/streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala index 59fe36baf0..cb95c36782 100644 --- a/streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala +++ b/streaming/src/test/scala/spark/streaming/DStreamSuiteBase.scala @@ -4,70 +4,157 @@ import spark.{RDD, Logging} import util.ManualClock import collection.mutable.ArrayBuffer import org.scalatest.FunSuite -import scala.collection.mutable.Queue +import collection.mutable.SynchronizedBuffer +class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, val input: Seq[Seq[T]]) + extends InputDStream[T](ssc_) { + var currentIndex = 0 + + def start() {} + + def stop() {} + + def compute(validTime: Time): Option[RDD[T]] = { + logInfo("Computing RDD for time " + validTime) + val rdd = if (currentIndex < input.size) { + ssc.sc.makeRDD(input(currentIndex), 2) + } else { + ssc.sc.makeRDD(Seq[T](), 2) + } + logInfo("Created RDD " + rdd.id) + currentIndex += 1 + Some(rdd) + } +} + +class TestOutputStream[T: ClassManifest](parent: DStream[T], val output: ArrayBuffer[Seq[T]]) + extends PerRDDForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { + val collected = rdd.collect() + output += collected + }) trait DStreamSuiteBase extends FunSuite with Logging { - def batchDuration() = Seconds(1) + def framework() = "DStreamSuiteBase" - def maxWaitTimeMillis() = 10000 + def master() = "local[2]" - def testOperation[U: ClassManifest, V: ClassManifest]( - input: Seq[Seq[U]], - operation: DStream[U] => DStream[V], - expectedOutput: Seq[Seq[V]], - useSet: Boolean = false - ) { + def batchDuration() = Seconds(1) - val manualClock = true + def checkpointFile() = null.asInstanceOf[String] - if (manualClock) { - System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") - } + def checkpointInterval() = batchDuration - val ssc = new StreamingContext("local", "test") + def maxWaitTimeMillis() = 10000 - try { - ssc.setBatchDuration(Milliseconds(batchDuration)) + def setupStreams[U: ClassManifest, V: ClassManifest]( + input: Seq[Seq[U]], + operation: DStream[U] => DStream[V] + ): StreamingContext = { + + // Create StreamingContext + val ssc = new StreamingContext(master, framework) + ssc.setBatchDuration(batchDuration) + if (checkpointFile != null) { + ssc.setCheckpointDetails(checkpointFile, checkpointInterval()) + } - val inputQueue = new Queue[RDD[U]]() - inputQueue ++= input.map(ssc.sc.makeRDD(_, 2)) - val emptyRDD = ssc.sc.makeRDD(Seq[U](), 2) + // Setup the stream computation + val inputStream = new TestInputStream(ssc, input) + ssc.registerInputStream(inputStream) + val operatedStream = operation(inputStream) + val outputStream = new TestOutputStream(operatedStream, new ArrayBuffer[Seq[V]] with SynchronizedBuffer[Seq[V]]) + ssc.registerOutputStream(outputStream) + ssc + } - val inputStream = ssc.createQueueStream(inputQueue, true, emptyRDD) - val outputStream = operation(inputStream) + def runStreams[V: ClassManifest]( + ssc: StreamingContext, + numBatches: Int, + numExpectedOutput: Int + ): Seq[Seq[V]] = { + logInfo("numBatches = " + numBatches + ", numExpectedOutput = " + numExpectedOutput) - val output = new ArrayBuffer[Seq[V]]() - outputStream.foreachRDD(rdd => output += rdd.collect()) + // Get the output buffer + val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[V]] + val output = outputStream.output + try { + // Start computation ssc.start() - val clock = ssc.scheduler.clock - if (clock.isInstanceOf[ManualClock]) { - clock.asInstanceOf[ManualClock].addToTime((input.size - 1) * batchDuration.milliseconds) - } + // Advance manual clock + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + logInfo("Manual clock before advancing = " + clock.time) + clock.addToTime(numBatches * batchDuration.milliseconds) + logInfo("Manual clock after advancing = " + clock.time) + // Wait until expected number of output items have been generated val startTime = System.currentTimeMillis() - while (output.size < expectedOutput.size && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { - println("output.size = " + output.size + ", expectedOutput.size = " + expectedOutput.size) - Thread.sleep(500) + while (output.size < numExpectedOutput && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { + logInfo("output.size = " + output.size + ", numExpectedOutput = " + numExpectedOutput) + Thread.sleep(100) } + val timeTaken = System.currentTimeMillis() - startTime - println("output.size = " + output.size) - println("output") - output.foreach(x => println("[" + x.mkString(",") + "]")) - - assert(output.size === expectedOutput.size) - for (i <- 0 until output.size) { - if (useSet) { - assert(output(i).toSet === expectedOutput(i).toSet) - } else { - assert(output(i).toList === expectedOutput(i).toList) - } - } + assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms") + assert(output.size === numExpectedOutput, "Unexpected number of outputs generated") + } catch { + case e: Exception => e.printStackTrace(); throw e; } finally { ssc.stop() } + + output + } + + def verifyOutput[V: ClassManifest]( + output: Seq[Seq[V]], + expectedOutput: Seq[Seq[V]], + useSet: Boolean + ) { + logInfo("--------------------------------") + logInfo("output.size = " + output.size) + logInfo("output") + output.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("expected output.size = " + expectedOutput.size) + logInfo("expected output") + expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]")) + logInfo("--------------------------------") + + // Match the output with the expected output + assert(output.size === expectedOutput.size, "Number of outputs do not match") + for (i <- 0 until output.size) { + if (useSet) { + assert(output(i).toSet === expectedOutput(i).toSet) + } else { + assert(output(i).toList === expectedOutput(i).toList) + } + } + logInfo("Output verified successfully") + } + + def testOperation[U: ClassManifest, V: ClassManifest]( + input: Seq[Seq[U]], + operation: DStream[U] => DStream[V], + expectedOutput: Seq[Seq[V]], + useSet: Boolean = false + ) { + testOperation[U, V](input, operation, expectedOutput, -1, useSet) + } + + def testOperation[U: ClassManifest, V: ClassManifest]( + input: Seq[Seq[U]], + operation: DStream[U] => DStream[V], + expectedOutput: Seq[Seq[V]], + numBatches: Int, + useSet: Boolean + ) { + System.setProperty("spark.streaming.clock", "spark.streaming.util.ManualClock") + + val numBatches_ = if (numBatches > 0) numBatches else expectedOutput.size + val ssc = setupStreams[U, V](input, operation) + val output = runStreams[V](ssc, numBatches_, expectedOutput.size) + verifyOutput[V](output, expectedOutput, useSet) } } diff --git a/streaming/src/test/scala/spark/streaming/DStreamWindowSuite.scala b/streaming/src/test/scala/spark/streaming/DStreamWindowSuite.scala index 061cab2cbb..8dd18f491a 100644 --- a/streaming/src/test/scala/spark/streaming/DStreamWindowSuite.scala +++ b/streaming/src/test/scala/spark/streaming/DStreamWindowSuite.scala @@ -4,6 +4,10 @@ import spark.streaming.StreamingContext._ class DStreamWindowSuite extends DStreamSuiteBase { + override def framework() = "DStreamWindowSuite" + + override def maxWaitTimeMillis() = 20000 + val largerSlideInput = Seq( Seq(("a", 1)), // 1st window from here Seq(("a", 2)), @@ -81,16 +85,15 @@ class DStreamWindowSuite extends DStreamSuiteBase { name: String, input: Seq[Seq[(String, Int)]], expectedOutput: Seq[Seq[(String, Int)]], - windowTime: Time = Seconds(2), - slideTime: Time = Seconds(1) + windowTime: Time = batchDuration * 2, + slideTime: Time = batchDuration ) { test("reduceByKeyAndWindow - " + name) { - testOperation( - input, - (s: DStream[(String, Int)]) => s.reduceByKeyAndWindow(_ + _, windowTime, slideTime).persist(), - expectedOutput, - true - ) + val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt + val operation = (s: DStream[(String, Int)]) => { + s.reduceByKeyAndWindow(_ + _, windowTime, slideTime).persist() + } + testOperation(input, operation, expectedOutput, numBatches, true) } } @@ -98,16 +101,15 @@ class DStreamWindowSuite extends DStreamSuiteBase { name: String, input: Seq[Seq[(String, Int)]], expectedOutput: Seq[Seq[(String, Int)]], - windowTime: Time = Seconds(2), - slideTime: Time = Seconds(1) + windowTime: Time = batchDuration * 2, + slideTime: Time = batchDuration ) { test("reduceByKeyAndWindowInv - " + name) { - testOperation( - input, - (s: DStream[(String, Int)]) => s.reduceByKeyAndWindow(_ + _, _ - _, windowTime, slideTime).persist(), - expectedOutput, - true - ) + val numBatches = expectedOutput.size * (slideTime.millis / batchDuration.millis).toInt + val operation = (s: DStream[(String, Int)]) => { + s.reduceByKeyAndWindow(_ + _, _ - _, windowTime, slideTime).persist() + } + testOperation(input, operation, expectedOutput, numBatches, true) } } @@ -116,8 +118,8 @@ class DStreamWindowSuite extends DStreamSuiteBase { testReduceByKeyAndWindow( "basic reduction", - Seq(Seq(("a", 1), ("a", 3)) ), - Seq(Seq(("a", 4)) ) + Seq( Seq(("a", 1), ("a", 3)) ), + Seq( Seq(("a", 4)) ) ) testReduceByKeyAndWindow( @@ -126,7 +128,6 @@ class DStreamWindowSuite extends DStreamSuiteBase { Seq( Seq(("a", 1)), Seq(("a", 2)) ) ) - testReduceByKeyAndWindow( "new key added into window", Seq( Seq(("a", 1)), Seq(("a", 1), ("b", 1)) ), |