aboutsummaryrefslogtreecommitdiff
path: root/extras/kinesis-asl
diff options
context:
space:
mode:
authorBurak Yavuz <brkyvz@gmail.com>2015-11-09 14:39:18 -0800
committerTathagata Das <tathagata.das1565@gmail.com>2015-11-09 14:39:18 -0800
commita3a7c9103e136035d65a5564f9eb0fa04727c4f3 (patch)
treeafa494adb7a10cbf78bbc82b502a33cffb3de9aa /extras/kinesis-asl
parent150f6a89b79f0e5bc31aa83731429dc7ac5ea76b (diff)
downloadspark-a3a7c9103e136035d65a5564f9eb0fa04727c4f3.tar.gz
spark-a3a7c9103e136035d65a5564f9eb0fa04727c4f3.tar.bz2
spark-a3a7c9103e136035d65a5564f9eb0fa04727c4f3.zip
[SPARK-11359][STREAMING][KINESIS] Checkpoint to DynamoDB even when new data doesn't come in
Currently, the checkpoints to DynamoDB occur only when new data comes in, as we update the clock for the checkpointState. This PR makes the checkpoint a scheduled execution based on the `checkpointInterval`. Author: Burak Yavuz <brkyvz@gmail.com> Closes #9421 from brkyvz/kinesis-checkpoint.
Diffstat (limited to 'extras/kinesis-asl')
-rw-r--r--extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala54
-rw-r--r--extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala133
-rw-r--r--extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala38
-rw-r--r--extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala59
-rw-r--r--extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala152
-rw-r--r--extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala96
6 files changed, 349 insertions, 183 deletions
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala
deleted file mode 100644
index 83a4537559..0000000000
--- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala
+++ /dev/null
@@ -1,54 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.streaming.kinesis
-
-import org.apache.spark.Logging
-import org.apache.spark.streaming.Duration
-import org.apache.spark.util.{Clock, ManualClock, SystemClock}
-
-/**
- * This is a helper class for managing checkpoint clocks.
- *
- * @param checkpointInterval
- * @param currentClock. Default to current SystemClock if none is passed in (mocking purposes)
- */
-private[kinesis] class KinesisCheckpointState(
- checkpointInterval: Duration,
- currentClock: Clock = new SystemClock())
- extends Logging {
-
- /* Initialize the checkpoint clock using the given currentClock + checkpointInterval millis */
- val checkpointClock = new ManualClock()
- checkpointClock.setTime(currentClock.getTimeMillis() + checkpointInterval.milliseconds)
-
- /**
- * Check if it's time to checkpoint based on the current time and the derived time
- * for the next checkpoint
- *
- * @return true if it's time to checkpoint
- */
- def shouldCheckpoint(): Boolean = {
- new SystemClock().getTimeMillis() > checkpointClock.getTimeMillis()
- }
-
- /**
- * Advance the checkpoint clock by the checkpoint interval.
- */
- def advanceCheckpoint(): Unit = {
- checkpointClock.advance(checkpointInterval.milliseconds)
- }
-}
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala
new file mode 100644
index 0000000000..1ca6d4302c
--- /dev/null
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.streaming.kinesis
+
+import java.util.concurrent._
+
+import scala.util.control.NonFatal
+
+import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer
+import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason
+
+import org.apache.spark.Logging
+import org.apache.spark.streaming.Duration
+import org.apache.spark.streaming.util.RecurringTimer
+import org.apache.spark.util.{Clock, SystemClock, ThreadUtils}
+
+/**
+ * This is a helper class for managing Kinesis checkpointing.
+ *
+ * @param receiver The receiver that keeps track of which sequence numbers we can checkpoint
+ * @param checkpointInterval How frequently we will checkpoint to DynamoDB
+ * @param workerId Worker Id of KCL worker for logging purposes
+ * @param clock In order to use ManualClocks for the purpose of testing
+ */
+private[kinesis] class KinesisCheckpointer(
+ receiver: KinesisReceiver[_],
+ checkpointInterval: Duration,
+ workerId: String,
+ clock: Clock = new SystemClock) extends Logging {
+
+ // a map from shardId's to checkpointers
+ private val checkpointers = new ConcurrentHashMap[String, IRecordProcessorCheckpointer]()
+
+ private val lastCheckpointedSeqNums = new ConcurrentHashMap[String, String]()
+
+ private val checkpointerThread: RecurringTimer = startCheckpointerThread()
+
+ /** Update the checkpointer instance to the most recent one for the given shardId. */
+ def setCheckpointer(shardId: String, checkpointer: IRecordProcessorCheckpointer): Unit = {
+ checkpointers.put(shardId, checkpointer)
+ }
+
+ /**
+ * Stop tracking the specified shardId.
+ *
+ * If a checkpointer is provided, e.g. on IRecordProcessor.shutdown [[ShutdownReason.TERMINATE]],
+ * we will use that to make the final checkpoint. If `null` is provided, we will not make the
+ * checkpoint, e.g. in case of [[ShutdownReason.ZOMBIE]].
+ */
+ def removeCheckpointer(shardId: String, checkpointer: IRecordProcessorCheckpointer): Unit = {
+ synchronized {
+ checkpointers.remove(shardId)
+ checkpoint(shardId, checkpointer)
+ }
+ }
+
+ /** Perform the checkpoint. */
+ private def checkpoint(shardId: String, checkpointer: IRecordProcessorCheckpointer): Unit = {
+ try {
+ if (checkpointer != null) {
+ receiver.getLatestSeqNumToCheckpoint(shardId).foreach { latestSeqNum =>
+ val lastSeqNum = lastCheckpointedSeqNums.get(shardId)
+ // Kinesis sequence numbers are monotonically increasing strings, therefore we can do
+ // safely do the string comparison
+ if (lastSeqNum == null || latestSeqNum > lastSeqNum) {
+ /* Perform the checkpoint */
+ KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(latestSeqNum), 4, 100)
+ logDebug(s"Checkpoint: WorkerId $workerId completed checkpoint at sequence number" +
+ s" $latestSeqNum for shardId $shardId")
+ lastCheckpointedSeqNums.put(shardId, latestSeqNum)
+ }
+ }
+ } else {
+ logDebug(s"Checkpointing skipped for shardId $shardId. Checkpointer not set.")
+ }
+ } catch {
+ case NonFatal(e) =>
+ logWarning(s"Failed to checkpoint shardId $shardId to DynamoDB.", e)
+ }
+ }
+
+ /** Checkpoint the latest saved sequence numbers for all active shardId's. */
+ private def checkpointAll(): Unit = synchronized {
+ // if this method throws an exception, then the scheduled task will not run again
+ try {
+ val shardIds = checkpointers.keys()
+ while (shardIds.hasMoreElements) {
+ val shardId = shardIds.nextElement()
+ checkpoint(shardId, checkpointers.get(shardId))
+ }
+ } catch {
+ case NonFatal(e) =>
+ logWarning("Failed to checkpoint to DynamoDB.", e)
+ }
+ }
+
+ /**
+ * Start the checkpointer thread with the given checkpoint duration.
+ */
+ private def startCheckpointerThread(): RecurringTimer = {
+ val period = checkpointInterval.milliseconds
+ val threadName = s"Kinesis Checkpointer - Worker $workerId"
+ val timer = new RecurringTimer(clock, period, _ => checkpointAll(), threadName)
+ timer.start()
+ logDebug(s"Started checkpointer thread: $threadName")
+ timer
+ }
+
+ /**
+ * Shutdown the checkpointer. Should be called on the onStop of the Receiver.
+ */
+ def shutdown(): Unit = {
+ // the recurring timer checkpoints for us one last time.
+ checkpointerThread.stop(interruptTimer = false)
+ checkpointers.clear()
+ lastCheckpointedSeqNums.clear()
+ logInfo("Successfully shutdown Kinesis Checkpointer.")
+ }
+}
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
index 134d627cda..50993f157c 100644
--- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
@@ -23,7 +23,7 @@ import scala.collection.mutable
import scala.util.control.NonFatal
import com.amazonaws.auth.{AWSCredentials, AWSCredentialsProvider, DefaultAWSCredentialsProviderChain}
-import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorFactory}
+import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessorCheckpointer, IRecordProcessor, IRecordProcessorFactory}
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{InitialPositionInStream, KinesisClientLibConfiguration, Worker}
import com.amazonaws.services.kinesis.model.Record
@@ -31,8 +31,7 @@ import org.apache.spark.storage.{StorageLevel, StreamBlockId}
import org.apache.spark.streaming.Duration
import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, Receiver}
import org.apache.spark.util.Utils
-import org.apache.spark.{Logging, SparkEnv}
-
+import org.apache.spark.Logging
private[kinesis]
case class SerializableAWSCredentials(accessKeyId: String, secretKey: String)
@@ -128,6 +127,11 @@ private[kinesis] class KinesisReceiver[T](
with mutable.SynchronizedMap[StreamBlockId, SequenceNumberRanges]
/**
+ * The centralized kinesisCheckpointer that checkpoints based on the given checkpointInterval.
+ */
+ @volatile private var kinesisCheckpointer: KinesisCheckpointer = null
+
+ /**
* Latest sequence number ranges that have been stored successfully.
* This is used for checkpointing through KCL */
private val shardIdToLatestStoredSeqNum = new mutable.HashMap[String, String]
@@ -141,6 +145,7 @@ private[kinesis] class KinesisReceiver[T](
workerId = Utils.localHostName() + ":" + UUID.randomUUID()
+ kinesisCheckpointer = new KinesisCheckpointer(receiver, checkpointInterval, workerId)
// KCL config instance
val awsCredProvider = resolveAWSCredentialsProvider()
val kinesisClientLibConfiguration =
@@ -157,8 +162,8 @@ private[kinesis] class KinesisReceiver[T](
* We're using our custom KinesisRecordProcessor in this case.
*/
val recordProcessorFactory = new IRecordProcessorFactory {
- override def createProcessor: IRecordProcessor = new KinesisRecordProcessor(receiver,
- workerId, new KinesisCheckpointState(checkpointInterval))
+ override def createProcessor: IRecordProcessor =
+ new KinesisRecordProcessor(receiver, workerId)
}
worker = new Worker(recordProcessorFactory, kinesisClientLibConfiguration)
@@ -198,6 +203,10 @@ private[kinesis] class KinesisReceiver[T](
logInfo(s"Stopped receiver for workerId $workerId")
}
workerId = null
+ if (kinesisCheckpointer != null) {
+ kinesisCheckpointer.shutdown()
+ kinesisCheckpointer = null
+ }
}
/** Add records of the given shard to the current block being generated */
@@ -217,6 +226,25 @@ private[kinesis] class KinesisReceiver[T](
}
/**
+ * Set the checkpointer that will be used to checkpoint sequence numbers to DynamoDB for the
+ * given shardId.
+ */
+ def setCheckpointer(shardId: String, checkpointer: IRecordProcessorCheckpointer): Unit = {
+ assert(kinesisCheckpointer != null, "Kinesis Checkpointer not initialized!")
+ kinesisCheckpointer.setCheckpointer(shardId, checkpointer)
+ }
+
+ /**
+ * Remove the checkpointer for the given shardId. The provided checkpointer will be used to
+ * checkpoint one last time for the given shard. If `checkpointer` is `null`, then we will not
+ * checkpoint.
+ */
+ def removeCheckpointer(shardId: String, checkpointer: IRecordProcessorCheckpointer): Unit = {
+ assert(kinesisCheckpointer != null, "Kinesis Checkpointer not initialized!")
+ kinesisCheckpointer.removeCheckpointer(shardId, checkpointer)
+ }
+
+ /**
* Remember the range of sequence numbers that was added to the currently active block.
* Internally, this is synchronized with `finalizeRangesForCurrentBlock()`.
*/
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala
index 1d5178790e..e381ffa0cb 100644
--- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala
@@ -27,26 +27,23 @@ import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason
import com.amazonaws.services.kinesis.model.Record
import org.apache.spark.Logging
+import org.apache.spark.streaming.Duration
/**
* Kinesis-specific implementation of the Kinesis Client Library (KCL) IRecordProcessor.
* This implementation operates on the Array[Byte] from the KinesisReceiver.
* The Kinesis Worker creates an instance of this KinesisRecordProcessor for each
- * shard in the Kinesis stream upon startup. This is normally done in separate threads,
- * but the KCLs within the KinesisReceivers will balance themselves out if you create
- * multiple Receivers.
+ * shard in the Kinesis stream upon startup. This is normally done in separate threads,
+ * but the KCLs within the KinesisReceivers will balance themselves out if you create
+ * multiple Receivers.
*
* @param receiver Kinesis receiver
* @param workerId for logging purposes
- * @param checkpointState represents the checkpoint state including the next checkpoint time.
- * It's injected here for mocking purposes.
*/
-private[kinesis] class KinesisRecordProcessor[T](
- receiver: KinesisReceiver[T],
- workerId: String,
- checkpointState: KinesisCheckpointState) extends IRecordProcessor with Logging {
+private[kinesis] class KinesisRecordProcessor[T](receiver: KinesisReceiver[T], workerId: String)
+ extends IRecordProcessor with Logging {
- // shardId to be populated during initialize()
+ // shardId populated during initialize()
@volatile
private var shardId: String = _
@@ -74,34 +71,7 @@ private[kinesis] class KinesisRecordProcessor[T](
try {
receiver.addRecords(shardId, batch)
logDebug(s"Stored: Worker $workerId stored ${batch.size} records for shardId $shardId")
-
- /*
- *
- * Checkpoint the sequence number of the last record successfully stored.
- * Note that in this current implementation, the checkpointing occurs only when after
- * checkpointIntervalMillis from the last checkpoint, AND when there is new record
- * to process. This leads to the checkpointing lagging behind what records have been
- * stored by the receiver. Ofcourse, this can lead records processed more than once,
- * under failures and restarts.
- *
- * TODO: Instead of checkpointing here, run a separate timer task to perform
- * checkpointing so that it checkpoints in a timely manner independent of whether
- * new records are available or not.
- */
- if (checkpointState.shouldCheckpoint()) {
- receiver.getLatestSeqNumToCheckpoint(shardId).foreach { latestSeqNum =>
- /* Perform the checkpoint */
- KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(latestSeqNum), 4, 100)
-
- /* Update the next checkpoint time */
- checkpointState.advanceCheckpoint()
-
- logDebug(s"Checkpoint: WorkerId $workerId completed checkpoint of ${batch.size}" +
- s" records for shardId $shardId")
- logDebug(s"Checkpoint: Next checkpoint is at " +
- s" ${checkpointState.checkpointClock.getTimeMillis()} for shardId $shardId")
- }
- }
+ receiver.setCheckpointer(shardId, checkpointer)
} catch {
case NonFatal(e) => {
/*
@@ -142,23 +112,18 @@ private[kinesis] class KinesisRecordProcessor[T](
* It's now OK to read from the new shards that resulted from a resharding event.
*/
case ShutdownReason.TERMINATE =>
- val latestSeqNumToCheckpointOption = receiver.getLatestSeqNumToCheckpoint(shardId)
- if (latestSeqNumToCheckpointOption.nonEmpty) {
- KinesisRecordProcessor.retryRandom(
- checkpointer.checkpoint(latestSeqNumToCheckpointOption.get), 4, 100)
- }
+ receiver.removeCheckpointer(shardId, checkpointer)
/*
- * ZOMBIE Use Case. NoOp.
+ * ZOMBIE Use Case or Unknown reason. NoOp.
* No checkpoint because other workers may have taken over and already started processing
* the same records.
* This may lead to records being processed more than once.
*/
- case ShutdownReason.ZOMBIE =>
-
- /* Unknown reason. NoOp */
case _ =>
+ receiver.removeCheckpointer(shardId, null) // return null so that we don't checkpoint
}
+
}
}
diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala
new file mode 100644
index 0000000000..645e64a0bc
--- /dev/null
+++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kinesis
+
+import java.util.concurrent.{TimeoutException, ExecutorService}
+
+import scala.concurrent.{Await, ExecutionContext, Future}
+import scala.concurrent.duration._
+import scala.language.postfixOps
+
+import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer
+import org.mockito.Matchers._
+import org.mockito.Mockito._
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+import org.scalatest.{PrivateMethodTester, BeforeAndAfterEach}
+import org.scalatest.concurrent.Eventually
+import org.scalatest.concurrent.Eventually._
+import org.scalatest.mock.MockitoSugar
+
+import org.apache.spark.streaming.{Duration, TestSuiteBase}
+import org.apache.spark.util.ManualClock
+
+class KinesisCheckpointerSuite extends TestSuiteBase
+ with MockitoSugar
+ with BeforeAndAfterEach
+ with PrivateMethodTester
+ with Eventually {
+
+ private val workerId = "dummyWorkerId"
+ private val shardId = "dummyShardId"
+ private val seqNum = "123"
+ private val otherSeqNum = "245"
+ private val checkpointInterval = Duration(10)
+ private val someSeqNum = Some(seqNum)
+ private val someOtherSeqNum = Some(otherSeqNum)
+
+ private var receiverMock: KinesisReceiver[Array[Byte]] = _
+ private var checkpointerMock: IRecordProcessorCheckpointer = _
+ private var kinesisCheckpointer: KinesisCheckpointer = _
+ private var clock: ManualClock = _
+
+ private val checkpoint = PrivateMethod[Unit]('checkpoint)
+
+ override def beforeEach(): Unit = {
+ receiverMock = mock[KinesisReceiver[Array[Byte]]]
+ checkpointerMock = mock[IRecordProcessorCheckpointer]
+ clock = new ManualClock()
+ kinesisCheckpointer = new KinesisCheckpointer(receiverMock, checkpointInterval, workerId, clock)
+ }
+
+ test("checkpoint is not called twice for the same sequence number") {
+ when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum)
+ kinesisCheckpointer.invokePrivate(checkpoint(shardId, checkpointerMock))
+ kinesisCheckpointer.invokePrivate(checkpoint(shardId, checkpointerMock))
+
+ verify(checkpointerMock, times(1)).checkpoint(anyString())
+ }
+
+ test("checkpoint is called after sequence number increases") {
+ when(receiverMock.getLatestSeqNumToCheckpoint(shardId))
+ .thenReturn(someSeqNum).thenReturn(someOtherSeqNum)
+ kinesisCheckpointer.invokePrivate(checkpoint(shardId, checkpointerMock))
+ kinesisCheckpointer.invokePrivate(checkpoint(shardId, checkpointerMock))
+
+ verify(checkpointerMock, times(1)).checkpoint(seqNum)
+ verify(checkpointerMock, times(1)).checkpoint(otherSeqNum)
+ }
+
+ test("should checkpoint if we have exceeded the checkpoint interval") {
+ when(receiverMock.getLatestSeqNumToCheckpoint(shardId))
+ .thenReturn(someSeqNum).thenReturn(someOtherSeqNum)
+
+ kinesisCheckpointer.setCheckpointer(shardId, checkpointerMock)
+ clock.advance(5 * checkpointInterval.milliseconds)
+
+ eventually(timeout(1 second)) {
+ verify(checkpointerMock, times(1)).checkpoint(seqNum)
+ verify(checkpointerMock, times(1)).checkpoint(otherSeqNum)
+ }
+ }
+
+ test("shouldn't checkpoint if we have not exceeded the checkpoint interval") {
+ when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum)
+
+ kinesisCheckpointer.setCheckpointer(shardId, checkpointerMock)
+ clock.advance(checkpointInterval.milliseconds / 2)
+
+ verify(checkpointerMock, never()).checkpoint(anyString())
+ }
+
+ test("should not checkpoint for the same sequence number") {
+ when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum)
+
+ kinesisCheckpointer.setCheckpointer(shardId, checkpointerMock)
+
+ clock.advance(checkpointInterval.milliseconds * 5)
+ eventually(timeout(1 second)) {
+ verify(checkpointerMock, atMost(1)).checkpoint(anyString())
+ }
+ }
+
+ test("removing checkpointer checkpoints one last time") {
+ when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum)
+
+ kinesisCheckpointer.removeCheckpointer(shardId, checkpointerMock)
+ verify(checkpointerMock, times(1)).checkpoint(anyString())
+ }
+
+ test("if checkpointing is going on, wait until finished before removing and checkpointing") {
+ when(receiverMock.getLatestSeqNumToCheckpoint(shardId))
+ .thenReturn(someSeqNum).thenReturn(someOtherSeqNum)
+ when(checkpointerMock.checkpoint(anyString)).thenAnswer(new Answer[Unit] {
+ override def answer(invocations: InvocationOnMock): Unit = {
+ clock.waitTillTime(clock.getTimeMillis() + checkpointInterval.milliseconds / 2)
+ }
+ })
+
+ kinesisCheckpointer.setCheckpointer(shardId, checkpointerMock)
+ clock.advance(checkpointInterval.milliseconds)
+ eventually(timeout(1 second)) {
+ verify(checkpointerMock, times(1)).checkpoint(anyString())
+ }
+ // don't block test thread
+ val f = Future(kinesisCheckpointer.removeCheckpointer(shardId, checkpointerMock))(
+ ExecutionContext.global)
+
+ intercept[TimeoutException] {
+ Await.ready(f, 50 millis)
+ }
+
+ clock.advance(checkpointInterval.milliseconds / 2)
+ eventually(timeout(1 second)) {
+ verify(checkpointerMock, times(2)).checkpoint(anyString())
+ }
+ }
+}
diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala
index 17ab444704..e5c70db554 100644
--- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala
+++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala
@@ -25,12 +25,13 @@ import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorC
import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason
import com.amazonaws.services.kinesis.model.Record
import org.mockito.Matchers._
+import org.mockito.Matchers.{eq => meq}
import org.mockito.Mockito._
import org.scalatest.mock.MockitoSugar
import org.scalatest.{BeforeAndAfter, Matchers}
-import org.apache.spark.streaming.{Milliseconds, TestSuiteBase}
-import org.apache.spark.util.{Clock, ManualClock, Utils}
+import org.apache.spark.streaming.{Duration, TestSuiteBase}
+import org.apache.spark.util.Utils
/**
* Suite of Kinesis streaming receiver tests focusing mostly on the KinesisRecordProcessor
@@ -44,6 +45,7 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft
val workerId = "dummyWorkerId"
val shardId = "dummyShardId"
val seqNum = "dummySeqNum"
+ val checkpointInterval = Duration(10)
val someSeqNum = Some(seqNum)
val record1 = new Record()
@@ -54,24 +56,10 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft
var receiverMock: KinesisReceiver[Array[Byte]] = _
var checkpointerMock: IRecordProcessorCheckpointer = _
- var checkpointClockMock: ManualClock = _
- var checkpointStateMock: KinesisCheckpointState = _
- var currentClockMock: Clock = _
override def beforeFunction(): Unit = {
receiverMock = mock[KinesisReceiver[Array[Byte]]]
checkpointerMock = mock[IRecordProcessorCheckpointer]
- checkpointClockMock = mock[ManualClock]
- checkpointStateMock = mock[KinesisCheckpointState]
- currentClockMock = mock[Clock]
- }
-
- override def afterFunction(): Unit = {
- super.afterFunction()
- // Since this suite was originally written using EasyMock, add this to preserve the old
- // mocking semantics (see SPARK-5735 for more details)
- verifyNoMoreInteractions(receiverMock, checkpointerMock, checkpointClockMock,
- checkpointStateMock, currentClockMock)
}
test("check serializability of SerializableAWSCredentials") {
@@ -79,113 +67,67 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft
Utils.serialize(new SerializableAWSCredentials("x", "y")))
}
- test("process records including store and checkpoint") {
+ test("process records including store and set checkpointer") {
when(receiverMock.isStopped()).thenReturn(false)
- when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum)
- when(checkpointStateMock.shouldCheckpoint()).thenReturn(true)
- val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock)
+ val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId)
recordProcessor.initialize(shardId)
recordProcessor.processRecords(batch, checkpointerMock)
verify(receiverMock, times(1)).isStopped()
verify(receiverMock, times(1)).addRecords(shardId, batch)
- verify(receiverMock, times(1)).getLatestSeqNumToCheckpoint(shardId)
- verify(checkpointStateMock, times(1)).shouldCheckpoint()
- verify(checkpointerMock, times(1)).checkpoint(anyString)
- verify(checkpointStateMock, times(1)).advanceCheckpoint()
+ verify(receiverMock, times(1)).setCheckpointer(shardId, checkpointerMock)
}
- test("shouldn't store and checkpoint when receiver is stopped") {
+ test("shouldn't store and update checkpointer when receiver is stopped") {
when(receiverMock.isStopped()).thenReturn(true)
- val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock)
+ val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId)
recordProcessor.processRecords(batch, checkpointerMock)
verify(receiverMock, times(1)).isStopped()
verify(receiverMock, never).addRecords(anyString, anyListOf(classOf[Record]))
- verify(checkpointerMock, never).checkpoint(anyString)
+ verify(receiverMock, never).setCheckpointer(anyString, meq(checkpointerMock))
}
- test("shouldn't checkpoint when exception occurs during store") {
+ test("shouldn't update checkpointer when exception occurs during store") {
when(receiverMock.isStopped()).thenReturn(false)
when(
receiverMock.addRecords(anyString, anyListOf(classOf[Record]))
).thenThrow(new RuntimeException())
intercept[RuntimeException] {
- val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock)
+ val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId)
recordProcessor.initialize(shardId)
recordProcessor.processRecords(batch, checkpointerMock)
}
verify(receiverMock, times(1)).isStopped()
verify(receiverMock, times(1)).addRecords(shardId, batch)
- verify(checkpointerMock, never).checkpoint(anyString)
- }
-
- test("should set checkpoint time to currentTime + checkpoint interval upon instantiation") {
- when(currentClockMock.getTimeMillis()).thenReturn(0)
-
- val checkpointIntervalMillis = 10
- val checkpointState =
- new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock)
- assert(checkpointState.checkpointClock.getTimeMillis() == checkpointIntervalMillis)
-
- verify(currentClockMock, times(1)).getTimeMillis()
- }
-
- test("should checkpoint if we have exceeded the checkpoint interval") {
- when(currentClockMock.getTimeMillis()).thenReturn(0)
-
- val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MinValue), currentClockMock)
- assert(checkpointState.shouldCheckpoint())
-
- verify(currentClockMock, times(1)).getTimeMillis()
- }
-
- test("shouldn't checkpoint if we have not exceeded the checkpoint interval") {
- when(currentClockMock.getTimeMillis()).thenReturn(0)
-
- val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MaxValue), currentClockMock)
- assert(!checkpointState.shouldCheckpoint())
-
- verify(currentClockMock, times(1)).getTimeMillis()
- }
-
- test("should add to time when advancing checkpoint") {
- when(currentClockMock.getTimeMillis()).thenReturn(0)
-
- val checkpointIntervalMillis = 10
- val checkpointState =
- new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock)
- assert(checkpointState.checkpointClock.getTimeMillis() == checkpointIntervalMillis)
- checkpointState.advanceCheckpoint()
- assert(checkpointState.checkpointClock.getTimeMillis() == (2 * checkpointIntervalMillis))
-
- verify(currentClockMock, times(1)).getTimeMillis()
+ verify(receiverMock, never).setCheckpointer(anyString, meq(checkpointerMock))
}
test("shutdown should checkpoint if the reason is TERMINATE") {
when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum)
- val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock)
+ val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId)
recordProcessor.initialize(shardId)
recordProcessor.shutdown(checkpointerMock, ShutdownReason.TERMINATE)
- verify(receiverMock, times(1)).getLatestSeqNumToCheckpoint(shardId)
- verify(checkpointerMock, times(1)).checkpoint(anyString)
+ verify(receiverMock, times(1)).removeCheckpointer(meq(shardId), meq(checkpointerMock))
}
+
test("shutdown should not checkpoint if the reason is something other than TERMINATE") {
when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum)
- val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock)
+ val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId)
recordProcessor.initialize(shardId)
recordProcessor.shutdown(checkpointerMock, ShutdownReason.ZOMBIE)
recordProcessor.shutdown(checkpointerMock, null)
- verify(checkpointerMock, never).checkpoint(anyString)
+ verify(receiverMock, times(2)).removeCheckpointer(meq(shardId),
+ meq[IRecordProcessorCheckpointer](null))
}
test("retry success on first attempt") {