aboutsummaryrefslogtreecommitdiff
path: root/streaming
diff options
context:
space:
mode:
Diffstat (limited to 'streaming')
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala19
1 files changed, 13 insertions, 6 deletions
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
index dada495843..ca716cf4e6 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
@@ -133,6 +133,17 @@ trait DStreamCheckpointTester { self: SparkFunSuite =>
new StreamingContext(SparkContext.getOrCreate(conf), batchDuration)
}
+ /**
+ * Get the first TestOutputStreamWithPartitions, does not check the provided generic type.
+ */
+ protected def getTestOutputStream[V: ClassTag](streams: Array[DStream[_]]):
+ TestOutputStreamWithPartitions[V] = {
+ streams.collect {
+ case ds: TestOutputStreamWithPartitions[V @unchecked] => ds
+ }.head
+ }
+
+
protected def generateOutput[V: ClassTag](
ssc: StreamingContext,
targetBatchTime: Time,
@@ -150,9 +161,7 @@ trait DStreamCheckpointTester { self: SparkFunSuite =>
clock.setTime(targetBatchTime.milliseconds)
logInfo("Manual clock after advancing = " + clock.getTimeMillis())
- val outputStream = ssc.graph.getOutputStreams().filter { dstream =>
- dstream.isInstanceOf[TestOutputStreamWithPartitions[V]]
- }.head.asInstanceOf[TestOutputStreamWithPartitions[V]]
+ val outputStream = getTestOutputStream[V](ssc.graph.getOutputStreams())
eventually(timeout(10 seconds)) {
ssc.awaitTerminationOrTimeout(10)
@@ -908,9 +917,7 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester
logInfo("Manual clock after advancing = " + clock.getTimeMillis())
Thread.sleep(batchDuration.milliseconds)
- val outputStream = ssc.graph.getOutputStreams().filter { dstream =>
- dstream.isInstanceOf[TestOutputStreamWithPartitions[V]]
- }.head.asInstanceOf[TestOutputStreamWithPartitions[V]]
+ val outputStream = getTestOutputStream[V](ssc.graph.getOutputStreams())
outputStream.output.asScala.map(_.flatten)
}
}