aboutsummaryrefslogtreecommitdiff
path: root/extras
diff options
context:
space:
mode:
authorBurak Yavuz <brkyvz@gmail.com>2015-10-25 21:18:35 -0700
committerTathagata Das <tathagata.das1565@gmail.com>2015-10-25 21:18:35 -0700
commit63accc79625d8a03d0624717af5e1d81b18a6da3 (patch)
treee043db0052621a19a6040f4e813a43875fa4f9f3 /extras
parent80279ac1875d488f7000f352a958a35536bd4c2e (diff)
downloadspark-63accc79625d8a03d0624717af5e1d81b18a6da3.tar.gz
spark-63accc79625d8a03d0624717af5e1d81b18a6da3.tar.bz2
spark-63accc79625d8a03d0624717af5e1d81b18a6da3.zip
[SPARK-10891][STREAMING][KINESIS] Add MessageHandler to KinesisUtils.createStream similar to Direct Kafka
This PR allows users to map a Kinesis `Record` to a generic `T` when creating a Kinesis stream. This is particularly useful, if you would like to do extra work with Kinesis metadata such as sequence number, and partition key. TODO: - [x] add tests Author: Burak Yavuz <brkyvz@gmail.com> Closes #8954 from brkyvz/kinesis-handler.
Diffstat (limited to 'extras')
-rw-r--r--extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala35
-rw-r--r--extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala15
-rw-r--r--extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala18
-rw-r--r--extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala4
-rw-r--r--extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala247
-rw-r--r--extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java29
-rw-r--r--extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala16
-rw-r--r--extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala4
-rw-r--r--extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala44
9 files changed, 337 insertions, 75 deletions
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala
index 5d32fa699a..000897a4e7 100644
--- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala
@@ -18,6 +18,7 @@
package org.apache.spark.streaming.kinesis
import scala.collection.JavaConverters._
+import scala.reflect.ClassTag
import scala.util.control.NonFatal
import com.amazonaws.auth.{AWSCredentials, DefaultAWSCredentialsProviderChain}
@@ -67,7 +68,7 @@ class KinesisBackedBlockRDDPartition(
* sequence numbers of the corresponding blocks.
*/
private[kinesis]
-class KinesisBackedBlockRDD(
+class KinesisBackedBlockRDD[T: ClassTag](
@transient sc: SparkContext,
val regionName: String,
val endpointUrl: String,
@@ -75,8 +76,9 @@ class KinesisBackedBlockRDD(
@transient val arrayOfseqNumberRanges: Array[SequenceNumberRanges],
@transient isBlockIdValid: Array[Boolean] = Array.empty,
val retryTimeoutMs: Int = 10000,
+ val messageHandler: Record => T = KinesisUtils.defaultMessageHandler _,
val awsCredentialsOption: Option[SerializableAWSCredentials] = None
- ) extends BlockRDD[Array[Byte]](sc, blockIds) {
+ ) extends BlockRDD[T](sc, blockIds) {
require(blockIds.length == arrayOfseqNumberRanges.length,
"Number of blockIds is not equal to the number of sequence number ranges")
@@ -90,23 +92,23 @@ class KinesisBackedBlockRDD(
}
}
- override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
+ override def compute(split: Partition, context: TaskContext): Iterator[T] = {
val blockManager = SparkEnv.get.blockManager
val partition = split.asInstanceOf[KinesisBackedBlockRDDPartition]
val blockId = partition.blockId
- def getBlockFromBlockManager(): Option[Iterator[Array[Byte]]] = {
+ def getBlockFromBlockManager(): Option[Iterator[T]] = {
logDebug(s"Read partition data of $this from block manager, block $blockId")
- blockManager.get(blockId).map(_.data.asInstanceOf[Iterator[Array[Byte]]])
+ blockManager.get(blockId).map(_.data.asInstanceOf[Iterator[T]])
}
- def getBlockFromKinesis(): Iterator[Array[Byte]] = {
- val credenentials = awsCredentialsOption.getOrElse {
+ def getBlockFromKinesis(): Iterator[T] = {
+ val credentials = awsCredentialsOption.getOrElse {
new DefaultAWSCredentialsProviderChain().getCredentials()
}
partition.seqNumberRanges.ranges.iterator.flatMap { range =>
- new KinesisSequenceRangeIterator(
- credenentials, endpointUrl, regionName, range, retryTimeoutMs)
+ new KinesisSequenceRangeIterator(credentials, endpointUrl, regionName,
+ range, retryTimeoutMs).map(messageHandler)
}
}
if (partition.isBlockIdValid) {
@@ -129,8 +131,7 @@ class KinesisSequenceRangeIterator(
endpointUrl: String,
regionId: String,
range: SequenceNumberRange,
- retryTimeoutMs: Int
- ) extends NextIterator[Array[Byte]] with Logging {
+ retryTimeoutMs: Int) extends NextIterator[Record] with Logging {
private val client = new AmazonKinesisClient(credentials)
private val streamName = range.streamName
@@ -142,8 +143,8 @@ class KinesisSequenceRangeIterator(
client.setEndpoint(endpointUrl, "kinesis", regionId)
- override protected def getNext(): Array[Byte] = {
- var nextBytes: Array[Byte] = null
+ override protected def getNext(): Record = {
+ var nextRecord: Record = null
if (toSeqNumberReceived) {
finished = true
} else {
@@ -170,10 +171,7 @@ class KinesisSequenceRangeIterator(
} else {
// Get the record, copy the data into a byte array and remember its sequence number
- val nextRecord: Record = internalIterator.next()
- val byteBuffer = nextRecord.getData()
- nextBytes = new Array[Byte](byteBuffer.remaining())
- byteBuffer.get(nextBytes)
+ nextRecord = internalIterator.next()
lastSeqNumber = nextRecord.getSequenceNumber()
// If the this record's sequence number matches the stopping sequence number, then make sure
@@ -182,9 +180,8 @@ class KinesisSequenceRangeIterator(
toSeqNumberReceived = true
}
}
-
}
- nextBytes
+ nextRecord
}
override protected def close(): Unit = {
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala
index 2e4204dcb6..72ab6357a5 100644
--- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala
@@ -17,7 +17,10 @@
package org.apache.spark.streaming.kinesis
+import scala.reflect.ClassTag
+
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream
+import com.amazonaws.services.kinesis.model.Record
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.{BlockId, StorageLevel}
@@ -26,7 +29,7 @@ import org.apache.spark.streaming.receiver.Receiver
import org.apache.spark.streaming.scheduler.ReceivedBlockInfo
import org.apache.spark.streaming.{Duration, StreamingContext, Time}
-private[kinesis] class KinesisInputDStream(
+private[kinesis] class KinesisInputDStream[T: ClassTag](
@transient _ssc: StreamingContext,
streamName: String,
endpointUrl: String,
@@ -35,11 +38,12 @@ private[kinesis] class KinesisInputDStream(
checkpointAppName: String,
checkpointInterval: Duration,
storageLevel: StorageLevel,
+ messageHandler: Record => T,
awsCredentialsOption: Option[SerializableAWSCredentials]
- ) extends ReceiverInputDStream[Array[Byte]](_ssc) {
+ ) extends ReceiverInputDStream[T](_ssc) {
private[streaming]
- override def createBlockRDD(time: Time, blockInfos: Seq[ReceivedBlockInfo]): RDD[Array[Byte]] = {
+ override def createBlockRDD(time: Time, blockInfos: Seq[ReceivedBlockInfo]): RDD[T] = {
// This returns true even for when blockInfos is empty
val allBlocksHaveRanges = blockInfos.map { _.metadataOption }.forall(_.nonEmpty)
@@ -56,6 +60,7 @@ private[kinesis] class KinesisInputDStream(
context.sc, regionName, endpointUrl, blockIds, seqNumRanges,
isBlockIdValid = isBlockIdValid,
retryTimeoutMs = ssc.graph.batchDuration.milliseconds.toInt,
+ messageHandler = messageHandler,
awsCredentialsOption = awsCredentialsOption)
} else {
logWarning("Kinesis sequence number information was not present with some block metadata," +
@@ -64,8 +69,8 @@ private[kinesis] class KinesisInputDStream(
}
}
- override def getReceiver(): Receiver[Array[Byte]] = {
+ override def getReceiver(): Receiver[T] = {
new KinesisReceiver(streamName, endpointUrl, regionName, initialPositionInStream,
- checkpointAppName, checkpointInterval, storageLevel, awsCredentialsOption)
+ checkpointAppName, checkpointInterval, storageLevel, messageHandler, awsCredentialsOption)
}
}
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
index 6e0988c1af..134d627cda 100644
--- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
@@ -80,7 +80,7 @@ case class SerializableAWSCredentials(accessKeyId: String, secretKey: String)
* @param awsCredentialsOption Optional AWS credentials, used when user directly specifies
* the credentials
*/
-private[kinesis] class KinesisReceiver(
+private[kinesis] class KinesisReceiver[T](
val streamName: String,
endpointUrl: String,
regionName: String,
@@ -88,8 +88,9 @@ private[kinesis] class KinesisReceiver(
checkpointAppName: String,
checkpointInterval: Duration,
storageLevel: StorageLevel,
- awsCredentialsOption: Option[SerializableAWSCredentials]
- ) extends Receiver[Array[Byte]](storageLevel) with Logging { receiver =>
+ messageHandler: Record => T,
+ awsCredentialsOption: Option[SerializableAWSCredentials])
+ extends Receiver[T](storageLevel) with Logging { receiver =>
/*
* =================================================================================
@@ -202,12 +203,7 @@ private[kinesis] class KinesisReceiver(
/** Add records of the given shard to the current block being generated */
private[kinesis] def addRecords(shardId: String, records: java.util.List[Record]): Unit = {
if (records.size > 0) {
- val dataIterator = records.iterator().asScala.map { record =>
- val byteBuffer = record.getData()
- val byteArray = new Array[Byte](byteBuffer.remaining())
- byteBuffer.get(byteArray)
- byteArray
- }
+ val dataIterator = records.iterator().asScala.map(messageHandler)
val metadata = SequenceNumberRange(streamName, shardId,
records.get(0).getSequenceNumber(), records.get(records.size() - 1).getSequenceNumber())
blockGenerator.addMultipleDataWithCallback(dataIterator, metadata)
@@ -240,7 +236,7 @@ private[kinesis] class KinesisReceiver(
/** Store the block along with its associated ranges */
private def storeBlockWithRanges(
- blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[Array[Byte]]): Unit = {
+ blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[T]): Unit = {
val rangesToReportOption = blockIdToSeqNumRanges.remove(blockId)
if (rangesToReportOption.isEmpty) {
stop("Error while storing block into Spark, could not find sequence number ranges " +
@@ -325,7 +321,7 @@ private[kinesis] class KinesisReceiver(
/** Callback method called when a block is ready to be pushed / stored. */
def onPushBlock(blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = {
storeBlockWithRanges(blockId,
- arrayBuffer.asInstanceOf[mutable.ArrayBuffer[Array[Byte]]])
+ arrayBuffer.asInstanceOf[mutable.ArrayBuffer[T]])
}
/** Callback called in case of any error in internal of the BlockGenerator */
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala
index b240512332..1d5178790e 100644
--- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala
@@ -41,8 +41,8 @@ import org.apache.spark.Logging
* @param checkpointState represents the checkpoint state including the next checkpoint time.
* It's injected here for mocking purposes.
*/
-private[kinesis] class KinesisRecordProcessor(
- receiver: KinesisReceiver,
+private[kinesis] class KinesisRecordProcessor[T](
+ receiver: KinesisReceiver[T],
workerId: String,
checkpointState: KinesisCheckpointState) extends IRecordProcessor with Logging {
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala
index c799fadf2d..2849fd8a82 100644
--- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala
@@ -16,15 +16,18 @@
*/
package org.apache.spark.streaming.kinesis
+import scala.reflect.ClassTag
+
import com.amazonaws.regions.RegionUtils
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream
+import com.amazonaws.services.kinesis.model.Record
+import org.apache.spark.api.java.function.{Function => JFunction}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext}
import org.apache.spark.streaming.dstream.ReceiverInputDStream
import org.apache.spark.streaming.{Duration, StreamingContext}
-
object KinesisUtils {
/**
* Create an input stream that pulls messages from a Kinesis stream.
@@ -52,6 +55,107 @@ object KinesisUtils {
* details on the different types of checkpoints.
* @param storageLevel Storage level to use for storing the received objects.
* StorageLevel.MEMORY_AND_DISK_2 is recommended.
+ * @param messageHandler A custom message handler that can generate a generic output from a
+ * Kinesis `Record`, which contains both message data, and metadata.
+ */
+ def createStream[T: ClassTag](
+ ssc: StreamingContext,
+ kinesisAppName: String,
+ streamName: String,
+ endpointUrl: String,
+ regionName: String,
+ initialPositionInStream: InitialPositionInStream,
+ checkpointInterval: Duration,
+ storageLevel: StorageLevel,
+ messageHandler: Record => T): ReceiverInputDStream[T] = {
+ val cleanedHandler = ssc.sc.clean(messageHandler)
+ // Setting scope to override receiver stream's scope of "receiver stream"
+ ssc.withNamedScope("kinesis stream") {
+ new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName),
+ initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel,
+ cleanedHandler, None)
+ }
+ }
+
+ /**
+ * Create an input stream that pulls messages from a Kinesis stream.
+ * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis.
+ *
+ * Note:
+ * The given AWS credentials will get saved in DStream checkpoints if checkpointing
+ * is enabled. Make sure that your checkpoint directory is secure.
+ *
+ * @param ssc StreamingContext object
+ * @param kinesisAppName Kinesis application name used by the Kinesis Client Library
+ * (KCL) to update DynamoDB
+ * @param streamName Kinesis stream name
+ * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com)
+ * @param regionName Name of region used by the Kinesis Client Library (KCL) to update
+ * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics)
+ * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the
+ * worker's initial starting position in the stream.
+ * The values are either the beginning of the stream
+ * per Kinesis' limit of 24 hours
+ * (InitialPositionInStream.TRIM_HORIZON) or
+ * the tip of the stream (InitialPositionInStream.LATEST).
+ * @param checkpointInterval Checkpoint interval for Kinesis checkpointing.
+ * See the Kinesis Spark Streaming documentation for more
+ * details on the different types of checkpoints.
+ * @param storageLevel Storage level to use for storing the received objects.
+ * StorageLevel.MEMORY_AND_DISK_2 is recommended.
+ * @param messageHandler A custom message handler that can generate a generic output from a
+ * Kinesis `Record`, which contains both message data, and metadata.
+ * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain)
+ * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain)
+ */
+ // scalastyle:off
+ def createStream[T: ClassTag](
+ ssc: StreamingContext,
+ kinesisAppName: String,
+ streamName: String,
+ endpointUrl: String,
+ regionName: String,
+ initialPositionInStream: InitialPositionInStream,
+ checkpointInterval: Duration,
+ storageLevel: StorageLevel,
+ messageHandler: Record => T,
+ awsAccessKeyId: String,
+ awsSecretKey: String): ReceiverInputDStream[T] = {
+ // scalastyle:on
+ val cleanedHandler = ssc.sc.clean(messageHandler)
+ ssc.withNamedScope("kinesis stream") {
+ new KinesisInputDStream[T](ssc, streamName, endpointUrl, validateRegion(regionName),
+ initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel,
+ cleanedHandler, Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey)))
+ }
+ }
+
+ /**
+ * Create an input stream that pulls messages from a Kinesis stream.
+ * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis.
+ *
+ * Note: The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain
+ * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain
+ * gets the AWS credentials.
+ *
+ * @param ssc StreamingContext object
+ * @param kinesisAppName Kinesis application name used by the Kinesis Client Library
+ * (KCL) to update DynamoDB
+ * @param streamName Kinesis stream name
+ * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com)
+ * @param regionName Name of region used by the Kinesis Client Library (KCL) to update
+ * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics)
+ * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the
+ * worker's initial starting position in the stream.
+ * The values are either the beginning of the stream
+ * per Kinesis' limit of 24 hours
+ * (InitialPositionInStream.TRIM_HORIZON) or
+ * the tip of the stream (InitialPositionInStream.LATEST).
+ * @param checkpointInterval Checkpoint interval for Kinesis checkpointing.
+ * See the Kinesis Spark Streaming documentation for more
+ * details on the different types of checkpoints.
+ * @param storageLevel Storage level to use for storing the received objects.
+ * StorageLevel.MEMORY_AND_DISK_2 is recommended.
*/
def createStream(
ssc: StreamingContext,
@@ -61,12 +165,12 @@ object KinesisUtils {
regionName: String,
initialPositionInStream: InitialPositionInStream,
checkpointInterval: Duration,
- storageLevel: StorageLevel
- ): ReceiverInputDStream[Array[Byte]] = {
+ storageLevel: StorageLevel): ReceiverInputDStream[Array[Byte]] = {
// Setting scope to override receiver stream's scope of "receiver stream"
ssc.withNamedScope("kinesis stream") {
- new KinesisInputDStream(ssc, streamName, endpointUrl, validateRegion(regionName),
- initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel, None)
+ new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, validateRegion(regionName),
+ initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel,
+ defaultMessageHandler, None)
}
}
@@ -109,12 +213,11 @@ object KinesisUtils {
checkpointInterval: Duration,
storageLevel: StorageLevel,
awsAccessKeyId: String,
- awsSecretKey: String
- ): ReceiverInputDStream[Array[Byte]] = {
+ awsSecretKey: String): ReceiverInputDStream[Array[Byte]] = {
ssc.withNamedScope("kinesis stream") {
- new KinesisInputDStream(ssc, streamName, endpointUrl, validateRegion(regionName),
+ new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl, validateRegion(regionName),
initialPositionInStream, kinesisAppName, checkpointInterval, storageLevel,
- Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey)))
+ defaultMessageHandler, Some(SerializableAWSCredentials(awsAccessKeyId, awsSecretKey)))
}
}
@@ -156,8 +259,9 @@ object KinesisUtils {
storageLevel: StorageLevel
): ReceiverInputDStream[Array[Byte]] = {
ssc.withNamedScope("kinesis stream") {
- new KinesisInputDStream(ssc, streamName, endpointUrl, getRegionByEndpoint(endpointUrl),
- initialPositionInStream, ssc.sc.appName, checkpointInterval, storageLevel, None)
+ new KinesisInputDStream[Array[Byte]](ssc, streamName, endpointUrl,
+ getRegionByEndpoint(endpointUrl), initialPositionInStream, ssc.sc.appName,
+ checkpointInterval, storageLevel, defaultMessageHandler, None)
}
}
@@ -187,6 +291,107 @@ object KinesisUtils {
* details on the different types of checkpoints.
* @param storageLevel Storage level to use for storing the received objects.
* StorageLevel.MEMORY_AND_DISK_2 is recommended.
+ * @param messageHandler A custom message handler that can generate a generic output from a
+ * Kinesis `Record`, which contains both message data, and metadata.
+ * @param recordClass Class of the records in DStream
+ */
+ def createStream[T](
+ jssc: JavaStreamingContext,
+ kinesisAppName: String,
+ streamName: String,
+ endpointUrl: String,
+ regionName: String,
+ initialPositionInStream: InitialPositionInStream,
+ checkpointInterval: Duration,
+ storageLevel: StorageLevel,
+ messageHandler: JFunction[Record, T],
+ recordClass: Class[T]): JavaReceiverInputDStream[T] = {
+ implicit val recordCmt: ClassTag[T] = ClassTag(recordClass)
+ val cleanedHandler = jssc.sparkContext.clean(messageHandler.call(_))
+ createStream[T](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName,
+ initialPositionInStream, checkpointInterval, storageLevel, cleanedHandler)
+ }
+
+ /**
+ * Create an input stream that pulls messages from a Kinesis stream.
+ * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis.
+ *
+ * Note:
+ * The given AWS credentials will get saved in DStream checkpoints if checkpointing
+ * is enabled. Make sure that your checkpoint directory is secure.
+ *
+ * @param jssc Java StreamingContext object
+ * @param kinesisAppName Kinesis application name used by the Kinesis Client Library
+ * (KCL) to update DynamoDB
+ * @param streamName Kinesis stream name
+ * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com)
+ * @param regionName Name of region used by the Kinesis Client Library (KCL) to update
+ * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics)
+ * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the
+ * worker's initial starting position in the stream.
+ * The values are either the beginning of the stream
+ * per Kinesis' limit of 24 hours
+ * (InitialPositionInStream.TRIM_HORIZON) or
+ * the tip of the stream (InitialPositionInStream.LATEST).
+ * @param checkpointInterval Checkpoint interval for Kinesis checkpointing.
+ * See the Kinesis Spark Streaming documentation for more
+ * details on the different types of checkpoints.
+ * @param storageLevel Storage level to use for storing the received objects.
+ * StorageLevel.MEMORY_AND_DISK_2 is recommended.
+ * @param messageHandler A custom message handler that can generate a generic output from a
+ * Kinesis `Record`, which contains both message data, and metadata.
+ * @param recordClass Class of the records in DStream
+ * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain)
+ * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain)
+ */
+ // scalastyle:off
+ def createStream[T](
+ jssc: JavaStreamingContext,
+ kinesisAppName: String,
+ streamName: String,
+ endpointUrl: String,
+ regionName: String,
+ initialPositionInStream: InitialPositionInStream,
+ checkpointInterval: Duration,
+ storageLevel: StorageLevel,
+ messageHandler: JFunction[Record, T],
+ recordClass: Class[T],
+ awsAccessKeyId: String,
+ awsSecretKey: String): JavaReceiverInputDStream[T] = {
+ // scalastyle:on
+ implicit val recordCmt: ClassTag[T] = ClassTag(recordClass)
+ val cleanedHandler = jssc.sparkContext.clean(messageHandler.call(_))
+ createStream[T](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName,
+ initialPositionInStream, checkpointInterval, storageLevel, cleanedHandler,
+ awsAccessKeyId, awsSecretKey)
+ }
+
+ /**
+ * Create an input stream that pulls messages from a Kinesis stream.
+ * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis.
+ *
+ * Note: The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain
+ * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain
+ * gets the AWS credentials.
+ *
+ * @param jssc Java StreamingContext object
+ * @param kinesisAppName Kinesis application name used by the Kinesis Client Library
+ * (KCL) to update DynamoDB
+ * @param streamName Kinesis stream name
+ * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com)
+ * @param regionName Name of region used by the Kinesis Client Library (KCL) to update
+ * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics)
+ * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the
+ * worker's initial starting position in the stream.
+ * The values are either the beginning of the stream
+ * per Kinesis' limit of 24 hours
+ * (InitialPositionInStream.TRIM_HORIZON) or
+ * the tip of the stream (InitialPositionInStream.LATEST).
+ * @param checkpointInterval Checkpoint interval for Kinesis checkpointing.
+ * See the Kinesis Spark Streaming documentation for more
+ * details on the different types of checkpoints.
+ * @param storageLevel Storage level to use for storing the received objects.
+ * StorageLevel.MEMORY_AND_DISK_2 is recommended.
*/
def createStream(
jssc: JavaStreamingContext,
@@ -198,8 +403,8 @@ object KinesisUtils {
checkpointInterval: Duration,
storageLevel: StorageLevel
): JavaReceiverInputDStream[Array[Byte]] = {
- createStream(jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName,
- initialPositionInStream, checkpointInterval, storageLevel)
+ createStream[Array[Byte]](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName,
+ initialPositionInStream, checkpointInterval, storageLevel, defaultMessageHandler(_))
}
/**
@@ -241,10 +446,10 @@ object KinesisUtils {
checkpointInterval: Duration,
storageLevel: StorageLevel,
awsAccessKeyId: String,
- awsSecretKey: String
- ): JavaReceiverInputDStream[Array[Byte]] = {
- createStream(jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName,
- initialPositionInStream, checkpointInterval, storageLevel, awsAccessKeyId, awsSecretKey)
+ awsSecretKey: String): JavaReceiverInputDStream[Array[Byte]] = {
+ createStream[Array[Byte]](jssc.ssc, kinesisAppName, streamName, endpointUrl, regionName,
+ initialPositionInStream, checkpointInterval, storageLevel,
+ defaultMessageHandler(_), awsAccessKeyId, awsSecretKey)
}
/**
@@ -297,6 +502,14 @@ object KinesisUtils {
throw new IllegalArgumentException(s"Region name '$regionName' is not valid")
}
}
+
+ private[kinesis] def defaultMessageHandler(record: Record): Array[Byte] = {
+ if (record == null) return null
+ val byteBuffer = record.getData()
+ val byteArray = new Array[Byte](byteBuffer.remaining())
+ byteBuffer.get(byteArray)
+ byteArray
+ }
}
/**
diff --git a/extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java b/extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java
index 87954a31f6..3f0f6793d2 100644
--- a/extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java
+++ b/extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java
@@ -17,14 +17,19 @@
package org.apache.spark.streaming.kinesis;
+import com.amazonaws.services.kinesis.model.Record;
+import org.junit.Test;
+
+import org.apache.spark.api.java.function.Function;
import org.apache.spark.storage.StorageLevel;
import org.apache.spark.streaming.Duration;
import org.apache.spark.streaming.LocalJavaStreamingContext;
import org.apache.spark.streaming.api.java.JavaDStream;
-import org.junit.Test;
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream;
+import java.nio.ByteBuffer;
+
/**
* Demonstrate the use of the KinesisUtils Java API
*/
@@ -33,9 +38,27 @@ public class JavaKinesisStreamSuite extends LocalJavaStreamingContext {
public void testKinesisStream() {
// Tests the API, does not actually test data receiving
JavaDStream<byte[]> kinesisStream = KinesisUtils.createStream(ssc, "mySparkStream",
- "https://kinesis.us-west-2.amazonaws.com", new Duration(2000),
+ "https://kinesis.us-west-2.amazonaws.com", new Duration(2000),
InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2());
-
+
+ ssc.stop();
+ }
+
+
+ private static Function<Record, String> handler = new Function<Record, String>() {
+ @Override
+ public String call(Record record) {
+ return record.getPartitionKey() + "-" + record.getSequenceNumber();
+ }
+ };
+
+ @Test
+ public void testCustomHandler() {
+ // Tests the API, does not actually test data receiving
+ JavaDStream<String> kinesisStream = KinesisUtils.createStream(ssc, "testApp", "mySparkStream",
+ "https://kinesis.us-west-2.amazonaws.com", "us-west-2", InitialPositionInStream.LATEST,
+ new Duration(2000), StorageLevel.MEMORY_AND_DISK_2(), handler, String.class);
+
ssc.stop();
}
}
diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala
index a89e5627e0..9f9e146a08 100644
--- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala
+++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala
@@ -73,22 +73,22 @@ class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll
testIfEnabled("Basic reading from Kinesis") {
// Verify all data using multiple ranges in a single RDD partition
- val receivedData1 = new KinesisBackedBlockRDD(sc, testUtils.regionName, testUtils.endpointUrl,
- fakeBlockIds(1),
+ val receivedData1 = new KinesisBackedBlockRDD[Array[Byte]](sc, testUtils.regionName,
+ testUtils.endpointUrl, fakeBlockIds(1),
Array(SequenceNumberRanges(allRanges.toArray))
).map { bytes => new String(bytes).toInt }.collect()
assert(receivedData1.toSet === testData.toSet)
// Verify all data using one range in each of the multiple RDD partitions
- val receivedData2 = new KinesisBackedBlockRDD(sc, testUtils.regionName, testUtils.endpointUrl,
- fakeBlockIds(allRanges.size),
+ val receivedData2 = new KinesisBackedBlockRDD[Array[Byte]](sc, testUtils.regionName,
+ testUtils.endpointUrl, fakeBlockIds(allRanges.size),
allRanges.map { range => SequenceNumberRanges(Array(range)) }.toArray
).map { bytes => new String(bytes).toInt }.collect()
assert(receivedData2.toSet === testData.toSet)
// Verify ordering within each partition
- val receivedData3 = new KinesisBackedBlockRDD(sc, testUtils.regionName, testUtils.endpointUrl,
- fakeBlockIds(allRanges.size),
+ val receivedData3 = new KinesisBackedBlockRDD[Array[Byte]](sc, testUtils.regionName,
+ testUtils.endpointUrl, fakeBlockIds(allRanges.size),
allRanges.map { range => SequenceNumberRanges(Array(range)) }.toArray
).map { bytes => new String(bytes).toInt }.collectPartitions()
assert(receivedData3.length === allRanges.size)
@@ -209,7 +209,7 @@ class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll
}, "Incorrect configuration of RDD, unexpected ranges set"
)
- val rdd = new KinesisBackedBlockRDD(
+ val rdd = new KinesisBackedBlockRDD[Array[Byte]](
sc, testUtils.regionName, testUtils.endpointUrl, blockIds, ranges)
val collectedData = rdd.map { bytes =>
new String(bytes).toInt
@@ -223,7 +223,7 @@ class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll
if (testIsBlockValid) {
require(numPartitionsInBM === numPartitions, "All partitions must be in BlockManager")
require(numPartitionsInKinesis === 0, "No partitions must be in Kinesis")
- val rdd2 = new KinesisBackedBlockRDD(
+ val rdd2 = new KinesisBackedBlockRDD[Array[Byte]](
sc, testUtils.regionName, testUtils.endpointUrl, blockIds.toArray, ranges,
isBlockIdValid = Array.fill(blockIds.length)(false))
intercept[SparkException] {
diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala
index 3d136aec2e..17ab444704 100644
--- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala
+++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala
@@ -52,14 +52,14 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft
record2.setData(ByteBuffer.wrap("Learning Spark".getBytes(StandardCharsets.UTF_8)))
val batch = Arrays.asList(record1, record2)
- var receiverMock: KinesisReceiver = _
+ var receiverMock: KinesisReceiver[Array[Byte]] = _
var checkpointerMock: IRecordProcessorCheckpointer = _
var checkpointClockMock: ManualClock = _
var checkpointStateMock: KinesisCheckpointState = _
var currentClockMock: Clock = _
override def beforeFunction(): Unit = {
- receiverMock = mock[KinesisReceiver]
+ receiverMock = mock[KinesisReceiver[Array[Byte]]]
checkpointerMock = mock[IRecordProcessorCheckpointer]
checkpointClockMock = mock[ManualClock]
checkpointStateMock = mock[KinesisCheckpointState]
diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
index 1177dc7581..ba84e557df 100644
--- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
+++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
@@ -24,6 +24,7 @@ import scala.util.Random
import com.amazonaws.regions.RegionUtils
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream
+import com.amazonaws.services.kinesis.model.Record
import org.scalatest.Matchers._
import org.scalatest.concurrent.Eventually
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
@@ -31,6 +32,7 @@ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.{StorageLevel, StreamBlockId}
import org.apache.spark.streaming._
+import org.apache.spark.streaming.dstream.ReceiverInputDStream
import org.apache.spark.streaming.kinesis.KinesisTestUtils._
import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult
import org.apache.spark.streaming.scheduler.ReceivedBlockInfo
@@ -113,9 +115,9 @@ class KinesisStreamSuite extends KinesisFunSuite
val inputStream = KinesisUtils.createStream(ssc, appName, "dummyStream",
dummyEndpointUrl, dummyRegionName, InitialPositionInStream.LATEST, Seconds(2),
StorageLevel.MEMORY_AND_DISK_2, dummyAWSAccessKey, dummyAWSSecretKey)
- assert(inputStream.isInstanceOf[KinesisInputDStream])
+ assert(inputStream.isInstanceOf[KinesisInputDStream[Array[Byte]]])
- val kinesisStream = inputStream.asInstanceOf[KinesisInputDStream]
+ val kinesisStream = inputStream.asInstanceOf[KinesisInputDStream[Array[Byte]]]
val time = Time(1000)
// Generate block info data for testing
@@ -134,8 +136,8 @@ class KinesisStreamSuite extends KinesisFunSuite
// Verify that the generated KinesisBackedBlockRDD has the all the right information
val blockInfos = Seq(blockInfo1, blockInfo2)
val nonEmptyRDD = kinesisStream.createBlockRDD(time, blockInfos)
- nonEmptyRDD shouldBe a [KinesisBackedBlockRDD]
- val kinesisRDD = nonEmptyRDD.asInstanceOf[KinesisBackedBlockRDD]
+ nonEmptyRDD shouldBe a [KinesisBackedBlockRDD[Array[Byte]]]
+ val kinesisRDD = nonEmptyRDD.asInstanceOf[KinesisBackedBlockRDD[Array[Byte]]]
assert(kinesisRDD.regionName === dummyRegionName)
assert(kinesisRDD.endpointUrl === dummyEndpointUrl)
assert(kinesisRDD.retryTimeoutMs === batchDuration.milliseconds)
@@ -151,7 +153,7 @@ class KinesisStreamSuite extends KinesisFunSuite
// Verify that KinesisBackedBlockRDD is generated even when there are no blocks
val emptyRDD = kinesisStream.createBlockRDD(time, Seq.empty)
- emptyRDD shouldBe a [KinesisBackedBlockRDD]
+ emptyRDD shouldBe a [KinesisBackedBlockRDD[Array[Byte]]]
emptyRDD.partitions shouldBe empty
// Verify that the KinesisBackedBlockRDD has isBlockValid = false when blocks are invalid
@@ -192,6 +194,32 @@ class KinesisStreamSuite extends KinesisFunSuite
ssc.stop(stopSparkContext = false)
}
+ testIfEnabled("custom message handling") {
+ val awsCredentials = KinesisTestUtils.getAWSCredentials()
+ def addFive(r: Record): Int = new String(r.getData.array()).toInt + 5
+ val stream = KinesisUtils.createStream(ssc, appName, testUtils.streamName,
+ testUtils.endpointUrl, testUtils.regionName, InitialPositionInStream.LATEST,
+ Seconds(10), StorageLevel.MEMORY_ONLY, addFive,
+ awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey)
+
+ stream shouldBe a [ReceiverInputDStream[Int]]
+
+ val collected = new mutable.HashSet[Int] with mutable.SynchronizedSet[Int]
+ stream.foreachRDD { rdd =>
+ collected ++= rdd.collect()
+ logInfo("Collected = " + rdd.collect().toSeq.mkString(", "))
+ }
+ ssc.start()
+
+ val testData = 1 to 10
+ eventually(timeout(120 seconds), interval(10 second)) {
+ testUtils.pushData(testData)
+ val modData = testData.map(_ + 5)
+ assert(collected === modData.toSet, "\nData received does not match data sent")
+ }
+ ssc.stop(stopSparkContext = false)
+ }
+
testIfEnabled("failure recovery") {
val sparkConf = new SparkConf().setMaster("local[4]").setAppName(this.getClass.getSimpleName)
val checkpointDir = Utils.createTempDir().getAbsolutePath
@@ -210,7 +238,7 @@ class KinesisStreamSuite extends KinesisFunSuite
// Verify that the generated RDDs are KinesisBackedBlockRDDs, and collect the data in each batch
kinesisStream.foreachRDD((rdd: RDD[Array[Byte]], time: Time) => {
- val kRdd = rdd.asInstanceOf[KinesisBackedBlockRDD]
+ val kRdd = rdd.asInstanceOf[KinesisBackedBlockRDD[Array[Byte]]]
val data = rdd.map { bytes => new String(bytes).toInt }.collect().toSeq
collectedData(time) = (kRdd.arrayOfseqNumberRanges, data)
})
@@ -243,10 +271,10 @@ class KinesisStreamSuite extends KinesisFunSuite
times.foreach { time =>
val (arrayOfSeqNumRanges, data) = collectedData(time)
val rdd = recoveredKinesisStream.getOrCompute(time).get.asInstanceOf[RDD[Array[Byte]]]
- rdd shouldBe a [KinesisBackedBlockRDD]
+ rdd shouldBe a [KinesisBackedBlockRDD[Array[Byte]]]
// Verify the recovered sequence ranges
- val kRdd = rdd.asInstanceOf[KinesisBackedBlockRDD]
+ val kRdd = rdd.asInstanceOf[KinesisBackedBlockRDD[Array[Byte]]]
assert(kRdd.arrayOfseqNumberRanges.size === arrayOfSeqNumRanges.size)
arrayOfSeqNumRanges.zip(kRdd.arrayOfseqNumberRanges).foreach { case (expected, found) =>
assert(expected.ranges.toSeq === found.ranges.toSeq)