aboutsummaryrefslogtreecommitdiff
path: root/external/kinesis-asl/src/test/scala/org/apache
diff options
context:
space:
mode:
Diffstat (limited to 'external/kinesis-asl/src/test/scala/org/apache')
-rw-r--r--external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KPLBasedKinesisTestUtils.scala72
-rw-r--r--external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala259
-rw-r--r--external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala152
-rw-r--r--external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala46
-rw-r--r--external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala210
-rw-r--r--external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala297
6 files changed, 1036 insertions, 0 deletions
diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KPLBasedKinesisTestUtils.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KPLBasedKinesisTestUtils.scala
new file mode 100644
index 0000000000..fdb270eaad
--- /dev/null
+++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KPLBasedKinesisTestUtils.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.streaming.kinesis
+
+import java.nio.ByteBuffer
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
+import com.amazonaws.services.kinesis.producer.{KinesisProducer => KPLProducer, KinesisProducerConfiguration, UserRecordResult}
+import com.google.common.util.concurrent.{FutureCallback, Futures}
+
+private[kinesis] class KPLBasedKinesisTestUtils extends KinesisTestUtils {
+ override protected def getProducer(aggregate: Boolean): KinesisDataGenerator = {
+ if (!aggregate) {
+ new SimpleDataGenerator(kinesisClient)
+ } else {
+ new KPLDataGenerator(regionName)
+ }
+ }
+}
+
+/** A wrapper for the KinesisProducer provided in the KPL. */
+private[kinesis] class KPLDataGenerator(regionName: String) extends KinesisDataGenerator {
+
+ private lazy val producer: KPLProducer = {
+ val conf = new KinesisProducerConfiguration()
+ .setRecordMaxBufferedTime(1000)
+ .setMaxConnections(1)
+ .setRegion(regionName)
+ .setMetricsLevel("none")
+
+ new KPLProducer(conf)
+ }
+
+ override def sendData(streamName: String, data: Seq[Int]): Map[String, Seq[(Int, String)]] = {
+ val shardIdToSeqNumbers = new mutable.HashMap[String, ArrayBuffer[(Int, String)]]()
+ data.foreach { num =>
+ val str = num.toString
+ val data = ByteBuffer.wrap(str.getBytes())
+ val future = producer.addUserRecord(streamName, str, data)
+ val kinesisCallBack = new FutureCallback[UserRecordResult]() {
+ override def onFailure(t: Throwable): Unit = {} // do nothing
+
+ override def onSuccess(result: UserRecordResult): Unit = {
+ val shardId = result.getShardId
+ val seqNumber = result.getSequenceNumber()
+ val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId,
+ new ArrayBuffer[(Int, String)]())
+ sentSeqNumbers += ((num, seqNumber))
+ }
+ }
+ Futures.addCallback(future, kinesisCallBack)
+ }
+ producer.flushSync()
+ shardIdToSeqNumbers.toMap
+ }
+}
diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala
new file mode 100644
index 0000000000..2555332d22
--- /dev/null
+++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala
@@ -0,0 +1,259 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kinesis
+
+import org.scalatest.BeforeAndAfterEach
+
+import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException}
+import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId}
+
+abstract class KinesisBackedBlockRDDTests(aggregateTestData: Boolean)
+ extends KinesisFunSuite with BeforeAndAfterEach with LocalSparkContext {
+
+ private val testData = 1 to 8
+
+ private var testUtils: KinesisTestUtils = null
+ private var shardIds: Seq[String] = null
+ private var shardIdToData: Map[String, Seq[Int]] = null
+ private var shardIdToSeqNumbers: Map[String, Seq[String]] = null
+ private var shardIdToDataAndSeqNumbers: Map[String, Seq[(Int, String)]] = null
+ private var shardIdToRange: Map[String, SequenceNumberRange] = null
+ private var allRanges: Seq[SequenceNumberRange] = null
+
+ private var blockManager: BlockManager = null
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ runIfTestsEnabled("Prepare KinesisTestUtils") {
+ testUtils = new KPLBasedKinesisTestUtils()
+ testUtils.createStream()
+
+ shardIdToDataAndSeqNumbers = testUtils.pushData(testData, aggregate = aggregateTestData)
+ require(shardIdToDataAndSeqNumbers.size > 1, "Need data to be sent to multiple shards")
+
+ shardIds = shardIdToDataAndSeqNumbers.keySet.toSeq
+ shardIdToData = shardIdToDataAndSeqNumbers.mapValues { _.map { _._1 }}
+ shardIdToSeqNumbers = shardIdToDataAndSeqNumbers.mapValues { _.map { _._2 }}
+ shardIdToRange = shardIdToSeqNumbers.map { case (shardId, seqNumbers) =>
+ val seqNumRange = SequenceNumberRange(
+ testUtils.streamName, shardId, seqNumbers.head, seqNumbers.last)
+ (shardId, seqNumRange)
+ }
+ allRanges = shardIdToRange.values.toSeq
+ }
+ }
+
+ override def beforeEach(): Unit = {
+ super.beforeEach()
+ val conf = new SparkConf().setMaster("local[4]").setAppName("KinesisBackedBlockRDDSuite")
+ sc = new SparkContext(conf)
+ blockManager = sc.env.blockManager
+ }
+
+ override def afterAll(): Unit = {
+ try {
+ if (testUtils != null) {
+ testUtils.deleteStream()
+ }
+ } finally {
+ super.afterAll()
+ }
+ }
+
+ testIfEnabled("Basic reading from Kinesis") {
+ // Verify all data using multiple ranges in a single RDD partition
+ val receivedData1 = new KinesisBackedBlockRDD[Array[Byte]](sc, testUtils.regionName,
+ testUtils.endpointUrl, fakeBlockIds(1),
+ Array(SequenceNumberRanges(allRanges.toArray))
+ ).map { bytes => new String(bytes).toInt }.collect()
+ assert(receivedData1.toSet === testData.toSet)
+
+ // Verify all data using one range in each of the multiple RDD partitions
+ val receivedData2 = new KinesisBackedBlockRDD[Array[Byte]](sc, testUtils.regionName,
+ testUtils.endpointUrl, fakeBlockIds(allRanges.size),
+ allRanges.map { range => SequenceNumberRanges(Array(range)) }.toArray
+ ).map { bytes => new String(bytes).toInt }.collect()
+ assert(receivedData2.toSet === testData.toSet)
+
+ // Verify ordering within each partition
+ val receivedData3 = new KinesisBackedBlockRDD[Array[Byte]](sc, testUtils.regionName,
+ testUtils.endpointUrl, fakeBlockIds(allRanges.size),
+ allRanges.map { range => SequenceNumberRanges(Array(range)) }.toArray
+ ).map { bytes => new String(bytes).toInt }.collectPartitions()
+ assert(receivedData3.length === allRanges.size)
+ for (i <- 0 until allRanges.size) {
+ assert(receivedData3(i).toSeq === shardIdToData(allRanges(i).shardId))
+ }
+ }
+
+ testIfEnabled("Read data available in both block manager and Kinesis") {
+ testRDD(numPartitions = 2, numPartitionsInBM = 2, numPartitionsInKinesis = 2)
+ }
+
+ testIfEnabled("Read data available only in block manager, not in Kinesis") {
+ testRDD(numPartitions = 2, numPartitionsInBM = 2, numPartitionsInKinesis = 0)
+ }
+
+ testIfEnabled("Read data available only in Kinesis, not in block manager") {
+ testRDD(numPartitions = 2, numPartitionsInBM = 0, numPartitionsInKinesis = 2)
+ }
+
+ testIfEnabled("Read data available partially in block manager, rest in Kinesis") {
+ testRDD(numPartitions = 2, numPartitionsInBM = 1, numPartitionsInKinesis = 1)
+ }
+
+ testIfEnabled("Test isBlockValid skips block fetching from block manager") {
+ testRDD(numPartitions = 2, numPartitionsInBM = 2, numPartitionsInKinesis = 0,
+ testIsBlockValid = true)
+ }
+
+ testIfEnabled("Test whether RDD is valid after removing blocks from block anager") {
+ testRDD(numPartitions = 2, numPartitionsInBM = 2, numPartitionsInKinesis = 2,
+ testBlockRemove = true)
+ }
+
+ /**
+ * Test the WriteAheadLogBackedRDD, by writing some partitions of the data to block manager
+ * and the rest to a write ahead log, and then reading reading it all back using the RDD.
+ * It can also test if the partitions that were read from the log were again stored in
+ * block manager.
+ *
+ *
+ *
+ * @param numPartitions Number of partitions in RDD
+ * @param numPartitionsInBM Number of partitions to write to the BlockManager.
+ * Partitions 0 to (numPartitionsInBM-1) will be written to BlockManager
+ * @param numPartitionsInKinesis Number of partitions to write to the Kinesis.
+ * Partitions (numPartitions - 1 - numPartitionsInKinesis) to
+ * (numPartitions - 1) will be written to Kinesis
+ * @param testIsBlockValid Test whether setting isBlockValid to false skips block fetching
+ * @param testBlockRemove Test whether calling rdd.removeBlock() makes the RDD still usable with
+ * reads falling back to the WAL
+ * Example with numPartitions = 5, numPartitionsInBM = 3, and numPartitionsInWAL = 4
+ *
+ * numPartitionsInBM = 3
+ * |------------------|
+ * | |
+ * 0 1 2 3 4
+ * | |
+ * |-------------------------|
+ * numPartitionsInKinesis = 4
+ */
+ private def testRDD(
+ numPartitions: Int,
+ numPartitionsInBM: Int,
+ numPartitionsInKinesis: Int,
+ testIsBlockValid: Boolean = false,
+ testBlockRemove: Boolean = false
+ ): Unit = {
+ require(shardIds.size > 1, "Need at least 2 shards to test")
+ require(numPartitionsInBM <= shardIds.size,
+ "Number of partitions in BlockManager cannot be more than the Kinesis test shards available")
+ require(numPartitionsInKinesis <= shardIds.size,
+ "Number of partitions in Kinesis cannot be more than the Kinesis test shards available")
+ require(numPartitionsInBM <= numPartitions,
+ "Number of partitions in BlockManager cannot be more than that in RDD")
+ require(numPartitionsInKinesis <= numPartitions,
+ "Number of partitions in Kinesis cannot be more than that in RDD")
+
+ // Put necessary blocks in the block manager
+ val blockIds = fakeBlockIds(numPartitions)
+ blockIds.foreach(blockManager.removeBlock(_))
+ (0 until numPartitionsInBM).foreach { i =>
+ val blockData = shardIdToData(shardIds(i)).iterator.map { _.toString.getBytes() }
+ blockManager.putIterator(blockIds(i), blockData, StorageLevel.MEMORY_ONLY)
+ }
+
+ // Create the necessary ranges to use in the RDD
+ val fakeRanges = Array.fill(numPartitions - numPartitionsInKinesis)(
+ SequenceNumberRanges(SequenceNumberRange("fakeStream", "fakeShardId", "xxx", "yyy")))
+ val realRanges = Array.tabulate(numPartitionsInKinesis) { i =>
+ val range = shardIdToRange(shardIds(i + (numPartitions - numPartitionsInKinesis)))
+ SequenceNumberRanges(Array(range))
+ }
+ val ranges = (fakeRanges ++ realRanges)
+
+
+ // Make sure that the left `numPartitionsInBM` blocks are in block manager, and others are not
+ require(
+ blockIds.take(numPartitionsInBM).forall(blockManager.get(_).nonEmpty),
+ "Expected blocks not in BlockManager"
+ )
+
+ require(
+ blockIds.drop(numPartitionsInBM).forall(blockManager.get(_).isEmpty),
+ "Unexpected blocks in BlockManager"
+ )
+
+ // Make sure that the right sequence `numPartitionsInKinesis` are configured, and others are not
+ require(
+ ranges.takeRight(numPartitionsInKinesis).forall {
+ _.ranges.forall { _.streamName == testUtils.streamName }
+ }, "Incorrect configuration of RDD, expected ranges not set: "
+ )
+
+ require(
+ ranges.dropRight(numPartitionsInKinesis).forall {
+ _.ranges.forall { _.streamName != testUtils.streamName }
+ }, "Incorrect configuration of RDD, unexpected ranges set"
+ )
+
+ val rdd = new KinesisBackedBlockRDD[Array[Byte]](
+ sc, testUtils.regionName, testUtils.endpointUrl, blockIds, ranges)
+ val collectedData = rdd.map { bytes =>
+ new String(bytes).toInt
+ }.collect()
+ assert(collectedData.toSet === testData.toSet)
+
+ // Verify that the block fetching is skipped when isBlockValid is set to false.
+ // This is done by using a RDD whose data is only in memory but is set to skip block fetching
+ // Using that RDD will throw exception, as it skips block fetching even if the blocks are in
+ // in BlockManager.
+ if (testIsBlockValid) {
+ require(numPartitionsInBM === numPartitions, "All partitions must be in BlockManager")
+ require(numPartitionsInKinesis === 0, "No partitions must be in Kinesis")
+ val rdd2 = new KinesisBackedBlockRDD[Array[Byte]](
+ sc, testUtils.regionName, testUtils.endpointUrl, blockIds.toArray, ranges,
+ isBlockIdValid = Array.fill(blockIds.length)(false))
+ intercept[SparkException] {
+ rdd2.collect()
+ }
+ }
+
+ // Verify that the RDD is not invalid after the blocks are removed and can still read data
+ // from write ahead log
+ if (testBlockRemove) {
+ require(numPartitions === numPartitionsInKinesis,
+ "All partitions must be in WAL for this test")
+ require(numPartitionsInBM > 0, "Some partitions must be in BlockManager for this test")
+ rdd.removeBlocks()
+ assert(rdd.map { bytes => new String(bytes).toInt }.collect().toSet === testData.toSet)
+ }
+ }
+
+ /** Generate fake block ids */
+ private def fakeBlockIds(num: Int): Array[BlockId] = {
+ Array.tabulate(num) { i => new StreamBlockId(0, i) }
+ }
+}
+
+class WithAggregationKinesisBackedBlockRDDSuite
+ extends KinesisBackedBlockRDDTests(aggregateTestData = true)
+
+class WithoutAggregationKinesisBackedBlockRDDSuite
+ extends KinesisBackedBlockRDDTests(aggregateTestData = false)
diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala
new file mode 100644
index 0000000000..e1499a8220
--- /dev/null
+++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kinesis
+
+import java.util.concurrent.{ExecutorService, TimeoutException}
+
+import scala.concurrent.{Await, ExecutionContext, Future}
+import scala.concurrent.duration._
+import scala.language.postfixOps
+
+import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer
+import org.mockito.Matchers._
+import org.mockito.Mockito._
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester}
+import org.scalatest.concurrent.Eventually
+import org.scalatest.concurrent.Eventually._
+import org.scalatest.mock.MockitoSugar
+
+import org.apache.spark.streaming.{Duration, TestSuiteBase}
+import org.apache.spark.util.ManualClock
+
+class KinesisCheckpointerSuite extends TestSuiteBase
+ with MockitoSugar
+ with BeforeAndAfterEach
+ with PrivateMethodTester
+ with Eventually {
+
+ private val workerId = "dummyWorkerId"
+ private val shardId = "dummyShardId"
+ private val seqNum = "123"
+ private val otherSeqNum = "245"
+ private val checkpointInterval = Duration(10)
+ private val someSeqNum = Some(seqNum)
+ private val someOtherSeqNum = Some(otherSeqNum)
+
+ private var receiverMock: KinesisReceiver[Array[Byte]] = _
+ private var checkpointerMock: IRecordProcessorCheckpointer = _
+ private var kinesisCheckpointer: KinesisCheckpointer = _
+ private var clock: ManualClock = _
+
+ private val checkpoint = PrivateMethod[Unit]('checkpoint)
+
+ override def beforeEach(): Unit = {
+ receiverMock = mock[KinesisReceiver[Array[Byte]]]
+ checkpointerMock = mock[IRecordProcessorCheckpointer]
+ clock = new ManualClock()
+ kinesisCheckpointer = new KinesisCheckpointer(receiverMock, checkpointInterval, workerId, clock)
+ }
+
+ test("checkpoint is not called twice for the same sequence number") {
+ when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum)
+ kinesisCheckpointer.invokePrivate(checkpoint(shardId, checkpointerMock))
+ kinesisCheckpointer.invokePrivate(checkpoint(shardId, checkpointerMock))
+
+ verify(checkpointerMock, times(1)).checkpoint(anyString())
+ }
+
+ test("checkpoint is called after sequence number increases") {
+ when(receiverMock.getLatestSeqNumToCheckpoint(shardId))
+ .thenReturn(someSeqNum).thenReturn(someOtherSeqNum)
+ kinesisCheckpointer.invokePrivate(checkpoint(shardId, checkpointerMock))
+ kinesisCheckpointer.invokePrivate(checkpoint(shardId, checkpointerMock))
+
+ verify(checkpointerMock, times(1)).checkpoint(seqNum)
+ verify(checkpointerMock, times(1)).checkpoint(otherSeqNum)
+ }
+
+ test("should checkpoint if we have exceeded the checkpoint interval") {
+ when(receiverMock.getLatestSeqNumToCheckpoint(shardId))
+ .thenReturn(someSeqNum).thenReturn(someOtherSeqNum)
+
+ kinesisCheckpointer.setCheckpointer(shardId, checkpointerMock)
+ clock.advance(5 * checkpointInterval.milliseconds)
+
+ eventually(timeout(1 second)) {
+ verify(checkpointerMock, times(1)).checkpoint(seqNum)
+ verify(checkpointerMock, times(1)).checkpoint(otherSeqNum)
+ }
+ }
+
+ test("shouldn't checkpoint if we have not exceeded the checkpoint interval") {
+ when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum)
+
+ kinesisCheckpointer.setCheckpointer(shardId, checkpointerMock)
+ clock.advance(checkpointInterval.milliseconds / 2)
+
+ verify(checkpointerMock, never()).checkpoint(anyString())
+ }
+
+ test("should not checkpoint for the same sequence number") {
+ when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum)
+
+ kinesisCheckpointer.setCheckpointer(shardId, checkpointerMock)
+
+ clock.advance(checkpointInterval.milliseconds * 5)
+ eventually(timeout(1 second)) {
+ verify(checkpointerMock, atMost(1)).checkpoint(anyString())
+ }
+ }
+
+ test("removing checkpointer checkpoints one last time") {
+ when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum)
+
+ kinesisCheckpointer.removeCheckpointer(shardId, checkpointerMock)
+ verify(checkpointerMock, times(1)).checkpoint(anyString())
+ }
+
+ test("if checkpointing is going on, wait until finished before removing and checkpointing") {
+ when(receiverMock.getLatestSeqNumToCheckpoint(shardId))
+ .thenReturn(someSeqNum).thenReturn(someOtherSeqNum)
+ when(checkpointerMock.checkpoint(anyString)).thenAnswer(new Answer[Unit] {
+ override def answer(invocations: InvocationOnMock): Unit = {
+ clock.waitTillTime(clock.getTimeMillis() + checkpointInterval.milliseconds / 2)
+ }
+ })
+
+ kinesisCheckpointer.setCheckpointer(shardId, checkpointerMock)
+ clock.advance(checkpointInterval.milliseconds)
+ eventually(timeout(1 second)) {
+ verify(checkpointerMock, times(1)).checkpoint(anyString())
+ }
+ // don't block test thread
+ val f = Future(kinesisCheckpointer.removeCheckpointer(shardId, checkpointerMock))(
+ ExecutionContext.global)
+
+ intercept[TimeoutException] {
+ Await.ready(f, 50 millis)
+ }
+
+ clock.advance(checkpointInterval.milliseconds / 2)
+ eventually(timeout(1 second)) {
+ verify(checkpointerMock, times(2)).checkpoint(anyString())
+ }
+ }
+}
diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala
new file mode 100644
index 0000000000..ee428f31d6
--- /dev/null
+++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kinesis
+
+import org.apache.spark.SparkFunSuite
+
+/**
+ * Helper class that runs Kinesis real data transfer tests or
+ * ignores them based on env variable is set or not.
+ */
+trait KinesisFunSuite extends SparkFunSuite {
+ import KinesisTestUtils._
+
+ /** Run the test if environment variable is set or ignore the test */
+ def testIfEnabled(testName: String)(testBody: => Unit) {
+ if (shouldRunTests) {
+ test(testName)(testBody)
+ } else {
+ ignore(s"$testName [enable by setting env var $envVarNameForEnablingTests=1]")(testBody)
+ }
+ }
+
+ /** Run the give body of code only if Kinesis tests are enabled */
+ def runIfTestsEnabled(message: String)(body: => Unit): Unit = {
+ if (shouldRunTests) {
+ body
+ } else {
+ ignore(s"$message [enable by setting env var $envVarNameForEnablingTests=1]")()
+ }
+ }
+}
diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala
new file mode 100644
index 0000000000..fd15b6ccdc
--- /dev/null
+++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala
@@ -0,0 +1,210 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.streaming.kinesis
+
+import java.nio.ByteBuffer
+import java.nio.charset.StandardCharsets
+import java.util.Arrays
+
+import com.amazonaws.services.kinesis.clientlibrary.exceptions._
+import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer
+import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason
+import com.amazonaws.services.kinesis.model.Record
+import org.mockito.Matchers._
+import org.mockito.Matchers.{eq => meq}
+import org.mockito.Mockito._
+import org.scalatest.{BeforeAndAfter, Matchers}
+import org.scalatest.mock.MockitoSugar
+
+import org.apache.spark.streaming.{Duration, TestSuiteBase}
+import org.apache.spark.util.Utils
+
+/**
+ * Suite of Kinesis streaming receiver tests focusing mostly on the KinesisRecordProcessor
+ */
+class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAfter
+ with MockitoSugar {
+
+ val app = "TestKinesisReceiver"
+ val stream = "mySparkStream"
+ val endpoint = "endpoint-url"
+ val workerId = "dummyWorkerId"
+ val shardId = "dummyShardId"
+ val seqNum = "dummySeqNum"
+ val checkpointInterval = Duration(10)
+ val someSeqNum = Some(seqNum)
+
+ val record1 = new Record()
+ record1.setData(ByteBuffer.wrap("Spark In Action".getBytes(StandardCharsets.UTF_8)))
+ val record2 = new Record()
+ record2.setData(ByteBuffer.wrap("Learning Spark".getBytes(StandardCharsets.UTF_8)))
+ val batch = Arrays.asList(record1, record2)
+
+ var receiverMock: KinesisReceiver[Array[Byte]] = _
+ var checkpointerMock: IRecordProcessorCheckpointer = _
+
+ override def beforeFunction(): Unit = {
+ receiverMock = mock[KinesisReceiver[Array[Byte]]]
+ checkpointerMock = mock[IRecordProcessorCheckpointer]
+ }
+
+ test("check serializability of SerializableAWSCredentials") {
+ Utils.deserialize[SerializableAWSCredentials](
+ Utils.serialize(new SerializableAWSCredentials("x", "y")))
+ }
+
+ test("process records including store and set checkpointer") {
+ when(receiverMock.isStopped()).thenReturn(false)
+
+ val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId)
+ recordProcessor.initialize(shardId)
+ recordProcessor.processRecords(batch, checkpointerMock)
+
+ verify(receiverMock, times(1)).isStopped()
+ verify(receiverMock, times(1)).addRecords(shardId, batch)
+ verify(receiverMock, times(1)).setCheckpointer(shardId, checkpointerMock)
+ }
+
+ test("shouldn't store and update checkpointer when receiver is stopped") {
+ when(receiverMock.isStopped()).thenReturn(true)
+
+ val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId)
+ recordProcessor.processRecords(batch, checkpointerMock)
+
+ verify(receiverMock, times(1)).isStopped()
+ verify(receiverMock, never).addRecords(anyString, anyListOf(classOf[Record]))
+ verify(receiverMock, never).setCheckpointer(anyString, meq(checkpointerMock))
+ }
+
+ test("shouldn't update checkpointer when exception occurs during store") {
+ when(receiverMock.isStopped()).thenReturn(false)
+ when(
+ receiverMock.addRecords(anyString, anyListOf(classOf[Record]))
+ ).thenThrow(new RuntimeException())
+
+ intercept[RuntimeException] {
+ val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId)
+ recordProcessor.initialize(shardId)
+ recordProcessor.processRecords(batch, checkpointerMock)
+ }
+
+ verify(receiverMock, times(1)).isStopped()
+ verify(receiverMock, times(1)).addRecords(shardId, batch)
+ verify(receiverMock, never).setCheckpointer(anyString, meq(checkpointerMock))
+ }
+
+ test("shutdown should checkpoint if the reason is TERMINATE") {
+ when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum)
+
+ val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId)
+ recordProcessor.initialize(shardId)
+ recordProcessor.shutdown(checkpointerMock, ShutdownReason.TERMINATE)
+
+ verify(receiverMock, times(1)).removeCheckpointer(meq(shardId), meq(checkpointerMock))
+ }
+
+
+ test("shutdown should not checkpoint if the reason is something other than TERMINATE") {
+ when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum)
+
+ val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId)
+ recordProcessor.initialize(shardId)
+ recordProcessor.shutdown(checkpointerMock, ShutdownReason.ZOMBIE)
+ recordProcessor.shutdown(checkpointerMock, null)
+
+ verify(receiverMock, times(2)).removeCheckpointer(meq(shardId),
+ meq[IRecordProcessorCheckpointer](null))
+ }
+
+ test("retry success on first attempt") {
+ val expectedIsStopped = false
+ when(receiverMock.isStopped()).thenReturn(expectedIsStopped)
+
+ val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100)
+ assert(actualVal == expectedIsStopped)
+
+ verify(receiverMock, times(1)).isStopped()
+ }
+
+ test("retry success on second attempt after a Kinesis throttling exception") {
+ val expectedIsStopped = false
+ when(receiverMock.isStopped())
+ .thenThrow(new ThrottlingException("error message"))
+ .thenReturn(expectedIsStopped)
+
+ val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100)
+ assert(actualVal == expectedIsStopped)
+
+ verify(receiverMock, times(2)).isStopped()
+ }
+
+ test("retry success on second attempt after a Kinesis dependency exception") {
+ val expectedIsStopped = false
+ when(receiverMock.isStopped())
+ .thenThrow(new KinesisClientLibDependencyException("error message"))
+ .thenReturn(expectedIsStopped)
+
+ val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100)
+ assert(actualVal == expectedIsStopped)
+
+ verify(receiverMock, times(2)).isStopped()
+ }
+
+ test("retry failed after a shutdown exception") {
+ when(checkpointerMock.checkpoint()).thenThrow(new ShutdownException("error message"))
+
+ intercept[ShutdownException] {
+ KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100)
+ }
+
+ verify(checkpointerMock, times(1)).checkpoint()
+ }
+
+ test("retry failed after an invalid state exception") {
+ when(checkpointerMock.checkpoint()).thenThrow(new InvalidStateException("error message"))
+
+ intercept[InvalidStateException] {
+ KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100)
+ }
+
+ verify(checkpointerMock, times(1)).checkpoint()
+ }
+
+ test("retry failed after unexpected exception") {
+ when(checkpointerMock.checkpoint()).thenThrow(new RuntimeException("error message"))
+
+ intercept[RuntimeException] {
+ KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100)
+ }
+
+ verify(checkpointerMock, times(1)).checkpoint()
+ }
+
+ test("retry failed after exhausing all retries") {
+ val expectedErrorMessage = "final try error message"
+ when(checkpointerMock.checkpoint())
+ .thenThrow(new ThrottlingException("error message"))
+ .thenThrow(new ThrottlingException(expectedErrorMessage))
+
+ val exception = intercept[RuntimeException] {
+ KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100)
+ }
+ exception.getMessage().shouldBe(expectedErrorMessage)
+
+ verify(checkpointerMock, times(2)).checkpoint()
+ }
+}
diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
new file mode 100644
index 0000000000..ca5d13da46
--- /dev/null
+++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
@@ -0,0 +1,297 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kinesis
+
+import scala.collection.mutable
+import scala.concurrent.duration._
+import scala.language.postfixOps
+import scala.util.Random
+
+import com.amazonaws.regions.RegionUtils
+import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream
+import com.amazonaws.services.kinesis.model.Record
+import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
+import org.scalatest.Matchers._
+import org.scalatest.concurrent.Eventually
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.network.util.JavaUtils
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.{StorageLevel, StreamBlockId}
+import org.apache.spark.streaming._
+import org.apache.spark.streaming.dstream.ReceiverInputDStream
+import org.apache.spark.streaming.kinesis.KinesisTestUtils._
+import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult
+import org.apache.spark.streaming.scheduler.ReceivedBlockInfo
+import org.apache.spark.util.Utils
+
+abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFunSuite
+ with Eventually with BeforeAndAfter with BeforeAndAfterAll {
+
+ // This is the name that KCL will use to save metadata to DynamoDB
+ private val appName = s"KinesisStreamSuite-${math.abs(Random.nextLong())}"
+ private val batchDuration = Seconds(1)
+
+ // Dummy parameters for API testing
+ private val dummyEndpointUrl = defaultEndpointUrl
+ private val dummyRegionName = RegionUtils.getRegionByEndpoint(dummyEndpointUrl).getName()
+ private val dummyAWSAccessKey = "dummyAccessKey"
+ private val dummyAWSSecretKey = "dummySecretKey"
+
+ private var testUtils: KinesisTestUtils = null
+ private var ssc: StreamingContext = null
+ private var sc: SparkContext = null
+
+ override def beforeAll(): Unit = {
+ val conf = new SparkConf()
+ .setMaster("local[4]")
+ .setAppName("KinesisStreamSuite") // Setting Spark app name to Kinesis app name
+ sc = new SparkContext(conf)
+
+ runIfTestsEnabled("Prepare KinesisTestUtils") {
+ testUtils = new KPLBasedKinesisTestUtils()
+ testUtils.createStream()
+ }
+ }
+
+ override def afterAll(): Unit = {
+ if (ssc != null) {
+ ssc.stop()
+ }
+ if (sc != null) {
+ sc.stop()
+ }
+ if (testUtils != null) {
+ // Delete the Kinesis stream as well as the DynamoDB table generated by
+ // Kinesis Client Library when consuming the stream
+ testUtils.deleteStream()
+ testUtils.deleteDynamoDBTable(appName)
+ }
+ }
+
+ before {
+ ssc = new StreamingContext(sc, batchDuration)
+ }
+
+ after {
+ if (ssc != null) {
+ ssc.stop(stopSparkContext = false)
+ ssc = null
+ }
+ if (testUtils != null) {
+ testUtils.deleteDynamoDBTable(appName)
+ }
+ }
+
+ test("KinesisUtils API") {
+ // Tests the API, does not actually test data receiving
+ val kinesisStream1 = KinesisUtils.createStream(ssc, "mySparkStream",
+ dummyEndpointUrl, Seconds(2),
+ InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2)
+ val kinesisStream2 = KinesisUtils.createStream(ssc, "myAppNam", "mySparkStream",
+ dummyEndpointUrl, dummyRegionName,
+ InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2)
+ val kinesisStream3 = KinesisUtils.createStream(ssc, "myAppNam", "mySparkStream",
+ dummyEndpointUrl, dummyRegionName,
+ InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2,
+ dummyAWSAccessKey, dummyAWSSecretKey)
+ }
+
+ test("RDD generation") {
+ val inputStream = KinesisUtils.createStream(ssc, appName, "dummyStream",
+ dummyEndpointUrl, dummyRegionName, InitialPositionInStream.LATEST, Seconds(2),
+ StorageLevel.MEMORY_AND_DISK_2, dummyAWSAccessKey, dummyAWSSecretKey)
+ assert(inputStream.isInstanceOf[KinesisInputDStream[Array[Byte]]])
+
+ val kinesisStream = inputStream.asInstanceOf[KinesisInputDStream[Array[Byte]]]
+ val time = Time(1000)
+
+ // Generate block info data for testing
+ val seqNumRanges1 = SequenceNumberRanges(
+ SequenceNumberRange("fakeStream", "fakeShardId", "xxx", "yyy"))
+ val blockId1 = StreamBlockId(kinesisStream.id, 123)
+ val blockInfo1 = ReceivedBlockInfo(
+ 0, None, Some(seqNumRanges1), new BlockManagerBasedStoreResult(blockId1, None))
+
+ val seqNumRanges2 = SequenceNumberRanges(
+ SequenceNumberRange("fakeStream", "fakeShardId", "aaa", "bbb"))
+ val blockId2 = StreamBlockId(kinesisStream.id, 345)
+ val blockInfo2 = ReceivedBlockInfo(
+ 0, None, Some(seqNumRanges2), new BlockManagerBasedStoreResult(blockId2, None))
+
+ // Verify that the generated KinesisBackedBlockRDD has the all the right information
+ val blockInfos = Seq(blockInfo1, blockInfo2)
+ val nonEmptyRDD = kinesisStream.createBlockRDD(time, blockInfos)
+ nonEmptyRDD shouldBe a [KinesisBackedBlockRDD[_]]
+ val kinesisRDD = nonEmptyRDD.asInstanceOf[KinesisBackedBlockRDD[_]]
+ assert(kinesisRDD.regionName === dummyRegionName)
+ assert(kinesisRDD.endpointUrl === dummyEndpointUrl)
+ assert(kinesisRDD.retryTimeoutMs === batchDuration.milliseconds)
+ assert(kinesisRDD.awsCredentialsOption ===
+ Some(SerializableAWSCredentials(dummyAWSAccessKey, dummyAWSSecretKey)))
+ assert(nonEmptyRDD.partitions.size === blockInfos.size)
+ nonEmptyRDD.partitions.foreach { _ shouldBe a [KinesisBackedBlockRDDPartition] }
+ val partitions = nonEmptyRDD.partitions.map {
+ _.asInstanceOf[KinesisBackedBlockRDDPartition] }.toSeq
+ assert(partitions.map { _.seqNumberRanges } === Seq(seqNumRanges1, seqNumRanges2))
+ assert(partitions.map { _.blockId } === Seq(blockId1, blockId2))
+ assert(partitions.forall { _.isBlockIdValid === true })
+
+ // Verify that KinesisBackedBlockRDD is generated even when there are no blocks
+ val emptyRDD = kinesisStream.createBlockRDD(time, Seq.empty)
+ emptyRDD shouldBe a [KinesisBackedBlockRDD[Array[Byte]]]
+ emptyRDD.partitions shouldBe empty
+
+ // Verify that the KinesisBackedBlockRDD has isBlockValid = false when blocks are invalid
+ blockInfos.foreach { _.setBlockIdInvalid() }
+ kinesisStream.createBlockRDD(time, blockInfos).partitions.foreach { partition =>
+ assert(partition.asInstanceOf[KinesisBackedBlockRDDPartition].isBlockIdValid === false)
+ }
+ }
+
+
+ /**
+ * Test the stream by sending data to a Kinesis stream and receiving from it.
+ * This test is not run by default as it requires AWS credentials that the test
+ * environment may not have. Even if there is AWS credentials available, the user
+ * may not want to run these tests to avoid the Kinesis costs. To enable this test,
+ * you must have AWS credentials available through the default AWS provider chain,
+ * and you have to set the system environment variable RUN_KINESIS_TESTS=1 .
+ */
+ testIfEnabled("basic operation") {
+ val awsCredentials = KinesisTestUtils.getAWSCredentials()
+ val stream = KinesisUtils.createStream(ssc, appName, testUtils.streamName,
+ testUtils.endpointUrl, testUtils.regionName, InitialPositionInStream.LATEST,
+ Seconds(10), StorageLevel.MEMORY_ONLY,
+ awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey)
+
+ val collected = new mutable.HashSet[Int] with mutable.SynchronizedSet[Int]
+ stream.map { bytes => new String(bytes).toInt }.foreachRDD { rdd =>
+ collected ++= rdd.collect()
+ logInfo("Collected = " + collected.mkString(", "))
+ }
+ ssc.start()
+
+ val testData = 1 to 10
+ eventually(timeout(120 seconds), interval(10 second)) {
+ testUtils.pushData(testData, aggregateTestData)
+ assert(collected === testData.toSet, "\nData received does not match data sent")
+ }
+ ssc.stop(stopSparkContext = false)
+ }
+
+ testIfEnabled("custom message handling") {
+ val awsCredentials = KinesisTestUtils.getAWSCredentials()
+ def addFive(r: Record): Int = JavaUtils.bytesToString(r.getData).toInt + 5
+ val stream = KinesisUtils.createStream(ssc, appName, testUtils.streamName,
+ testUtils.endpointUrl, testUtils.regionName, InitialPositionInStream.LATEST,
+ Seconds(10), StorageLevel.MEMORY_ONLY, addFive,
+ awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey)
+
+ stream shouldBe a [ReceiverInputDStream[_]]
+
+ val collected = new mutable.HashSet[Int] with mutable.SynchronizedSet[Int]
+ stream.foreachRDD { rdd =>
+ collected ++= rdd.collect()
+ logInfo("Collected = " + collected.mkString(", "))
+ }
+ ssc.start()
+
+ val testData = 1 to 10
+ eventually(timeout(120 seconds), interval(10 second)) {
+ testUtils.pushData(testData, aggregateTestData)
+ val modData = testData.map(_ + 5)
+ assert(collected === modData.toSet, "\nData received does not match data sent")
+ }
+ ssc.stop(stopSparkContext = false)
+ }
+
+ testIfEnabled("failure recovery") {
+ val sparkConf = new SparkConf().setMaster("local[4]").setAppName(this.getClass.getSimpleName)
+ val checkpointDir = Utils.createTempDir().getAbsolutePath
+
+ ssc = new StreamingContext(sc, Milliseconds(1000))
+ ssc.checkpoint(checkpointDir)
+
+ val awsCredentials = KinesisTestUtils.getAWSCredentials()
+ val collectedData = new mutable.HashMap[Time, (Array[SequenceNumberRanges], Seq[Int])]
+
+ val kinesisStream = KinesisUtils.createStream(ssc, appName, testUtils.streamName,
+ testUtils.endpointUrl, testUtils.regionName, InitialPositionInStream.LATEST,
+ Seconds(10), StorageLevel.MEMORY_ONLY,
+ awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey)
+
+ // Verify that the generated RDDs are KinesisBackedBlockRDDs, and collect the data in each batch
+ kinesisStream.foreachRDD((rdd: RDD[Array[Byte]], time: Time) => {
+ val kRdd = rdd.asInstanceOf[KinesisBackedBlockRDD[Array[Byte]]]
+ val data = rdd.map { bytes => new String(bytes).toInt }.collect().toSeq
+ collectedData.synchronized {
+ collectedData(time) = (kRdd.arrayOfseqNumberRanges, data)
+ }
+ })
+
+ ssc.remember(Minutes(60)) // remember all the batches so that they are all saved in checkpoint
+ ssc.start()
+
+ def numBatchesWithData: Int =
+ collectedData.synchronized { collectedData.count(_._2._2.nonEmpty) }
+
+ def isCheckpointPresent: Boolean = Checkpoint.getCheckpointFiles(checkpointDir).nonEmpty
+
+ // Run until there are at least 10 batches with some data in them
+ // If this times out because numBatchesWithData is empty, then its likely that foreachRDD
+ // function failed with exceptions, and nothing got added to `collectedData`
+ eventually(timeout(2 minutes), interval(1 seconds)) {
+ testUtils.pushData(1 to 5, aggregateTestData)
+ assert(isCheckpointPresent && numBatchesWithData > 10)
+ }
+ ssc.stop(stopSparkContext = true) // stop the SparkContext so that the blocks are not reused
+
+ // Restart the context from checkpoint and verify whether the
+ logInfo("Restarting from checkpoint")
+ ssc = new StreamingContext(checkpointDir)
+ ssc.start()
+ val recoveredKinesisStream = ssc.graph.getInputStreams().head
+
+ // Verify that the recomputed RDDs are KinesisBackedBlockRDDs with the same sequence ranges
+ // and return the same data
+ collectedData.synchronized {
+ val times = collectedData.keySet
+ times.foreach { time =>
+ val (arrayOfSeqNumRanges, data) = collectedData(time)
+ val rdd = recoveredKinesisStream.getOrCompute(time).get.asInstanceOf[RDD[Array[Byte]]]
+ rdd shouldBe a[KinesisBackedBlockRDD[_]]
+
+ // Verify the recovered sequence ranges
+ val kRdd = rdd.asInstanceOf[KinesisBackedBlockRDD[Array[Byte]]]
+ assert(kRdd.arrayOfseqNumberRanges.size === arrayOfSeqNumRanges.size)
+ arrayOfSeqNumRanges.zip(kRdd.arrayOfseqNumberRanges).foreach { case (expected, found) =>
+ assert(expected.ranges.toSeq === found.ranges.toSeq)
+ }
+
+ // Verify the recovered data
+ assert(rdd.map { bytes => new String(bytes).toInt }.collect().toSeq === data)
+ }
+ }
+ ssc.stop()
+ }
+}
+
+class WithAggregationKinesisStreamSuite extends KinesisStreamTests(aggregateTestData = true)
+
+class WithoutAggregationKinesisStreamSuite extends KinesisStreamTests(aggregateTestData = false)