diff options
Diffstat (limited to 'streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala')
-rw-r--r-- | streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala | 19 |
1 files changed, 16 insertions, 3 deletions
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index fa975a1462..dbab708861 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -359,14 +359,20 @@ trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging { * output data has been collected or timeout (set by `maxWaitTimeMillis`) is reached. * * Returns a sequence of items for each RDD. + * + * @param ssc The StreamingContext + * @param numBatches The number of batches should be run + * @param numExpectedOutput The number of expected output + * @param preStop The function to run before stopping StreamingContext */ def runStreams[V: ClassTag]( ssc: StreamingContext, numBatches: Int, - numExpectedOutput: Int + numExpectedOutput: Int, + preStop: () => Unit = () => {} ): Seq[Seq[V]] = { // Flatten each RDD into a single Seq - runStreamsWithPartitions(ssc, numBatches, numExpectedOutput).map(_.flatten.toSeq) + runStreamsWithPartitions(ssc, numBatches, numExpectedOutput, preStop).map(_.flatten.toSeq) } /** @@ -376,11 +382,17 @@ trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging { * * Returns a sequence of RDD's. Each RDD is represented as several sequences of items, each * representing one partition. + * + * @param ssc The StreamingContext + * @param numBatches The number of batches should be run + * @param numExpectedOutput The number of expected output + * @param preStop The function to run before stopping StreamingContext */ def runStreamsWithPartitions[V: ClassTag]( ssc: StreamingContext, numBatches: Int, - numExpectedOutput: Int + numExpectedOutput: Int, + preStop: () => Unit = () => {} ): Seq[Seq[Seq[V]]] = { assert(numBatches > 0, "Number of batches to run stream computation is zero") assert(numExpectedOutput > 0, "Number of expected outputs after " + numBatches + " is zero") @@ -424,6 +436,7 @@ trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging { assert(output.size === numExpectedOutput, "Unexpected number of outputs generated") Thread.sleep(100) // Give some time for the forgetting old RDDs to complete + preStop() } finally { ssc.stop(stopSparkContext = true) } |