aboutsummaryrefslogtreecommitdiff
path: root/external/kinesis-asl
diff options
context:
space:
mode:
authorTakeshi YAMAMURO <linguin.m.s@gmail.com>2017-01-25 17:38:48 -0800
committerBurak Yavuz <brkyvz@gmail.com>2017-01-25 17:38:48 -0800
commit256a3a801366ab9f705e50690114e49fdb49b38e (patch)
treed1d6eacbc69e23c8cfe27957172e38dba983939f /external/kinesis-asl
parent2338451266d37b4c952827325cdee53b3e8fbc78 (diff)
downloadspark-256a3a801366ab9f705e50690114e49fdb49b38e.tar.gz
spark-256a3a801366ab9f705e50690114e49fdb49b38e.tar.bz2
spark-256a3a801366ab9f705e50690114e49fdb49b38e.zip
[SPARK-18020][STREAMING][KINESIS] Checkpoint SHARD_END to finish reading closed shards
## What changes were proposed in this pull request? This pr is to fix an issue occurred when resharding Kinesis streams; the resharding makes the KCL throw an exception because Spark does not checkpoint `SHARD_END` when finishing reading closed shards in `KinesisRecordProcessor#shutdown`. This bug finally leads to stopping subscribing new split (or merged) shards. ## How was this patch tested? Added a test in `KinesisStreamSuite` to check if it works well when splitting/merging shards. Author: Takeshi YAMAMURO <linguin.m.s@gmail.com> Closes #16213 from maropu/SPARK-18020.
Diffstat (limited to 'external/kinesis-asl')
-rw-r--r--external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala15
-rw-r--r--external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala30
-rw-r--r--external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KPLBasedKinesisTestUtils.scala3
-rw-r--r--external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala5
-rw-r--r--external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala70
5 files changed, 116 insertions, 7 deletions
diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala
index 3e697f36a4..c445c15a5f 100644
--- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala
+++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala
@@ -64,7 +64,20 @@ private[kinesis] class KinesisCheckpointer(
def removeCheckpointer(shardId: String, checkpointer: IRecordProcessorCheckpointer): Unit = {
synchronized {
checkpointers.remove(shardId)
- checkpoint(shardId, checkpointer)
+ }
+ if (checkpointer != null) {
+ try {
+ // We must call `checkpoint()` with no parameter to finish reading shards.
+ // See an URL below for details:
+ // https://forums.aws.amazon.com/thread.jspa?threadID=244218
+ KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(), 4, 100)
+ } catch {
+ case NonFatal(e) =>
+ logError(s"Exception: WorkerId $workerId encountered an exception while checkpointing" +
+ s"to finish reading a shard of $shardId.", e)
+ // Rethrow the exception to the Kinesis Worker that is managing this RecordProcessor
+ throw e
+ }
}
}
diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala
index 0fe66254e9..f183ef00b3 100644
--- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala
+++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala
@@ -40,11 +40,10 @@ import org.apache.spark.internal.Logging
*
* PLEASE KEEP THIS FILE UNDER src/main AS PYTHON TESTS NEED ACCESS TO THIS FILE!
*/
-private[kinesis] class KinesisTestUtils extends Logging {
+private[kinesis] class KinesisTestUtils(streamShardCount: Int = 2) extends Logging {
val endpointUrl = KinesisTestUtils.endpointUrl
val regionName = RegionUtils.getRegionByEndpoint(endpointUrl).getName()
- val streamShardCount = 2
private val createStreamTimeoutSeconds = 300
private val describeStreamPollTimeSeconds = 1
@@ -88,7 +87,7 @@ private[kinesis] class KinesisTestUtils extends Logging {
logInfo(s"Creating stream ${_streamName}")
val createStreamRequest = new CreateStreamRequest()
createStreamRequest.setStreamName(_streamName)
- createStreamRequest.setShardCount(2)
+ createStreamRequest.setShardCount(streamShardCount)
kinesisClient.createStream(createStreamRequest)
// The stream is now being created. Wait for it to become active.
@@ -97,6 +96,31 @@ private[kinesis] class KinesisTestUtils extends Logging {
logInfo(s"Created stream ${_streamName}")
}
+ def getShards(): Seq[Shard] = {
+ kinesisClient.describeStream(_streamName).getStreamDescription.getShards.asScala
+ }
+
+ def splitShard(shardId: String): Unit = {
+ val splitShardRequest = new SplitShardRequest()
+ splitShardRequest.withStreamName(_streamName)
+ splitShardRequest.withShardToSplit(shardId)
+ // Set a half of the max hash value
+ splitShardRequest.withNewStartingHashKey("170141183460469231731687303715884105728")
+ kinesisClient.splitShard(splitShardRequest)
+ // Wait for the shards to become active
+ waitForStreamToBeActive(_streamName)
+ }
+
+ def mergeShard(shardToMerge: String, adjacentShardToMerge: String): Unit = {
+ val mergeShardRequest = new MergeShardsRequest
+ mergeShardRequest.withStreamName(_streamName)
+ mergeShardRequest.withShardToMerge(shardToMerge)
+ mergeShardRequest.withAdjacentShardToMerge(adjacentShardToMerge)
+ kinesisClient.mergeShards(mergeShardRequest)
+ // Wait for the shards to become active
+ waitForStreamToBeActive(_streamName)
+ }
+
/**
* Push data to Kinesis stream and return a map of
* shardId -> seq of (data, seq number) pushed to corresponding shard
diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KPLBasedKinesisTestUtils.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KPLBasedKinesisTestUtils.scala
index 0b455e574e..2ee3224b3c 100644
--- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KPLBasedKinesisTestUtils.scala
+++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KPLBasedKinesisTestUtils.scala
@@ -25,7 +25,8 @@ import scala.collection.mutable.ArrayBuffer
import com.amazonaws.services.kinesis.producer.{KinesisProducer => KPLProducer, KinesisProducerConfiguration, UserRecordResult}
import com.google.common.util.concurrent.{FutureCallback, Futures}
-private[kinesis] class KPLBasedKinesisTestUtils extends KinesisTestUtils {
+private[kinesis] class KPLBasedKinesisTestUtils(streamShardCount: Int = 2)
+ extends KinesisTestUtils(streamShardCount) {
override protected def getProducer(aggregate: Boolean): KinesisDataGenerator = {
if (!aggregate) {
new SimpleDataGenerator(kinesisClient)
diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala
index bcaed628a8..fef24ed4c5 100644
--- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala
+++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala
@@ -118,7 +118,7 @@ class KinesisCheckpointerSuite extends TestSuiteBase
when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum)
kinesisCheckpointer.removeCheckpointer(shardId, checkpointerMock)
- verify(checkpointerMock, times(1)).checkpoint(anyString())
+ verify(checkpointerMock, times(1)).checkpoint()
}
test("if checkpointing is going on, wait until finished before removing and checkpointing") {
@@ -145,7 +145,8 @@ class KinesisCheckpointerSuite extends TestSuiteBase
clock.advance(checkpointInterval.milliseconds / 2)
eventually(timeout(1 second)) {
- verify(checkpointerMock, times(2)).checkpoint(anyString())
+ verify(checkpointerMock, times(1)).checkpoint(anyString)
+ verify(checkpointerMock, times(1)).checkpoint()
}
}
}
diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
index 0e71bf9b84..404b673c01 100644
--- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
+++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala
@@ -225,6 +225,76 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun
ssc.stop(stopSparkContext = false)
}
+ testIfEnabled("split and merge shards in a stream") {
+ // Since this test tries to split and merge shards in a stream, we create another
+ // temporary stream and then remove it when finished.
+ val localAppName = s"KinesisStreamSuite-${math.abs(Random.nextLong())}"
+ val localTestUtils = new KPLBasedKinesisTestUtils(1)
+ localTestUtils.createStream()
+ try {
+ val awsCredentials = KinesisTestUtils.getAWSCredentials()
+ val stream = KinesisUtils.createStream(ssc, localAppName, localTestUtils.streamName,
+ localTestUtils.endpointUrl, localTestUtils.regionName, InitialPositionInStream.LATEST,
+ Seconds(10), StorageLevel.MEMORY_ONLY,
+ awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey)
+
+ val collected = new mutable.HashSet[Int]
+ stream.map { bytes => new String(bytes).toInt }.foreachRDD { rdd =>
+ collected.synchronized {
+ collected ++= rdd.collect()
+ logInfo("Collected = " + collected.mkString(", "))
+ }
+ }
+ ssc.start()
+
+ val testData1 = 1 to 10
+ val testData2 = 11 to 20
+ val testData3 = 21 to 30
+
+ eventually(timeout(60 seconds), interval(10 second)) {
+ localTestUtils.pushData(testData1, aggregateTestData)
+ assert(collected.synchronized { collected === testData1.toSet },
+ "\nData received does not match data sent")
+ }
+
+ val shardToSplit = localTestUtils.getShards().head
+ localTestUtils.splitShard(shardToSplit.getShardId)
+ val (splitOpenShards, splitCloseShards) = localTestUtils.getShards().partition { shard =>
+ shard.getSequenceNumberRange.getEndingSequenceNumber == null
+ }
+
+ // We should have one closed shard and two open shards
+ assert(splitCloseShards.size == 1)
+ assert(splitOpenShards.size == 2)
+
+ eventually(timeout(60 seconds), interval(10 second)) {
+ localTestUtils.pushData(testData2, aggregateTestData)
+ assert(collected.synchronized { collected === (testData1 ++ testData2).toSet },
+ "\nData received does not match data sent after splitting a shard")
+ }
+
+ val Seq(shardToMerge, adjShard) = splitOpenShards
+ localTestUtils.mergeShard(shardToMerge.getShardId, adjShard.getShardId)
+ val (mergedOpenShards, mergedCloseShards) = localTestUtils.getShards().partition { shard =>
+ shard.getSequenceNumberRange.getEndingSequenceNumber == null
+ }
+
+ // We should have three closed shards and one open shard
+ assert(mergedCloseShards.size == 3)
+ assert(mergedOpenShards.size == 1)
+
+ eventually(timeout(60 seconds), interval(10 second)) {
+ localTestUtils.pushData(testData3, aggregateTestData)
+ assert(collected.synchronized { collected === (testData1 ++ testData2 ++ testData3).toSet },
+ "\nData received does not match data sent after merging shards")
+ }
+ } finally {
+ ssc.stop(stopSparkContext = false)
+ localTestUtils.deleteStream()
+ localTestUtils.deleteDynamoDBTable(localAppName)
+ }
+ }
+
testIfEnabled("failure recovery") {
val sparkConf = new SparkConf().setMaster("local[4]").setAppName(this.getClass.getSimpleName)
val checkpointDir = Utils.createTempDir().getAbsolutePath