From 63accc79625d8a03d0624717af5e1d81b18a6da3 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Sun, 25 Oct 2015 21:18:35 -0700 Subject: [SPARK-10891][STREAMING][KINESIS] Add MessageHandler to KinesisUtils.createStream similar to Direct Kafka This PR allows users to map a Kinesis `Record` to a generic `T` when creating a Kinesis stream. This is particularly useful, if you would like to do extra work with Kinesis metadata such as sequence number, and partition key. TODO: - [x] add tests Author: Burak Yavuz Closes #8954 from brkyvz/kinesis-handler. --- .../streaming/kinesis/KinesisBackedBlockRDD.scala | 35 ++- .../streaming/kinesis/KinesisInputDStream.scala | 15 +- .../spark/streaming/kinesis/KinesisReceiver.scala | 18 +- .../streaming/kinesis/KinesisRecordProcessor.scala | 4 +- .../spark/streaming/kinesis/KinesisUtils.scala | 247 +++++++++++++++++++-- .../streaming/kinesis/JavaKinesisStreamSuite.java | 29 ++- .../kinesis/KinesisBackedBlockRDDSuite.scala | 16 +- .../streaming/kinesis/KinesisReceiverSuite.scala | 4 +- .../streaming/kinesis/KinesisStreamSuite.scala | 44 +++- 9 files changed, 337 insertions(+), 75 deletions(-) diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala index 5d32fa699a..000897a4e7 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala @@ -18,6 +18,7 @@ package org.apache.spark.streaming.kinesis import scala.collection.JavaConverters._ +import scala.reflect.ClassTag import scala.util.control.NonFatal import com.amazonaws.auth.{AWSCredentials, DefaultAWSCredentialsProviderChain} @@ -67,7 +68,7 @@ class KinesisBackedBlockRDDPartition( * sequence numbers of the corresponding blocks. */ private[kinesis] -class KinesisBackedBlockRDD( +class KinesisBackedBlockRDD[T: ClassTag]( @transient sc: SparkContext, val regionName: String, val endpointUrl: String, @@ -75,8 +76,9 @@ class KinesisBackedBlockRDD( @transient val arrayOfseqNumberRanges: Array[SequenceNumberRanges], @transient isBlockIdValid: Array[Boolean] = Array.empty, val retryTimeoutMs: Int = 10000, + val messageHandler: Record => T = KinesisUtils.defaultMessageHandler _, val awsCredentialsOption: Option[SerializableAWSCredentials] = None - ) extends BlockRDD[Array[Byte]](sc, blockIds) { + ) extends BlockRDD[T](sc, blockIds) { require(blockIds.length == arrayOfseqNumberRanges.length, "Number of blockIds is not equal to the number of sequence number ranges") @@ -90,23 +92,23 @@ class KinesisBackedBlockRDD( } } - override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { + override def compute(split: Partition, context: TaskContext): Iterator[T] = { val blockManager = SparkEnv.get.blockManager val partition = split.asInstanceOf[KinesisBackedBlockRDDPartition] val blockId = partition.blockId - def getBlockFromBlockManager(): Option[Iterator[Array[Byte]]] = { + def getBlockFromBlockManager(): Option[Iterator[T]] = { logDebug(s"Read partition data of $this from block manager, block $blockId") - blockManager.get(blockId).map(_.data.asInstanceOf[Iterator[Array[Byte]]]) + blockManager.get(blockId).map(_.data.asInstanceOf[Iterator[T]]) } - def getBlockFromKinesis(): Iterator[Array[Byte]] = { - val credenentials = awsCredentialsOption.getOrElse { + def getBlockFromKinesis(): Iterator[T] = { + val credentials = awsCredentialsOption.getOrElse { new DefaultAWSCredentialsProviderChain().getCredentials() } partition.seqNumberRanges.ranges.iterator.flatMap { range => - new KinesisSequenceRangeIterator( - credenentials, endpointUrl, regionName, range, retryTimeoutMs) + new KinesisSequenceRangeIterator(credentials, endpointUrl, regionName, + range, retryTimeoutMs).map(messageHandler) } } if (partition.isBlockIdValid) { @@ -129,8 +131,7 @@ class KinesisSequenceRangeIterator( endpointUrl: String, regionId: String, range: SequenceNumberRange, - retryTimeoutMs: Int - ) extends NextIterator[Array[Byte]] with Logging { + retryTimeoutMs: Int) extends NextIterator[Record] with Logging { private val client = new AmazonKinesisClient(credentials) private val streamName = range.streamName @@ -142,8 +143,8 @@ class KinesisSequenceRangeIterator( client.setEndpoint(endpointUrl, "kinesis", regionId) - override protected def getNext(): Array[Byte] = { - var nextBytes: Array[Byte] = null + override protected def getNext(): Record = { + var nextRecord: Record = null if (toSeqNumberReceived) { finished = true } else { @@ -170,10 +171,7 @@ class KinesisSequenceRangeIterator( } else { // Get the record, copy the data into a byte array and remember its sequence number - val nextRecord: Record = internalIterator.next() - val byteBuffer = nextRecord.getData() - nextBytes = new Array[Byte](byteBuffer.remaining()) - byteBuffer.get(nextBytes) + nextRecord = internalIterator.next() lastSeqNumber = nextRecord.getSequenceNumber() // If the this record's sequence number matches the stopping sequence number, then make sure @@ -182,9 +180,8 @@ class KinesisSequenceRangeIterator( toSeqNumberReceived = true } } - } - nextBytes + nextRecord } override protected def close(): Unit = { diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala index 2e4204dcb6..72ab6357a5 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala @@ -17,7 +17,10 @@ package org.apache.spark.streaming.kinesis +import scala.reflect.ClassTag + import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream +import com.amazonaws.services.kinesis.model.Record import org.apache.spark.rdd.RDD import org.apache.spark.storage.{BlockId, StorageLevel} @@ -26,7 +29,7 @@ import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.streaming.scheduler.ReceivedBlockInfo import org.apache.spark.streaming.{Duration, StreamingContext, Time} -private[kinesis] class KinesisInputDStream( +private[kinesis] class KinesisInputDStream[T: ClassTag]( @transient _ssc: StreamingContext, streamName: String, endpointUrl: String, @@ -35,11 +38,12 @@ private[kinesis] class KinesisInputDStream( checkpointAppName: String, checkpointInterval: Duration, storageLevel: StorageLevel, + messageHandler: Record => T, awsCredentialsOption: Option[SerializableAWSCredentials] - ) extends ReceiverInputDStream[Array[Byte]](_ssc) { + ) extends ReceiverInputDStream[T](_ssc) { private[streaming] - override def createBlockRDD(time: Time, blockInfos: Seq[ReceivedBlockInfo]): RDD[Array[Byte]] = { + override def createBlockRDD(time: Time, blockInfos: Seq[ReceivedBlockInfo]): RDD[T] = { // This returns true even for when blockInfos is empty val allBlocksHaveRanges = blockInfos.map { _.metadataOption }.forall(_.nonEmpty) @@ -56,6 +60,7 @@ private[kinesis] class KinesisInputDStream( context.sc, regionName, endpointUrl, blockIds, seqNumRanges, isBlockIdValid = isBlockIdValid, retryTimeoutMs = ssc.graph.batchDuration.milliseconds.toInt, + messageHandler = messageHandler, awsCredentialsOption = awsCredentialsOption) } else { logWarning("Kinesis sequence number information was not present with some block metadata," + @@ -64,8 +69,8 @@ private[kinesis] class KinesisInputDStream( } } - override def getReceiver(): Receiver[Array[Byte]] = { + override def getReceiver(): Receiver[T] = { new KinesisReceiver(streamName, endpointUrl, regionName, initialPositionInStream, - checkpointAppName, checkpointInterval, storageLevel, awsCredentialsOption) + checkpointAppName, checkpointInterval, storageLevel, messageHandler, awsCredentialsOption) } } diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 6e0988c1af..134d627cda 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -80,7 +80,7 @@ case class SerializableAWSCredentials(accessKeyId: String, secretKey: String) * @param awsCredentialsOption Optional AWS credentials, used when user directly specifies * the credentials */ -private[kinesis] class KinesisReceiver( +private[kinesis] class KinesisReceiver[T]( val streamName: String, endpointUrl: String, regionName: String, @@ -88,8 +88,9 @@ private[kinesis] class KinesisReceiver( checkpointAppName: String, checkpointInterval: Duration, storageLevel: StorageLevel, - awsCredentialsOption: Option[SerializableAWSCredentials] - ) extends Receiver[Array[Byte]](storageLevel) with Logging { receiver => + messageHandler: Record => T, + awsCredentialsOption: Option[SerializableAWSCredentials]) + extends Receiver[T](storageLevel) with Logging { receiver => /* * ================================================================================= @@ -202,12 +203,7 @@ private[kinesis] class KinesisReceiver( /** Add records of the given shard to the current block being generated */ private[kinesis] def addRecords(shardId: String, records: java.util.List[Record]): Unit = { if (records.size > 0) { - val dataIterator = records.iterator().asScala.map { record => - val byteBuffer = record.getData() - val byteArray = new Array[Byte](byteBuffer.remaining()) - byteBuffer.get(byteArray) - byteArray - } + val dataIterator = records.iterator().asScala.map(messageHandler) val metadata = SequenceNumberRange(streamName, shardId, records.get(0).getSequenceNumber(), records.get(records.size() - 1).getSequenceNumber()) blockGenerator.addMultipleDataWithCallback(dataIterator, metadata) @@ -240,7 +236,7 @@ private[kinesis] class KinesisReceiver( /** Store the block along with its associated ranges */ private def storeBlockWithRanges( - blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[Array[Byte]]): Unit = { + blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[T]): Unit = { val rangesToReportOption = blockIdToSeqNumRanges.remove(blockId) if (rangesToReportOption.isEmpty) { stop("Error while storing block into Spark, could not find sequence number ranges " + @@ -325,7 +321,7 @@ private[kinesis] class KinesisReceiver( /** Callback method called when a block is ready to be pushed / stored. */ def onPushBlock(blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = { storeBlockWithRanges(blockId, - arrayBuffer.asInstanceOf[mutable.ArrayBuffer[Array[Byte]]]) + arrayBuffer.asInstanceOf[mutable.ArrayBuffer[T]]) } /** Callback called in case of any error in internal of the BlockGenerator */ diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala index b240512332..1d5178790e 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala @@ -41,8 +41,8 @@ import org.apache.spark.Logging * @param checkpointState represents the checkpoint state including the next checkpoint time. * It's injected here for mocking purposes. */ -private[kinesis] class KinesisRecordProcessor( - receiver: KinesisReceiver, +private[kinesis] class KinesisRecordProcessor[T]( + receiver: KinesisReceiver[T], workerId: String, checkpointState: KinesisCheckpointState) extends IRecordProcessor with Logging { diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala index c799fadf2d..2849fd8a82 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -16,16 +16,120 @@ */ package org.apache.spark.streaming.kinesis +import scala.reflect.ClassTag + import com.amazonaws.regions.RegionUtils import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream +import com.amazonaws.services.kinesis.model.Record +import org.apache.spark.api.java.function.{Function => JFunction} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext} import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.{Duration, StreamingContext} - object KinesisUtils { + /** + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. + * + * Note: The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets the AWS credentials. + * + * @param ssc StreamingContext object + * @param kinesisAppName Kinesis application name used by the Kinesis Client Library + * (KCL) to update DynamoDB + * @param streamName Kinesis stream name + * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Name of region used by the Kinesis Client Library (KCL) to update + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param storageLevel Storage level to use for storing the received objects. + * StorageLevel.MEMORY_AND_DISK_2 is recommended. + * @param messageHandler A custom message handler that can generate a generic output from a + * Kinesis `Record`, which contains both message data, and metadata. + */ + def createStream[T: ClassTag]( + ssc: StreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream, + checkpointInterval: Duration, + storageLevel: StorageLevel, + messageHandler: Record => T): ReceiverInputDStream[T] = { + val cleanedHandler = ssc.sc.clean(messageHandler) + // Setting scope to override receiver stream's scope of "receiver stream" + ssc.withNamedScope("kinesis stream") { + new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName), + initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, + cleanedHandler, None) + } + } + + /** + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. + * + * Note: + * The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. + * + * @param ssc StreamingContext object + * @param kinesisAppName Kinesis application name used by the Kinesis Client Library + * (KCL) to update DynamoDB + * @param streamName Kinesis stream name + * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Name of region used by the Kinesis Client Library (KCL) to update + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param storageLevel Storage level to use for storing the received objects. + * StorageLevel.MEMORY_AND_DISK_2 is recommended. + * @param messageHandler A custom message handler that can generate a generic output from a + * Kinesis `Record`, which contains both message data, and metadata. + * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) + * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) + */ + // scalastyle:off + def createStream[T: ClassTag]( + ssc: StreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream, + checkpointInterval: Duration, + storageLevel: StorageLevel, + messageHandler: Record => T, + awsAccessKeyId: String, + awsSecretKey: String): ReceiverInputDStream[T] = { + // scalastyle:on + val cleanedHandler = ssc.sc.clean(messageHandler) + ssc.withNamedScope("kinesis stream") { + new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName), + initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, + cleanedHandler, Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey))) + } + } + /** * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. @@ -61,12 +165,12 @@ object KinesisUtils { regionName: String, initialPositionInStream: InitialPositionInStream, checkpointInterval: Duration, - storageLevel: StorageLevel - ): ReceiverInputDStream[Array[Byte]] = { + storageLevel: StorageLevel): ReceiverInputDStream[Array[Byte]] = { // Setting scope to override receiver stream's scope of "receiver stream" ssc.withNamedScope("kinesis stream") { - new KinesisInputDStream(ssc, streamName, endpointUrl, validateRegion(regionName), - initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, None) + new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, validateRegion(regionName), + initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, + defaultMessageHandler, None) } } @@ -109,12 +213,11 @@ object KinesisUtils { checkpointInterval: Duration, storageLevel: StorageLevel, awsAccessKeyId: String, - awsSecretKey: String - ): ReceiverInputDStream[Array[Byte]] = { + awsSecretKey: String): ReceiverInputDStream[Array[Byte]] = { ssc.withNamedScope("kinesis stream") { - new KinesisInputDStream(ssc, streamName, endpointUrl, validateRegion(regionName), + new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, validateRegion(regionName), initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, - Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey))) + defaultMessageHandler, Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey))) } } @@ -156,11 +259,113 @@ object KinesisUtils { storageLevel: StorageLevel ): ReceiverInputDStream[Array[Byte]] = { ssc.withNamedScope("kinesis stream") { - new KinesisInputDStream(ssc, streamName, endpointUrl, getRegionByEndpoint(endpointUrl), - initialPositionInStream, ssc.sc.appName, checkpointInterval, storageLevel, None) + new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, + getRegionByEndpoint(endpointUrl), initialPositionInStream, ssc.sc.appName, + checkpointInterval, storageLevel, defaultMessageHandler, None) } } + /** + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. + * + * Note: The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets the AWS credentials. + * + * @param jssc Java StreamingContext object + * @param kinesisAppName Kinesis application name used by the Kinesis Client Library + * (KCL) to update DynamoDB + * @param streamName Kinesis stream name + * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Name of region used by the Kinesis Client Library (KCL) to update + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param storageLevel Storage level to use for storing the received objects. + * StorageLevel.MEMORY_AND_DISK_2 is recommended. + * @param messageHandler A custom message handler that can generate a generic output from a + * Kinesis `Record`, which contains both message data, and metadata. + * @param recordClass Class of the records in DStream + */ + def createStream[T]( + jssc: JavaStreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream, + checkpointInterval: Duration, + storageLevel: StorageLevel, + messageHandler: JFunction[Record, T], + recordClass: Class[T]): JavaReceiverInputDStream[T] = { + implicit val recordCmt: ClassTag[T] = ClassTag(recordClass) + val cleanedHandler = jssc.sparkContext.clean(messageHandler.call(_)) + createStream[T](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, + initialPositionInStream, checkpointInterval, storageLevel, cleanedHandler) + } + + /** + * Create an input stream that pulls messages from a Kinesis stream. + * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. + * + * Note: + * The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. + * + * @param jssc Java StreamingContext object + * @param kinesisAppName Kinesis application name used by the Kinesis Client Library + * (KCL) to update DynamoDB + * @param streamName Kinesis stream name + * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param regionName Name of region used by the Kinesis Client Library (KCL) to update + * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param storageLevel Storage level to use for storing the received objects. + * StorageLevel.MEMORY_AND_DISK_2 is recommended. + * @param messageHandler A custom message handler that can generate a generic output from a + * Kinesis `Record`, which contains both message data, and metadata. + * @param recordClass Class of the records in DStream + * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) + * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) + */ + // scalastyle:off + def createStream[T]( + jssc: JavaStreamingContext, + kinesisAppName: String, + streamName: String, + endpointUrl: String, + regionName: String, + initialPositionInStream: InitialPositionInStream, + checkpointInterval: Duration, + storageLevel: StorageLevel, + messageHandler: JFunction[Record, T], + recordClass: Class[T], + awsAccessKeyId: String, + awsSecretKey: String): JavaReceiverInputDStream[T] = { + // scalastyle:on + implicit val recordCmt: ClassTag[T] = ClassTag(recordClass) + val cleanedHandler = jssc.sparkContext.clean(messageHandler.call(_)) + createStream[T](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, + initialPositionInStream, checkpointInterval, storageLevel, cleanedHandler, + awsAccessKeyId, awsSecretKey) + } + /** * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. @@ -198,8 +403,8 @@ object KinesisUtils { checkpointInterval: Duration, storageLevel: StorageLevel ): JavaReceiverInputDStream[Array[Byte]] = { - createStream(jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, - initialPositionInStream, checkpointInterval, storageLevel) + createStream[Array[Byte]](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, + initialPositionInStream, checkpointInterval, storageLevel, defaultMessageHandler(_)) } /** @@ -241,10 +446,10 @@ object KinesisUtils { checkpointInterval: Duration, storageLevel: StorageLevel, awsAccessKeyId: String, - awsSecretKey: String - ): JavaReceiverInputDStream[Array[Byte]] = { - createStream(jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, - initialPositionInStream, checkpointInterval, storageLevel, awsAccessKeyId, awsSecretKey) + awsSecretKey: String): JavaReceiverInputDStream[Array[Byte]] = { + createStream[Array[Byte]](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName, + initialPositionInStream, checkpointInterval, storageLevel, + defaultMessageHandler(_), awsAccessKeyId, awsSecretKey) } /** @@ -297,6 +502,14 @@ object KinesisUtils { throw new IllegalArgumentException(s"Region name '$regionName' is not valid") } } + + private[kinesis] def defaultMessageHandler(record: Record): Array[Byte] = { + if (record == null) return null + val byteBuffer = record.getData() + val byteArray = new Array[Byte](byteBuffer.remaining()) + byteBuffer.get(byteArray) + byteArray + } } /** diff --git a/extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java b/extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java index 87954a31f6..3f0f6793d2 100644 --- a/extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java +++ b/extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java @@ -17,14 +17,19 @@ package org.apache.spark.streaming.kinesis; +import com.amazonaws.services.kinesis.model.Record; +import org.junit.Test; + +import org.apache.spark.api.java.function.Function; import org.apache.spark.storage.StorageLevel; import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.LocalJavaStreamingContext; import org.apache.spark.streaming.api.java.JavaDStream; -import org.junit.Test; import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; +import java.nio.ByteBuffer; + /** * Demonstrate the use of the KinesisUtils Java API */ @@ -33,9 +38,27 @@ public class JavaKinesisStreamSuite extends LocalJavaStreamingContext { public void testKinesisStream() { // Tests the API, does not actually test data receiving JavaDStream kinesisStream = KinesisUtils.createStream(ssc, "mySparkStream", - "https://kinesis.us-west-2.amazonaws.com", new Duration(2000), + "https://kinesis.us-west-2.amazonaws.com", new Duration(2000), InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2()); - + + ssc.stop(); + } + + + private static Function handler = new Function() { + @Override + public String call(Record record) { + return record.getPartitionKey() + "-" + record.getSequenceNumber(); + } + }; + + @Test + public void testCustomHandler() { + // Tests the API, does not actually test data receiving + JavaDStream kinesisStream = KinesisUtils.createStream(ssc, "testApp", "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", "us-west-2", InitialPositionInStream.LATEST, + new Duration(2000), StorageLevel.MEMORY_AND_DISK_2(), handler, String.class); + ssc.stop(); } } diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala index a89e5627e0..9f9e146a08 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala @@ -73,22 +73,22 @@ class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll testIfEnabled("Basic reading from Kinesis") { // Verify all data using multiple ranges in a single RDD partition - val receivedData1 = new KinesisBackedBlockRDD(sc, testUtils.regionName, testUtils.endpointUrl, - fakeBlockIds(1), + val receivedData1 = new KinesisBackedBlockRDD[Array[Byte]](sc, testUtils.regionName, + testUtils.endpointUrl, fakeBlockIds(1), Array(SequenceNumberRanges(allRanges.toArray)) ).map { bytes => new String(bytes).toInt }.collect() assert(receivedData1.toSet === testData.toSet) // Verify all data using one range in each of the multiple RDD partitions - val receivedData2 = new KinesisBackedBlockRDD(sc, testUtils.regionName, testUtils.endpointUrl, - fakeBlockIds(allRanges.size), + val receivedData2 = new KinesisBackedBlockRDD[Array[Byte]](sc, testUtils.regionName, + testUtils.endpointUrl, fakeBlockIds(allRanges.size), allRanges.map { range => SequenceNumberRanges(Array(range)) }.toArray ).map { bytes => new String(bytes).toInt }.collect() assert(receivedData2.toSet === testData.toSet) // Verify ordering within each partition - val receivedData3 = new KinesisBackedBlockRDD(sc, testUtils.regionName, testUtils.endpointUrl, - fakeBlockIds(allRanges.size), + val receivedData3 = new KinesisBackedBlockRDD[Array[Byte]](sc, testUtils.regionName, + testUtils.endpointUrl, fakeBlockIds(allRanges.size), allRanges.map { range => SequenceNumberRanges(Array(range)) }.toArray ).map { bytes => new String(bytes).toInt }.collectPartitions() assert(receivedData3.length === allRanges.size) @@ -209,7 +209,7 @@ class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll }, "Incorrect configuration of RDD, unexpected ranges set" ) - val rdd = new KinesisBackedBlockRDD( + val rdd = new KinesisBackedBlockRDD[Array[Byte]]( sc, testUtils.regionName, testUtils.endpointUrl, blockIds, ranges) val collectedData = rdd.map { bytes => new String(bytes).toInt @@ -223,7 +223,7 @@ class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll if (testIsBlockValid) { require(numPartitionsInBM === numPartitions, "All partitions must be in BlockManager") require(numPartitionsInKinesis === 0, "No partitions must be in Kinesis") - val rdd2 = new KinesisBackedBlockRDD( + val rdd2 = new KinesisBackedBlockRDD[Array[Byte]]( sc, testUtils.regionName, testUtils.endpointUrl, blockIds.toArray, ranges, isBlockIdValid = Array.fill(blockIds.length)(false)) intercept[SparkException] { diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala index 3d136aec2e..17ab444704 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala @@ -52,14 +52,14 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft record2.setData(ByteBuffer.wrap("Learning Spark".getBytes(StandardCharsets.UTF_8))) val batch = Arrays.asList(record1, record2) - var receiverMock: KinesisReceiver = _ + var receiverMock: KinesisReceiver[Array[Byte]] = _ var checkpointerMock: IRecordProcessorCheckpointer = _ var checkpointClockMock: ManualClock = _ var checkpointStateMock: KinesisCheckpointState = _ var currentClockMock: Clock = _ override def beforeFunction(): Unit = { - receiverMock = mock[KinesisReceiver] + receiverMock = mock[KinesisReceiver[Array[Byte]]] checkpointerMock = mock[IRecordProcessorCheckpointer] checkpointClockMock = mock[ManualClock] checkpointStateMock = mock[KinesisCheckpointState] diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index 1177dc7581..ba84e557df 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -24,6 +24,7 @@ import scala.util.Random import com.amazonaws.regions.RegionUtils import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream +import com.amazonaws.services.kinesis.model.Record import org.scalatest.Matchers._ import org.scalatest.concurrent.Eventually import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} @@ -31,6 +32,7 @@ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.apache.spark.rdd.RDD import org.apache.spark.storage.{StorageLevel, StreamBlockId} import org.apache.spark.streaming._ +import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.kinesis.KinesisTestUtils._ import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult import org.apache.spark.streaming.scheduler.ReceivedBlockInfo @@ -113,9 +115,9 @@ class KinesisStreamSuite extends KinesisFunSuite val inputStream = KinesisUtils.createStream(ssc, appName, "dummyStream", dummyEndpointUrl, dummyRegionName, InitialPositionInStream.LATEST, Seconds(2), StorageLevel.MEMORY_AND_DISK_2, dummyAWSAccessKey, dummyAWSSecretKey) - assert(inputStream.isInstanceOf[KinesisInputDStream]) + assert(inputStream.isInstanceOf[KinesisInputDStream[Array[Byte]]]) - val kinesisStream = inputStream.asInstanceOf[KinesisInputDStream] + val kinesisStream = inputStream.asInstanceOf[KinesisInputDStream[Array[Byte]]] val time = Time(1000) // Generate block info data for testing @@ -134,8 +136,8 @@ class KinesisStreamSuite extends KinesisFunSuite // Verify that the generated KinesisBackedBlockRDD has the all the right information val blockInfos = Seq(blockInfo1, blockInfo2) val nonEmptyRDD = kinesisStream.createBlockRDD(time, blockInfos) - nonEmptyRDD shouldBe a [KinesisBackedBlockRDD] - val kinesisRDD = nonEmptyRDD.asInstanceOf[KinesisBackedBlockRDD] + nonEmptyRDD shouldBe a [KinesisBackedBlockRDD[Array[Byte]]] + val kinesisRDD = nonEmptyRDD.asInstanceOf[KinesisBackedBlockRDD[Array[Byte]]] assert(kinesisRDD.regionName === dummyRegionName) assert(kinesisRDD.endpointUrl === dummyEndpointUrl) assert(kinesisRDD.retryTimeoutMs === batchDuration.milliseconds) @@ -151,7 +153,7 @@ class KinesisStreamSuite extends KinesisFunSuite // Verify that KinesisBackedBlockRDD is generated even when there are no blocks val emptyRDD = kinesisStream.createBlockRDD(time, Seq.empty) - emptyRDD shouldBe a [KinesisBackedBlockRDD] + emptyRDD shouldBe a [KinesisBackedBlockRDD[Array[Byte]]] emptyRDD.partitions shouldBe empty // Verify that the KinesisBackedBlockRDD has isBlockValid = false when blocks are invalid @@ -192,6 +194,32 @@ class KinesisStreamSuite extends KinesisFunSuite ssc.stop(stopSparkContext = false) } + testIfEnabled("custom message handling") { + val awsCredentials = KinesisTestUtils.getAWSCredentials() + def addFive(r: Record): Int = new String(r.getData.array()).toInt + 5 + val stream = KinesisUtils.createStream(ssc, appName, testUtils.streamName, + testUtils.endpointUrl, testUtils.regionName, InitialPositionInStream.LATEST, + Seconds(10), StorageLevel.MEMORY_ONLY, addFive, + awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey) + + stream shouldBe a [ReceiverInputDStream[Int]] + + val collected = new mutable.HashSet[Int] with mutable.SynchronizedSet[Int] + stream.foreachRDD { rdd => + collected ++= rdd.collect() + logInfo("Collected = " + rdd.collect().toSeq.mkString(", ")) + } + ssc.start() + + val testData = 1 to 10 + eventually(timeout(120 seconds), interval(10 second)) { + testUtils.pushData(testData) + val modData = testData.map(_ + 5) + assert(collected === modData.toSet, "\nData received does not match data sent") + } + ssc.stop(stopSparkContext = false) + } + testIfEnabled("failure recovery") { val sparkConf = new SparkConf().setMaster("local[4]").setAppName(this.getClass.getSimpleName) val checkpointDir = Utils.createTempDir().getAbsolutePath @@ -210,7 +238,7 @@ class KinesisStreamSuite extends KinesisFunSuite // Verify that the generated RDDs are KinesisBackedBlockRDDs, and collect the data in each batch kinesisStream.foreachRDD((rdd: RDD[Array[Byte]], time: Time) => { - val kRdd = rdd.asInstanceOf[KinesisBackedBlockRDD] + val kRdd = rdd.asInstanceOf[KinesisBackedBlockRDD[Array[Byte]]] val data = rdd.map { bytes => new String(bytes).toInt }.collect().toSeq collectedData(time) = (kRdd.arrayOfseqNumberRanges, data) }) @@ -243,10 +271,10 @@ class KinesisStreamSuite extends KinesisFunSuite times.foreach { time => val (arrayOfSeqNumRanges, data) = collectedData(time) val rdd = recoveredKinesisStream.getOrCompute(time).get.asInstanceOf[RDD[Array[Byte]]] - rdd shouldBe a [KinesisBackedBlockRDD] + rdd shouldBe a [KinesisBackedBlockRDD[Array[Byte]]] // Verify the recovered sequence ranges - val kRdd = rdd.asInstanceOf[KinesisBackedBlockRDD] + val kRdd = rdd.asInstanceOf[KinesisBackedBlockRDD[Array[Byte]]] assert(kRdd.arrayOfseqNumberRanges.size === arrayOfSeqNumRanges.size) arrayOfSeqNumRanges.zip(kRdd.arrayOfseqNumberRanges).foreach { case (expected, found) => assert(expected.ranges.toSeq === found.ranges.toSeq) -- cgit v1.2.3