From d249636e59fabd8ca57a47dc2cbad9c4a4e7a750 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 23 Jul 2015 20:06:54 -0700 Subject: [SPARK-9216] [STREAMING] Define KinesisBackedBlockRDDs For more information see master JIRA: https://issues.apache.org/jira/browse/SPARK-9215 Design Doc: https://docs.google.com/document/d/1k0dl270EnK7uExrsCE7jYw7PYx0YC935uBcxn3p0f58/edit Author: Tathagata Das Closes #7578 from tdas/kinesis-rdd and squashes the following commits: 543d208 [Tathagata Das] Fixed scala style 5082a30 [Tathagata Das] Fixed scala style 3f40c2d [Tathagata Das] Addressed comments c4f25d2 [Tathagata Das] Addressed comment d3d64d1 [Tathagata Das] Minor update f6e35c8 [Tathagata Das] Added retry logic to make it more robust 8874b70 [Tathagata Das] Updated Kinesis RDD 575bdbc [Tathagata Das] Fix scala style issues 4a36096 [Tathagata Das] Add license 5da3995 [Tathagata Das] Changed KinesisSuiteHelper to KinesisFunSuite 528e206 [Tathagata Das] Merge remote-tracking branch 'apache-github/master' into kinesis-rdd 3ae0814 [Tathagata Das] Added KinesisBackedBlockRDD --- .../streaming/kinesis/KinesisBackedBlockRDD.scala | 285 +++++++++++++++++++++ .../spark/streaming/kinesis/KinesisTestUtils.scala | 2 +- .../kinesis/KinesisBackedBlockRDDSuite.scala | 246 ++++++++++++++++++ .../spark/streaming/kinesis/KinesisFunSuite.scala | 13 +- .../streaming/kinesis/KinesisStreamSuite.scala | 4 +- 5 files changed, 545 insertions(+), 5 deletions(-) create mode 100644 extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala create mode 100644 extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala new file mode 100644 index 0000000000..8f144a4d97 --- /dev/null +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala @@ -0,0 +1,285 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import scala.collection.JavaConversions._ +import scala.util.control.NonFatal + +import com.amazonaws.auth.{AWSCredentials, DefaultAWSCredentialsProviderChain} +import com.amazonaws.services.kinesis.AmazonKinesisClient +import com.amazonaws.services.kinesis.model._ + +import org.apache.spark._ +import org.apache.spark.rdd.{BlockRDD, BlockRDDPartition} +import org.apache.spark.storage.BlockId +import org.apache.spark.util.NextIterator + + +/** Class representing a range of Kinesis sequence numbers. Both sequence numbers are inclusive. */ +private[kinesis] +case class SequenceNumberRange( + streamName: String, shardId: String, fromSeqNumber: String, toSeqNumber: String) + +/** Class representing an array of Kinesis sequence number ranges */ +private[kinesis] +case class SequenceNumberRanges(ranges: Array[SequenceNumberRange]) { + def isEmpty(): Boolean = ranges.isEmpty + def nonEmpty(): Boolean = ranges.nonEmpty + override def toString(): String = ranges.mkString("SequenceNumberRanges(", ", ", ")") +} + +private[kinesis] +object SequenceNumberRanges { + def apply(range: SequenceNumberRange): SequenceNumberRanges = { + new SequenceNumberRanges(Array(range)) + } +} + + +/** Partition storing the information of the ranges of Kinesis sequence numbers to read */ +private[kinesis] +class KinesisBackedBlockRDDPartition( + idx: Int, + blockId: BlockId, + val isBlockIdValid: Boolean, + val seqNumberRanges: SequenceNumberRanges + ) extends BlockRDDPartition(blockId, idx) + +/** + * A BlockRDD where the block data is backed by Kinesis, which can accessed using the + * sequence numbers of the corresponding blocks. + */ +private[kinesis] +class KinesisBackedBlockRDD( + sc: SparkContext, + regionId: String, + endpointUrl: String, + @transient blockIds: Array[BlockId], + @transient arrayOfseqNumberRanges: Array[SequenceNumberRanges], + @transient isBlockIdValid: Array[Boolean] = Array.empty, + retryTimeoutMs: Int = 10000, + awsCredentialsOption: Option[SerializableAWSCredentials] = None + ) extends BlockRDD[Array[Byte]](sc, blockIds) { + + require(blockIds.length == arrayOfseqNumberRanges.length, + "Number of blockIds is not equal to the number of sequence number ranges") + + override def isValid(): Boolean = true + + override def getPartitions: Array[Partition] = { + Array.tabulate(blockIds.length) { i => + val isValid = if (isBlockIdValid.length == 0) true else isBlockIdValid(i) + new KinesisBackedBlockRDDPartition(i, blockIds(i), isValid, arrayOfseqNumberRanges(i)) + } + } + + override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { + val blockManager = SparkEnv.get.blockManager + val partition = split.asInstanceOf[KinesisBackedBlockRDDPartition] + val blockId = partition.blockId + + def getBlockFromBlockManager(): Option[Iterator[Array[Byte]]] = { + logDebug(s"Read partition data of $this from block manager, block $blockId") + blockManager.get(blockId).map(_.data.asInstanceOf[Iterator[Array[Byte]]]) + } + + def getBlockFromKinesis(): Iterator[Array[Byte]] = { + val credenentials = awsCredentialsOption.getOrElse { + new DefaultAWSCredentialsProviderChain().getCredentials() + } + partition.seqNumberRanges.ranges.iterator.flatMap { range => + new KinesisSequenceRangeIterator( + credenentials, endpointUrl, regionId, range, retryTimeoutMs) + } + } + if (partition.isBlockIdValid) { + getBlockFromBlockManager().getOrElse { getBlockFromKinesis() } + } else { + getBlockFromKinesis() + } + } +} + + +/** + * An iterator that return the Kinesis data based on the given range of sequence numbers. + * Internally, it repeatedly fetches sets of records starting from the fromSequenceNumber, + * until the endSequenceNumber is reached. + */ +private[kinesis] +class KinesisSequenceRangeIterator( + credentials: AWSCredentials, + endpointUrl: String, + regionId: String, + range: SequenceNumberRange, + retryTimeoutMs: Int + ) extends NextIterator[Array[Byte]] with Logging { + + private val client = new AmazonKinesisClient(credentials) + private val streamName = range.streamName + private val shardId = range.shardId + + private var toSeqNumberReceived = false + private var lastSeqNumber: String = null + private var internalIterator: Iterator[Record] = null + + client.setEndpoint(endpointUrl, "kinesis", regionId) + + override protected def getNext(): Array[Byte] = { + var nextBytes: Array[Byte] = null + if (toSeqNumberReceived) { + finished = true + } else { + + if (internalIterator == null) { + + // If the internal iterator has not been initialized, + // then fetch records from starting sequence number + internalIterator = getRecords(ShardIteratorType.AT_SEQUENCE_NUMBER, range.fromSeqNumber) + } else if (!internalIterator.hasNext) { + + // If the internal iterator does not have any more records, + // then fetch more records after the last consumed sequence number + internalIterator = getRecords(ShardIteratorType.AFTER_SEQUENCE_NUMBER, lastSeqNumber) + } + + if (!internalIterator.hasNext) { + + // If the internal iterator still does not have any data, then throw exception + // and terminate this iterator + finished = true + throw new SparkException( + s"Could not read until the end sequence number of the range: $range") + } else { + + // Get the record, copy the data into a byte array and remember its sequence number + val nextRecord: Record = internalIterator.next() + val byteBuffer = nextRecord.getData() + nextBytes = new Array[Byte](byteBuffer.remaining()) + byteBuffer.get(nextBytes) + lastSeqNumber = nextRecord.getSequenceNumber() + + // If the this record's sequence number matches the stopping sequence number, then make sure + // the iterator is marked finished next time getNext() is called + if (nextRecord.getSequenceNumber == range.toSeqNumber) { + toSeqNumberReceived = true + } + } + + } + nextBytes + } + + override protected def close(): Unit = { + client.shutdown() + } + + /** + * Get records starting from or after the given sequence number. + */ + private def getRecords(iteratorType: ShardIteratorType, seqNum: String): Iterator[Record] = { + val shardIterator = getKinesisIterator(iteratorType, seqNum) + val result = getRecordsAndNextKinesisIterator(shardIterator) + result._1 + } + + /** + * Get the records starting from using a Kinesis shard iterator (which is a progress handle + * to get records from Kinesis), and get the next shard iterator for next consumption. + */ + private def getRecordsAndNextKinesisIterator( + shardIterator: String): (Iterator[Record], String) = { + val getRecordsRequest = new GetRecordsRequest + getRecordsRequest.setRequestCredentials(credentials) + getRecordsRequest.setShardIterator(shardIterator) + val getRecordsResult = retryOrTimeout[GetRecordsResult]( + s"getting records using shard iterator") { + client.getRecords(getRecordsRequest) + } + (getRecordsResult.getRecords.iterator(), getRecordsResult.getNextShardIterator) + } + + /** + * Get the Kinesis shard iterator for getting records starting from or after the given + * sequence number. + */ + private def getKinesisIterator( + iteratorType: ShardIteratorType, + sequenceNumber: String): String = { + val getShardIteratorRequest = new GetShardIteratorRequest + getShardIteratorRequest.setRequestCredentials(credentials) + getShardIteratorRequest.setStreamName(streamName) + getShardIteratorRequest.setShardId(shardId) + getShardIteratorRequest.setShardIteratorType(iteratorType.toString) + getShardIteratorRequest.setStartingSequenceNumber(sequenceNumber) + val getShardIteratorResult = retryOrTimeout[GetShardIteratorResult]( + s"getting shard iterator from sequence number $sequenceNumber") { + client.getShardIterator(getShardIteratorRequest) + } + getShardIteratorResult.getShardIterator + } + + /** Helper method to retry Kinesis API request with exponential backoff and timeouts */ + private def retryOrTimeout[T](message: String)(body: => T): T = { + import KinesisSequenceRangeIterator._ + + var startTimeMs = System.currentTimeMillis() + var retryCount = 0 + var waitTimeMs = MIN_RETRY_WAIT_TIME_MS + var result: Option[T] = None + var lastError: Throwable = null + + def isTimedOut = (System.currentTimeMillis() - startTimeMs) >= retryTimeoutMs + def isMaxRetryDone = retryCount >= MAX_RETRIES + + while (result.isEmpty && !isTimedOut && !isMaxRetryDone) { + if (retryCount > 0) { // wait only if this is a retry + Thread.sleep(waitTimeMs) + waitTimeMs *= 2 // if you have waited, then double wait time for next round + } + try { + result = Some(body) + } catch { + case NonFatal(t) => + lastError = t + t match { + case ptee: ProvisionedThroughputExceededException => + logWarning(s"Error while $message [attempt = ${retryCount + 1}]", ptee) + case e: Throwable => + throw new SparkException(s"Error while $message", e) + } + } + retryCount += 1 + } + result.getOrElse { + if (isTimedOut) { + throw new SparkException( + s"Timed out after $retryTimeoutMs ms while $message, last exception: ", lastError) + } else { + throw new SparkException( + s"Gave up after $retryCount retries while $message, last exception: ", lastError) + } + } + } +} + +private[streaming] +object KinesisSequenceRangeIterator { + val MAX_RETRIES = 3 + val MIN_RETRY_WAIT_TIME_MS = 100 +} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala index f6bf552e6b..0ff1b7ed0f 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala @@ -177,7 +177,7 @@ private class KinesisTestUtils( private[kinesis] object KinesisTestUtils { - val envVarName = "RUN_KINESIS_TESTS" + val envVarName = "ENABLE_KINESIS_TESTS" val shouldRunTests = sys.env.get(envVarName) == Some("1") diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala new file mode 100644 index 0000000000..b2e2a4246d --- /dev/null +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} + +import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId} +import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite} + +class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll { + + private val regionId = "us-east-1" + private val endpointUrl = "https://kinesis.us-east-1.amazonaws.com" + private val testData = 1 to 8 + + private var testUtils: KinesisTestUtils = null + private var shardIds: Seq[String] = null + private var shardIdToData: Map[String, Seq[Int]] = null + private var shardIdToSeqNumbers: Map[String, Seq[String]] = null + private var shardIdToDataAndSeqNumbers: Map[String, Seq[(Int, String)]] = null + private var shardIdToRange: Map[String, SequenceNumberRange] = null + private var allRanges: Seq[SequenceNumberRange] = null + + private var sc: SparkContext = null + private var blockManager: BlockManager = null + + + override def beforeAll(): Unit = { + runIfTestsEnabled("Prepare KinesisTestUtils") { + testUtils = new KinesisTestUtils(endpointUrl) + testUtils.createStream() + + shardIdToDataAndSeqNumbers = testUtils.pushData(testData) + require(shardIdToDataAndSeqNumbers.size > 1, "Need data to be sent to multiple shards") + + shardIds = shardIdToDataAndSeqNumbers.keySet.toSeq + shardIdToData = shardIdToDataAndSeqNumbers.mapValues { _.map { _._1 }} + shardIdToSeqNumbers = shardIdToDataAndSeqNumbers.mapValues { _.map { _._2 }} + shardIdToRange = shardIdToSeqNumbers.map { case (shardId, seqNumbers) => + val seqNumRange = SequenceNumberRange( + testUtils.streamName, shardId, seqNumbers.head, seqNumbers.last) + (shardId, seqNumRange) + } + allRanges = shardIdToRange.values.toSeq + + val conf = new SparkConf().setMaster("local[4]").setAppName("KinesisBackedBlockRDDSuite") + sc = new SparkContext(conf) + blockManager = sc.env.blockManager + } + } + + override def afterAll(): Unit = { + if (sc != null) { + sc.stop() + } + } + + testIfEnabled("Basic reading from Kinesis") { + // Verify all data using multiple ranges in a single RDD partition + val receivedData1 = new KinesisBackedBlockRDD(sc, regionId, endpointUrl, + fakeBlockIds(1), + Array(SequenceNumberRanges(allRanges.toArray)) + ).map { bytes => new String(bytes).toInt }.collect() + assert(receivedData1.toSet === testData.toSet) + + // Verify all data using one range in each of the multiple RDD partitions + val receivedData2 = new KinesisBackedBlockRDD(sc, regionId, endpointUrl, + fakeBlockIds(allRanges.size), + allRanges.map { range => SequenceNumberRanges(Array(range)) }.toArray + ).map { bytes => new String(bytes).toInt }.collect() + assert(receivedData2.toSet === testData.toSet) + + // Verify ordering within each partition + val receivedData3 = new KinesisBackedBlockRDD(sc, regionId, endpointUrl, + fakeBlockIds(allRanges.size), + allRanges.map { range => SequenceNumberRanges(Array(range)) }.toArray + ).map { bytes => new String(bytes).toInt }.collectPartitions() + assert(receivedData3.length === allRanges.size) + for (i <- 0 until allRanges.size) { + assert(receivedData3(i).toSeq === shardIdToData(allRanges(i).shardId)) + } + } + + testIfEnabled("Read data available in both block manager and Kinesis") { + testRDD(numPartitions = 2, numPartitionsInBM = 2, numPartitionsInKinesis = 2) + } + + testIfEnabled("Read data available only in block manager, not in Kinesis") { + testRDD(numPartitions = 2, numPartitionsInBM = 2, numPartitionsInKinesis = 0) + } + + testIfEnabled("Read data available only in Kinesis, not in block manager") { + testRDD(numPartitions = 2, numPartitionsInBM = 0, numPartitionsInKinesis = 2) + } + + testIfEnabled("Read data available partially in block manager, rest in Kinesis") { + testRDD(numPartitions = 2, numPartitionsInBM = 1, numPartitionsInKinesis = 1) + } + + testIfEnabled("Test isBlockValid skips block fetching from block manager") { + testRDD(numPartitions = 2, numPartitionsInBM = 2, numPartitionsInKinesis = 0, + testIsBlockValid = true) + } + + testIfEnabled("Test whether RDD is valid after removing blocks from block anager") { + testRDD(numPartitions = 2, numPartitionsInBM = 2, numPartitionsInKinesis = 2, + testBlockRemove = true) + } + + /** + * Test the WriteAheadLogBackedRDD, by writing some partitions of the data to block manager + * and the rest to a write ahead log, and then reading reading it all back using the RDD. + * It can also test if the partitions that were read from the log were again stored in + * block manager. + * + * + * + * @param numPartitions Number of partitions in RDD + * @param numPartitionsInBM Number of partitions to write to the BlockManager. + * Partitions 0 to (numPartitionsInBM-1) will be written to BlockManager + * @param numPartitionsInKinesis Number of partitions to write to the Kinesis. + * Partitions (numPartitions - 1 - numPartitionsInKinesis) to + * (numPartitions - 1) will be written to Kinesis + * @param testIsBlockValid Test whether setting isBlockValid to false skips block fetching + * @param testBlockRemove Test whether calling rdd.removeBlock() makes the RDD still usable with + * reads falling back to the WAL + * Example with numPartitions = 5, numPartitionsInBM = 3, and numPartitionsInWAL = 4 + * + * numPartitionsInBM = 3 + * |------------------| + * | | + * 0 1 2 3 4 + * | | + * |-------------------------| + * numPartitionsInKinesis = 4 + */ + private def testRDD( + numPartitions: Int, + numPartitionsInBM: Int, + numPartitionsInKinesis: Int, + testIsBlockValid: Boolean = false, + testBlockRemove: Boolean = false + ): Unit = { + require(shardIds.size > 1, "Need at least 2 shards to test") + require(numPartitionsInBM <= shardIds.size , + "Number of partitions in BlockManager cannot be more than the Kinesis test shards available") + require(numPartitionsInKinesis <= shardIds.size , + "Number of partitions in Kinesis cannot be more than the Kinesis test shards available") + require(numPartitionsInBM <= numPartitions, + "Number of partitions in BlockManager cannot be more than that in RDD") + require(numPartitionsInKinesis <= numPartitions, + "Number of partitions in Kinesis cannot be more than that in RDD") + + // Put necessary blocks in the block manager + val blockIds = fakeBlockIds(numPartitions) + blockIds.foreach(blockManager.removeBlock(_)) + (0 until numPartitionsInBM).foreach { i => + val blockData = shardIdToData(shardIds(i)).iterator.map { _.toString.getBytes() } + blockManager.putIterator(blockIds(i), blockData, StorageLevel.MEMORY_ONLY) + } + + // Create the necessary ranges to use in the RDD + val fakeRanges = Array.fill(numPartitions - numPartitionsInKinesis)( + SequenceNumberRanges(SequenceNumberRange("fakeStream", "fakeShardId", "xxx", "yyy"))) + val realRanges = Array.tabulate(numPartitionsInKinesis) { i => + val range = shardIdToRange(shardIds(i + (numPartitions - numPartitionsInKinesis))) + SequenceNumberRanges(Array(range)) + } + val ranges = (fakeRanges ++ realRanges) + + + // Make sure that the left `numPartitionsInBM` blocks are in block manager, and others are not + require( + blockIds.take(numPartitionsInBM).forall(blockManager.get(_).nonEmpty), + "Expected blocks not in BlockManager" + ) + + require( + blockIds.drop(numPartitionsInBM).forall(blockManager.get(_).isEmpty), + "Unexpected blocks in BlockManager" + ) + + // Make sure that the right sequence `numPartitionsInKinesis` are configured, and others are not + require( + ranges.takeRight(numPartitionsInKinesis).forall { + _.ranges.forall { _.streamName == testUtils.streamName } + }, "Incorrect configuration of RDD, expected ranges not set: " + ) + + require( + ranges.dropRight(numPartitionsInKinesis).forall { + _.ranges.forall { _.streamName != testUtils.streamName } + }, "Incorrect configuration of RDD, unexpected ranges set" + ) + + val rdd = new KinesisBackedBlockRDD(sc, regionId, endpointUrl, blockIds, ranges) + val collectedData = rdd.map { bytes => + new String(bytes).toInt + }.collect() + assert(collectedData.toSet === testData.toSet) + + // Verify that the block fetching is skipped when isBlockValid is set to false. + // This is done by using a RDD whose data is only in memory but is set to skip block fetching + // Using that RDD will throw exception, as it skips block fetching even if the blocks are in + // in BlockManager. + if (testIsBlockValid) { + require(numPartitionsInBM === numPartitions, "All partitions must be in BlockManager") + require(numPartitionsInKinesis === 0, "No partitions must be in Kinesis") + val rdd2 = new KinesisBackedBlockRDD(sc, regionId, endpointUrl, blockIds.toArray, + ranges, isBlockIdValid = Array.fill(blockIds.length)(false)) + intercept[SparkException] { + rdd2.collect() + } + } + + // Verify that the RDD is not invalid after the blocks are removed and can still read data + // from write ahead log + if (testBlockRemove) { + require(numPartitions === numPartitionsInKinesis, + "All partitions must be in WAL for this test") + require(numPartitionsInBM > 0, "Some partitions must be in BlockManager for this test") + rdd.removeBlocks() + assert(rdd.map { bytes => new String(bytes).toInt }.collect().toSet === testData.toSet) + } + } + + /** Generate fake block ids */ + private def fakeBlockIds(num: Int): Array[BlockId] = { + Array.tabulate(num) { i => new StreamBlockId(0, i) } + } +} diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala index 6d011f295e..8373138785 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisFunSuite.scala @@ -23,15 +23,24 @@ import org.apache.spark.SparkFunSuite * Helper class that runs Kinesis real data transfer tests or * ignores them based on env variable is set or not. */ -trait KinesisSuiteHelper { self: SparkFunSuite => +trait KinesisFunSuite extends SparkFunSuite { import KinesisTestUtils._ /** Run the test if environment variable is set or ignore the test */ - def testOrIgnore(testName: String)(testBody: => Unit) { + def testIfEnabled(testName: String)(testBody: => Unit) { if (shouldRunTests) { test(testName)(testBody) } else { ignore(s"$testName [enable by setting env var $envVarName=1]")(testBody) } } + + /** Run the give body of code only if Kinesis tests are enabled */ + def runIfTestsEnabled(message: String)(body: => Unit): Unit = { + if (shouldRunTests) { + body + } else { + ignore(s"$message [enable by setting env var $envVarName=1]")() + } + } } diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index 50f71413ab..f9c952b946 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} -class KinesisStreamSuite extends SparkFunSuite with KinesisSuiteHelper +class KinesisStreamSuite extends KinesisFunSuite with Eventually with BeforeAndAfter with BeforeAndAfterAll { // This is the name that KCL uses to save metadata to DynamoDB @@ -83,7 +83,7 @@ class KinesisStreamSuite extends SparkFunSuite with KinesisSuiteHelper * you must have AWS credentials available through the default AWS provider chain, * and you have to set the system environment variable RUN_KINESIS_TESTS=1 . */ - testOrIgnore("basic operation") { + testIfEnabled("basic operation") { val kinesisTestUtils = new KinesisTestUtils() try { kinesisTestUtils.createStream() -- cgit v1.2.3