aboutsummaryrefslogtreecommitdiff
path: root/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
diff options
context:
space:
mode:
Diffstat (limited to 'streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala')
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala19
1 files changed, 16 insertions, 3 deletions
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
index fa975a1462..dbab708861 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
@@ -359,14 +359,20 @@ trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging {
* output data has been collected or timeout (set by `maxWaitTimeMillis`) is reached.
*
* Returns a sequence of items for each RDD.
+ *
+ * @param ssc The StreamingContext
+ * @param numBatches The number of batches should be run
+ * @param numExpectedOutput The number of expected output
+ * @param preStop The function to run before stopping StreamingContext
*/
def runStreams[V: ClassTag](
ssc: StreamingContext,
numBatches: Int,
- numExpectedOutput: Int
+ numExpectedOutput: Int,
+ preStop: () => Unit = () => {}
): Seq[Seq[V]] = {
// Flatten each RDD into a single Seq
- runStreamsWithPartitions(ssc, numBatches, numExpectedOutput).map(_.flatten.toSeq)
+ runStreamsWithPartitions(ssc, numBatches, numExpectedOutput, preStop).map(_.flatten.toSeq)
}
/**
@@ -376,11 +382,17 @@ trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging {
*
* Returns a sequence of RDD's. Each RDD is represented as several sequences of items, each
* representing one partition.
+ *
+ * @param ssc The StreamingContext
+ * @param numBatches The number of batches should be run
+ * @param numExpectedOutput The number of expected output
+ * @param preStop The function to run before stopping StreamingContext
*/
def runStreamsWithPartitions[V: ClassTag](
ssc: StreamingContext,
numBatches: Int,
- numExpectedOutput: Int
+ numExpectedOutput: Int,
+ preStop: () => Unit = () => {}
): Seq[Seq[Seq[V]]] = {
assert(numBatches > 0, "Number of batches to run stream computation is zero")
assert(numExpectedOutput > 0, "Number of expected outputs after " + numBatches + " is zero")
@@ -424,6 +436,7 @@ trait TestSuiteBase extends SparkFunSuite with BeforeAndAfter with Logging {
assert(output.size === numExpectedOutput, "Unexpected number of outputs generated")
Thread.sleep(100) // Give some time for the forgetting old RDDs to complete
+ preStop()
} finally {
ssc.stop(stopSparkContext = true)
}