diff options
Diffstat (limited to 'external/kinesis-asl/src/main')
2 files changed, 41 insertions, 4 deletions
diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala index 3e697f36a4..c445c15a5f 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala @@ -64,7 +64,20 @@ private[kinesis] class KinesisCheckpointer( def removeCheckpointer(shardId: String, checkpointer: IRecordProcessorCheckpointer): Unit = { synchronized { checkpointers.remove(shardId) - checkpoint(shardId, checkpointer) + } + if (checkpointer != null) { + try { + // We must call `checkpoint()` with no parameter to finish reading shards. + // See an URL below for details: + // https://forums.aws.amazon.com/thread.jspa?threadID=244218 + KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(), 4, 100) + } catch { + case NonFatal(e) => + logError(s"Exception: WorkerId $workerId encountered an exception while checkpointing" + + s"to finish reading a shard of $shardId.", e) + // Rethrow the exception to the Kinesis Worker that is managing this RecordProcessor + throw e + } } } diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala index 0fe66254e9..f183ef00b3 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala @@ -40,11 +40,10 @@ import org.apache.spark.internal.Logging * * PLEASE KEEP THIS FILE UNDER src/main AS PYTHON TESTS NEED ACCESS TO THIS FILE! */ -private[kinesis] class KinesisTestUtils extends Logging { +private[kinesis] class KinesisTestUtils(streamShardCount: Int = 2) extends Logging { val endpointUrl = KinesisTestUtils.endpointUrl val regionName = RegionUtils.getRegionByEndpoint(endpointUrl).getName() - val streamShardCount = 2 private val createStreamTimeoutSeconds = 300 private val describeStreamPollTimeSeconds = 1 @@ -88,7 +87,7 @@ private[kinesis] class KinesisTestUtils extends Logging { logInfo(s"Creating stream ${_streamName}") val createStreamRequest = new CreateStreamRequest() createStreamRequest.setStreamName(_streamName) - createStreamRequest.setShardCount(2) + createStreamRequest.setShardCount(streamShardCount) kinesisClient.createStream(createStreamRequest) // The stream is now being created. Wait for it to become active. @@ -97,6 +96,31 @@ private[kinesis] class KinesisTestUtils extends Logging { logInfo(s"Created stream ${_streamName}") } + def getShards(): Seq[Shard] = { + kinesisClient.describeStream(_streamName).getStreamDescription.getShards.asScala + } + + def splitShard(shardId: String): Unit = { + val splitShardRequest = new SplitShardRequest() + splitShardRequest.withStreamName(_streamName) + splitShardRequest.withShardToSplit(shardId) + // Set a half of the max hash value + splitShardRequest.withNewStartingHashKey("170141183460469231731687303715884105728") + kinesisClient.splitShard(splitShardRequest) + // Wait for the shards to become active + waitForStreamToBeActive(_streamName) + } + + def mergeShard(shardToMerge: String, adjacentShardToMerge: String): Unit = { + val mergeShardRequest = new MergeShardsRequest + mergeShardRequest.withStreamName(_streamName) + mergeShardRequest.withShardToMerge(shardToMerge) + mergeShardRequest.withAdjacentShardToMerge(adjacentShardToMerge) + kinesisClient.mergeShards(mergeShardRequest) + // Wait for the shards to become active + waitForStreamToBeActive(_streamName) + } + /** * Push data to Kinesis stream and return a map of * shardId -> seq of (data, seq number) pushed to corresponding shard |