diff options
-rw-r--r-- | streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala | 19 |
1 files changed, 13 insertions, 6 deletions
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index dada495843..ca716cf4e6 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -133,6 +133,17 @@ trait DStreamCheckpointTester { self: SparkFunSuite => new StreamingContext(SparkContext.getOrCreate(conf), batchDuration) } + /** + * Get the first TestOutputStreamWithPartitions, does not check the provided generic type. + */ + protected def getTestOutputStream[V: ClassTag](streams: Array[DStream[_]]): + TestOutputStreamWithPartitions[V] = { + streams.collect { + case ds: TestOutputStreamWithPartitions[V @unchecked] => ds + }.head + } + + protected def generateOutput[V: ClassTag]( ssc: StreamingContext, targetBatchTime: Time, @@ -150,9 +161,7 @@ trait DStreamCheckpointTester { self: SparkFunSuite => clock.setTime(targetBatchTime.milliseconds) logInfo("Manual clock after advancing = " + clock.getTimeMillis()) - val outputStream = ssc.graph.getOutputStreams().filter { dstream => - dstream.isInstanceOf[TestOutputStreamWithPartitions[V]] - }.head.asInstanceOf[TestOutputStreamWithPartitions[V]] + val outputStream = getTestOutputStream[V](ssc.graph.getOutputStreams()) eventually(timeout(10 seconds)) { ssc.awaitTerminationOrTimeout(10) @@ -908,9 +917,7 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester logInfo("Manual clock after advancing = " + clock.getTimeMillis()) Thread.sleep(batchDuration.milliseconds) - val outputStream = ssc.graph.getOutputStreams().filter { dstream => - dstream.isInstanceOf[TestOutputStreamWithPartitions[V]] - }.head.asInstanceOf[TestOutputStreamWithPartitions[V]] + val outputStream = getTestOutputStream[V](ssc.graph.getOutputStreams()) outputStream.output.asScala.map(_.flatten) } } |