diff options
Diffstat (limited to 'streaming')
-rw-r--r-- | streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala | 186 | ||||
-rw-r--r-- | streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala | 52 |
2 files changed, 136 insertions, 102 deletions
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 30a359677c..86b96785d7 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -470,32 +470,31 @@ class BasicOperationsSuite extends TestSuiteBase { } test("slice") { - val ssc = new StreamingContext(conf, Seconds(1)) - val input = Seq(Seq(1), Seq(2), Seq(3), Seq(4)) - val stream = new TestInputStream[Int](ssc, input, 2) - stream.foreachRDD(_ => {}) // Dummy output stream - ssc.start() - Thread.sleep(2000) - def getInputFromSlice(fromMillis: Long, toMillis: Long) = { - stream.slice(new Time(fromMillis), new Time(toMillis)).flatMap(_.collect()).toSet - } + withStreamingContext(new StreamingContext(conf, Seconds(1))) { ssc => + val input = Seq(Seq(1), Seq(2), Seq(3), Seq(4)) + val stream = new TestInputStream[Int](ssc, input, 2) + stream.foreachRDD(_ => {}) // Dummy output stream + ssc.start() + Thread.sleep(2000) + def getInputFromSlice(fromMillis: Long, toMillis: Long) = { + stream.slice(new Time(fromMillis), new Time(toMillis)).flatMap(_.collect()).toSet + } - assert(getInputFromSlice(0, 1000) == Set(1)) - assert(getInputFromSlice(0, 2000) == Set(1, 2)) - assert(getInputFromSlice(1000, 2000) == Set(1, 2)) - assert(getInputFromSlice(2000, 4000) == Set(2, 3, 4)) - ssc.stop() - Thread.sleep(1000) + assert(getInputFromSlice(0, 1000) == Set(1)) + assert(getInputFromSlice(0, 2000) == Set(1, 2)) + assert(getInputFromSlice(1000, 2000) == Set(1, 2)) + assert(getInputFromSlice(2000, 4000) == Set(2, 3, 4)) + } } - test("slice - has not been initialized") { - val ssc = new StreamingContext(conf, Seconds(1)) - val input = Seq(Seq(1), Seq(2), Seq(3), Seq(4)) - val stream = new TestInputStream[Int](ssc, input, 2) - val thrown = intercept[SparkException] { - stream.slice(new Time(0), new Time(1000)) + withStreamingContext(new StreamingContext(conf, Seconds(1))) { ssc => + val input = Seq(Seq(1), Seq(2), Seq(3), Seq(4)) + val stream = new TestInputStream[Int](ssc, input, 2) + val thrown = intercept[SparkException] { + stream.slice(new Time(0), new Time(1000)) + } + assert(thrown.getMessage.contains("has not been initialized")) } - assert(thrown.getMessage.contains("has not been initialized")) } val cleanupTestInput = (0 until 10).map(x => Seq(x, x + 1)).toSeq @@ -555,73 +554,72 @@ class BasicOperationsSuite extends TestSuiteBase { test("rdd cleanup - input blocks and persisted RDDs") { // Actually receive data over through receiver to create BlockRDDs - // Start the server - val testServer = new TestServer() - testServer.start() - - // Set up the streaming context and input streams - val ssc = new StreamingContext(conf, batchDuration) - val networkStream = ssc.socketTextStream("localhost", testServer.port, StorageLevel.MEMORY_AND_DISK) - val mappedStream = networkStream.map(_ + ".").persist() - val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] - val outputStream = new TestOutputStream(mappedStream, outputBuffer) - - outputStream.register() - ssc.start() - - // Feed data to the server to send to the network receiver - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - val input = Seq(1, 2, 3, 4, 5, 6) + withTestServer(new TestServer()) { testServer => + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + testServer.start() + // Set up the streaming context and input streams + val networkStream = + ssc.socketTextStream("localhost", testServer.port, StorageLevel.MEMORY_AND_DISK) + val mappedStream = networkStream.map(_ + ".").persist() + val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] + val outputStream = new TestOutputStream(mappedStream, outputBuffer) + + outputStream.register() + ssc.start() + + // Feed data to the server to send to the network receiver + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + val input = Seq(1, 2, 3, 4, 5, 6) + + val blockRdds = new mutable.HashMap[Time, BlockRDD[_]] + val persistentRddIds = new mutable.HashMap[Time, Int] + + def collectRddInfo() { // get all RDD info required for verification + networkStream.generatedRDDs.foreach { case (time, rdd) => + blockRdds(time) = rdd.asInstanceOf[BlockRDD[_]] + } + mappedStream.generatedRDDs.foreach { case (time, rdd) => + persistentRddIds(time) = rdd.id + } + } - val blockRdds = new mutable.HashMap[Time, BlockRDD[_]] - val persistentRddIds = new mutable.HashMap[Time, Int] + Thread.sleep(200) + for (i <- 0 until input.size) { + testServer.send(input(i).toString + "\n") + Thread.sleep(200) + clock.addToTime(batchDuration.milliseconds) + collectRddInfo() + } - def collectRddInfo() { // get all RDD info required for verification - networkStream.generatedRDDs.foreach { case (time, rdd) => - blockRdds(time) = rdd.asInstanceOf[BlockRDD[_]] - } - mappedStream.generatedRDDs.foreach { case (time, rdd) => - persistentRddIds(time) = rdd.id + Thread.sleep(200) + collectRddInfo() + logInfo("Stopping server") + testServer.stop() + + // verify data has been received + assert(outputBuffer.size > 0) + assert(blockRdds.size > 0) + assert(persistentRddIds.size > 0) + + import Time._ + + val latestPersistedRddId = persistentRddIds(persistentRddIds.keySet.max) + val earliestPersistedRddId = persistentRddIds(persistentRddIds.keySet.min) + val latestBlockRdd = blockRdds(blockRdds.keySet.max) + val earliestBlockRdd = blockRdds(blockRdds.keySet.min) + // verify that the latest mapped RDD is persisted but the earliest one has been unpersisted + assert(ssc.sparkContext.persistentRdds.contains(latestPersistedRddId)) + assert(!ssc.sparkContext.persistentRdds.contains(earliestPersistedRddId)) + + // verify that the latest input blocks are present but the earliest blocks have been removed + assert(latestBlockRdd.isValid) + assert(latestBlockRdd.collect != null) + assert(!earliestBlockRdd.isValid) + earliestBlockRdd.blockIds.foreach { blockId => + assert(!ssc.sparkContext.env.blockManager.master.contains(blockId)) + } } } - - Thread.sleep(200) - for (i <- 0 until input.size) { - testServer.send(input(i).toString + "\n") - Thread.sleep(200) - clock.addToTime(batchDuration.milliseconds) - collectRddInfo() - } - - Thread.sleep(200) - collectRddInfo() - logInfo("Stopping server") - testServer.stop() - logInfo("Stopping context") - - // verify data has been received - assert(outputBuffer.size > 0) - assert(blockRdds.size > 0) - assert(persistentRddIds.size > 0) - - import Time._ - - val latestPersistedRddId = persistentRddIds(persistentRddIds.keySet.max) - val earliestPersistedRddId = persistentRddIds(persistentRddIds.keySet.min) - val latestBlockRdd = blockRdds(blockRdds.keySet.max) - val earliestBlockRdd = blockRdds(blockRdds.keySet.min) - // verify that the latest mapped RDD is persisted but the earliest one has been unpersisted - assert(ssc.sparkContext.persistentRdds.contains(latestPersistedRddId)) - assert(!ssc.sparkContext.persistentRdds.contains(earliestPersistedRddId)) - - // verify that the latest input blocks are present but the earliest blocks have been removed - assert(latestBlockRdd.isValid) - assert(latestBlockRdd.collect != null) - assert(!earliestBlockRdd.isValid) - earliestBlockRdd.blockIds.foreach { blockId => - assert(!ssc.sparkContext.env.blockManager.master.contains(blockId)) - } - ssc.stop() } /** Test cleanup of RDDs in DStream metadata */ @@ -635,13 +633,15 @@ class BasicOperationsSuite extends TestSuiteBase { // Setup the stream computation assert(batchDuration === Seconds(1), "Batch duration has changed from 1 second, check cleanup tests") - val ssc = setupStreams(cleanupTestInput, operation) - val operatedStream = ssc.graph.getOutputStreams().head.dependencies.head.asInstanceOf[DStream[T]] - if (rememberDuration != null) ssc.remember(rememberDuration) - val output = runStreams[(Int, Int)](ssc, cleanupTestInput.size, numExpectedOutput) - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - assert(clock.time === Seconds(10).milliseconds) - assert(output.size === numExpectedOutput) - operatedStream + withStreamingContext(setupStreams(cleanupTestInput, operation)) { ssc => + val operatedStream = + ssc.graph.getOutputStreams().head.dependencies.head.asInstanceOf[DStream[T]] + if (rememberDuration != null) ssc.remember(rememberDuration) + val output = runStreams[(Int, Int)](ssc, cleanupTestInput.size, numExpectedOutput) + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + assert(clock.time === Seconds(10).milliseconds) + assert(output.size === numExpectedOutput) + operatedStream + } } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 2154c24abd..52972f63c6 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -164,6 +164,40 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { after(afterFunction) /** + * Run a block of code with the given StreamingContext and automatically + * stop the context when the block completes or when an exception is thrown. + */ + def withStreamingContext[R](ssc: StreamingContext)(block: StreamingContext => R): R = { + try { + block(ssc) + } finally { + try { + ssc.stop(stopSparkContext = true) + } catch { + case e: Exception => + logError("Error stopping StreamingContext", e) + } + } + } + + /** + * Run a block of code with the given TestServer and automatically + * stop the server when the block completes or when an exception is thrown. + */ + def withTestServer[R](testServer: TestServer)(block: TestServer => R): R = { + try { + block(testServer) + } finally { + try { + testServer.stop() + } catch { + case e: Exception => + logError("Error stopping TestServer", e) + } + } + } + + /** * Set up required DStreams to test the DStream operation using the two sequences * of input collections. */ @@ -282,10 +316,8 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { assert(output.size === numExpectedOutput, "Unexpected number of outputs generated") Thread.sleep(100) // Give some time for the forgetting old RDDs to complete - } catch { - case e: Exception => {e.printStackTrace(); throw e} } finally { - ssc.stop() + ssc.stop(stopSparkContext = true) } output } @@ -351,9 +383,10 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { useSet: Boolean ) { val numBatches_ = if (numBatches > 0) numBatches else expectedOutput.size - val ssc = setupStreams[U, V](input, operation) - val output = runStreams[V](ssc, numBatches_, expectedOutput.size) - verifyOutput[V](output, expectedOutput, useSet) + withStreamingContext(setupStreams[U, V](input, operation)) { ssc => + val output = runStreams[V](ssc, numBatches_, expectedOutput.size) + verifyOutput[V](output, expectedOutput, useSet) + } } /** @@ -389,8 +422,9 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { useSet: Boolean ) { val numBatches_ = if (numBatches > 0) numBatches else expectedOutput.size - val ssc = setupStreams[U, V, W](input1, input2, operation) - val output = runStreams[W](ssc, numBatches_, expectedOutput.size) - verifyOutput[W](output, expectedOutput, useSet) + withStreamingContext(setupStreams[U, V, W](input1, input2, operation)) { ssc => + val output = runStreams[W](ssc, numBatches_, expectedOutput.size) + verifyOutput[W](output, expectedOutput, useSet) + } } } |