aboutsummaryrefslogtreecommitdiff
path: root/streaming
diff options
context:
space:
mode:
authorTathagata Das <tathagata.das1565@gmail.com>2015-12-07 11:03:59 -0800
committerTathagata Das <tathagata.das1565@gmail.com>2015-12-07 11:03:59 -0800
commit5d80d8c6a54b2113022eff31187e6d97521bd2cf (patch)
tree68bb6f6eb391c2caeca90c040c2a83d73ce323ab /streaming
parentef3f047c07ef0ac4a3a97e6bc11e1c28c6c8f9a0 (diff)
downloadspark-5d80d8c6a54b2113022eff31187e6d97521bd2cf.tar.gz
spark-5d80d8c6a54b2113022eff31187e6d97521bd2cf.tar.bz2
spark-5d80d8c6a54b2113022eff31187e6d97521bd2cf.zip
[SPARK-11932][STREAMING] Partition previous TrackStateRDD if partitioner not present
The reason is that TrackStateRDDs generated by trackStateByKey expect the previous batch's TrackStateRDDs to have a partitioner. However, when recovery from DStream checkpoints, the RDDs recovered from RDD checkpoints do not have a partitioner attached to it. This is because RDD checkpoints do not preserve the partitioner (SPARK-12004). While #9983 solves SPARK-12004 by preserving the partitioner through RDD checkpoints, there may be a non-zero chance that the saving and recovery fails. To be resilient, this PR repartitions the previous state RDD if the partitioner is not detected. Author: Tathagata Das <tathagata.das1565@gmail.com> Closes #9988 from tdas/SPARK-11932.
Diffstat (limited to 'streaming')
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala2
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala39
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala29
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala189
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala6
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala77
6 files changed, 258 insertions, 84 deletions
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
index fd0e8d5d69..d0046afdeb 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
@@ -277,7 +277,7 @@ class CheckpointWriter(
val bytes = Checkpoint.serialize(checkpoint, conf)
executor.execute(new CheckpointWriteHandler(
checkpoint.checkpointTime, bytes, clearCheckpointDataLater))
- logDebug("Submitted checkpoint of time " + checkpoint.checkpointTime + " writer queue")
+ logInfo("Submitted checkpoint of time " + checkpoint.checkpointTime + " writer queue")
} catch {
case rej: RejectedExecutionException =>
logError("Could not submit checkpoint task to the thread pool executor", rej)
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala
index 0ada1111ce..ea6213420e 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala
@@ -132,22 +132,37 @@ class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassT
/** Method that generates a RDD for the given time */
override def compute(validTime: Time): Option[RDD[TrackStateRDDRecord[K, S, E]]] = {
// Get the previous state or create a new empty state RDD
- val prevStateRDD = getOrCompute(validTime - slideDuration).getOrElse {
- TrackStateRDD.createFromPairRDD[K, V, S, E](
- spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)),
- partitioner, validTime
- )
+ val prevStateRDD = getOrCompute(validTime - slideDuration) match {
+ case Some(rdd) =>
+ if (rdd.partitioner != Some(partitioner)) {
+ // If the RDD is not partitioned the right way, let us repartition it using the
+ // partition index as the key. This is to ensure that state RDD is always partitioned
+ // before creating another state RDD using it
+ TrackStateRDD.createFromRDD[K, V, S, E](
+ rdd.flatMap { _.stateMap.getAll() }, partitioner, validTime)
+ } else {
+ rdd
+ }
+ case None =>
+ TrackStateRDD.createFromPairRDD[K, V, S, E](
+ spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)),
+ partitioner,
+ validTime
+ )
}
+
// Compute the new state RDD with previous state RDD and partitioned data RDD
- parent.getOrCompute(validTime).map { dataRDD =>
- val partitionedDataRDD = dataRDD.partitionBy(partitioner)
- val timeoutThresholdTime = spec.getTimeoutInterval().map { interval =>
- (validTime - interval).milliseconds
- }
- new TrackStateRDD(
- prevStateRDD, partitionedDataRDD, trackingFunction, validTime, timeoutThresholdTime)
+ // Even if there is no data RDD, use an empty one to create a new state RDD
+ val dataRDD = parent.getOrCompute(validTime).getOrElse {
+ context.sparkContext.emptyRDD[(K, V)]
+ }
+ val partitionedDataRDD = dataRDD.partitionBy(partitioner)
+ val timeoutThresholdTime = spec.getTimeoutInterval().map { interval =>
+ (validTime - interval).milliseconds
}
+ Some(new TrackStateRDD(
+ prevStateRDD, partitionedDataRDD, trackingFunction, validTime, timeoutThresholdTime))
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala
index 7050378d0f..30aafcf146 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala
@@ -179,22 +179,43 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E:
private[streaming] object TrackStateRDD {
- def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
+ def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
pairRDD: RDD[(K, S)],
partitioner: Partitioner,
- updateTime: Time): TrackStateRDD[K, V, S, T] = {
+ updateTime: Time): TrackStateRDD[K, V, S, E] = {
val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions ({ iterator =>
val stateMap = StateMap.create[K, S](SparkEnv.get.conf)
iterator.foreach { case (key, state) => stateMap.put(key, state, updateTime.milliseconds) }
- Iterator(TrackStateRDDRecord(stateMap, Seq.empty[T]))
+ Iterator(TrackStateRDDRecord(stateMap, Seq.empty[E]))
}, preservesPartitioning = true)
val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner)
val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None
- new TrackStateRDD[K, V, S, T](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None)
+ new TrackStateRDD[K, V, S, E](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None)
+ }
+
+ def createFromRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
+ rdd: RDD[(K, S, Long)],
+ partitioner: Partitioner,
+ updateTime: Time): TrackStateRDD[K, V, S, E] = {
+
+ val pairRDD = rdd.map { x => (x._1, (x._2, x._3)) }
+ val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions({ iterator =>
+ val stateMap = StateMap.create[K, S](SparkEnv.get.conf)
+ iterator.foreach { case (key, (state, updateTime)) =>
+ stateMap.put(key, state, updateTime)
+ }
+ Iterator(TrackStateRDDRecord(stateMap, Seq.empty[E]))
+ }, preservesPartitioning = true)
+
+ val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner)
+
+ val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None
+
+ new TrackStateRDD[K, V, S, E](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None)
}
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
index b1cbc7163b..cd28d3cf40 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
@@ -33,17 +33,149 @@ import org.mockito.Mockito.mock
import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._
-import org.apache.spark.TestUtils
+import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite, TestUtils}
import org.apache.spark.streaming.dstream.{DStream, FileInputDStream}
import org.apache.spark.streaming.scheduler._
import org.apache.spark.util.{MutableURLClassLoader, Clock, ManualClock, Utils}
/**
+ * A trait of that can be mixed in to get methods for testing DStream operations under
+ * DStream checkpointing. Note that the implementations of this trait has to implement
+ * the `setupCheckpointOperation`
+ */
+trait DStreamCheckpointTester { self: SparkFunSuite =>
+
+ /**
+ * Tests a streaming operation under checkpointing, by restarting the operation
+ * from checkpoint file and verifying whether the final output is correct.
+ * The output is assumed to have come from a reliable queue which an replay
+ * data as required.
+ *
+ * NOTE: This takes into consideration that the last batch processed before
+ * master failure will be re-processed after restart/recovery.
+ */
+ protected def testCheckpointedOperation[U: ClassTag, V: ClassTag](
+ input: Seq[Seq[U]],
+ operation: DStream[U] => DStream[V],
+ expectedOutput: Seq[Seq[V]],
+ numBatchesBeforeRestart: Int,
+ batchDuration: Duration = Milliseconds(500),
+ stopSparkContextAfterTest: Boolean = true
+ ) {
+ require(numBatchesBeforeRestart < expectedOutput.size,
+ "Number of batches before context restart less than number of expected output " +
+ "(i.e. number of total batches to run)")
+ require(StreamingContext.getActive().isEmpty,
+ "Cannot run test with already active streaming context")
+
+ // Current code assumes that number of batches to be run = number of inputs
+ val totalNumBatches = input.size
+ val batchDurationMillis = batchDuration.milliseconds
+
+ // Setup the stream computation
+ val checkpointDir = Utils.createTempDir(this.getClass.getSimpleName()).toString
+ logDebug(s"Using checkpoint directory $checkpointDir")
+ val ssc = createContextForCheckpointOperation(batchDuration)
+ require(ssc.conf.get("spark.streaming.clock") === classOf[ManualClock].getName,
+ "Cannot run test without manual clock in the conf")
+
+ val inputStream = new TestInputStream(ssc, input, numPartitions = 2)
+ val operatedStream = operation(inputStream)
+ operatedStream.print()
+ val outputStream = new TestOutputStreamWithPartitions(operatedStream,
+ new ArrayBuffer[Seq[Seq[V]]] with SynchronizedBuffer[Seq[Seq[V]]])
+ outputStream.register()
+ ssc.checkpoint(checkpointDir)
+
+ // Do the computation for initial number of batches, create checkpoint file and quit
+ val beforeRestartOutput = generateOutput[V](ssc,
+ Time(batchDurationMillis * numBatchesBeforeRestart), checkpointDir, stopSparkContextAfterTest)
+ assertOutput(beforeRestartOutput, expectedOutput, beforeRestart = true)
+ // Restart and complete the computation from checkpoint file
+ logInfo(
+ "\n-------------------------------------------\n" +
+ " Restarting stream computation " +
+ "\n-------------------------------------------\n"
+ )
+
+ val restartedSsc = new StreamingContext(checkpointDir)
+ val afterRestartOutput = generateOutput[V](restartedSsc,
+ Time(batchDurationMillis * totalNumBatches), checkpointDir, stopSparkContextAfterTest)
+ assertOutput(afterRestartOutput, expectedOutput, beforeRestart = false)
+ }
+
+ protected def createContextForCheckpointOperation(batchDuration: Duration): StreamingContext = {
+ val conf = new SparkConf().setMaster("local").setAppName(this.getClass.getSimpleName)
+ conf.set("spark.streaming.clock", classOf[ManualClock].getName())
+ new StreamingContext(SparkContext.getOrCreate(conf), batchDuration)
+ }
+
+ private def generateOutput[V: ClassTag](
+ ssc: StreamingContext,
+ targetBatchTime: Time,
+ checkpointDir: String,
+ stopSparkContext: Boolean
+ ): Seq[Seq[V]] = {
+ try {
+ val batchDuration = ssc.graph.batchDuration
+ val batchCounter = new BatchCounter(ssc)
+ ssc.start()
+ val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
+ val currentTime = clock.getTimeMillis()
+
+ logInfo("Manual clock before advancing = " + clock.getTimeMillis())
+ clock.setTime(targetBatchTime.milliseconds)
+ logInfo("Manual clock after advancing = " + clock.getTimeMillis())
+
+ val outputStream = ssc.graph.getOutputStreams().filter { dstream =>
+ dstream.isInstanceOf[TestOutputStreamWithPartitions[V]]
+ }.head.asInstanceOf[TestOutputStreamWithPartitions[V]]
+
+ eventually(timeout(10 seconds)) {
+ ssc.awaitTerminationOrTimeout(10)
+ assert(batchCounter.getLastCompletedBatchTime === targetBatchTime)
+ }
+
+ eventually(timeout(10 seconds)) {
+ val checkpointFilesOfLatestTime = Checkpoint.getCheckpointFiles(checkpointDir).filter {
+ _.toString.contains(clock.getTimeMillis.toString)
+ }
+ // Checkpoint files are written twice for every batch interval. So assert that both
+ // are written to make sure that both of them have been written.
+ assert(checkpointFilesOfLatestTime.size === 2)
+ }
+ outputStream.output.map(_.flatten)
+
+ } finally {
+ ssc.stop(stopSparkContext = stopSparkContext)
+ }
+ }
+
+ private def assertOutput[V: ClassTag](
+ output: Seq[Seq[V]],
+ expectedOutput: Seq[Seq[V]],
+ beforeRestart: Boolean): Unit = {
+ val expectedPartialOutput = if (beforeRestart) {
+ expectedOutput.take(output.size)
+ } else {
+ expectedOutput.takeRight(output.size)
+ }
+ val setComparison = output.zip(expectedPartialOutput).forall {
+ case (o, e) => o.toSet === e.toSet
+ }
+ assert(setComparison, s"set comparison failed\n" +
+ s"Expected output items:\n${expectedPartialOutput.mkString("\n")}\n" +
+ s"Generated output items: ${output.mkString("\n")}"
+ )
+ }
+}
+
+/**
* This test suites tests the checkpointing functionality of DStreams -
* the checkpointing of a DStream's RDDs as well as the checkpointing of
* the whole DStream graph.
*/
-class CheckpointSuite extends TestSuiteBase {
+class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester {
var ssc: StreamingContext = null
@@ -56,7 +188,7 @@ class CheckpointSuite extends TestSuiteBase {
override def afterFunction() {
super.afterFunction()
- if (ssc != null) ssc.stop()
+ if (ssc != null) { ssc.stop() }
Utils.deleteRecursively(new File(checkpointDir))
}
@@ -251,7 +383,9 @@ class CheckpointSuite extends TestSuiteBase {
Seq(("", 2)),
Seq(),
Seq(("a", 2), ("b", 1)),
- Seq(("", 2)), Seq() ),
+ Seq(("", 2)),
+ Seq()
+ ),
3
)
}
@@ -635,53 +769,6 @@ class CheckpointSuite extends TestSuiteBase {
}
/**
- * Tests a streaming operation under checkpointing, by restarting the operation
- * from checkpoint file and verifying whether the final output is correct.
- * The output is assumed to have come from a reliable queue which an replay
- * data as required.
- *
- * NOTE: This takes into consideration that the last batch processed before
- * master failure will be re-processed after restart/recovery.
- */
- def testCheckpointedOperation[U: ClassTag, V: ClassTag](
- input: Seq[Seq[U]],
- operation: DStream[U] => DStream[V],
- expectedOutput: Seq[Seq[V]],
- initialNumBatches: Int
- ) {
-
- // Current code assumes that:
- // number of inputs = number of outputs = number of batches to be run
- val totalNumBatches = input.size
- val nextNumBatches = totalNumBatches - initialNumBatches
- val initialNumExpectedOutputs = initialNumBatches
- val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs + 1
- // because the last batch will be processed again
-
- // Do the computation for initial number of batches, create checkpoint file and quit
- ssc = setupStreams[U, V](input, operation)
- ssc.start()
- val output = advanceTimeWithRealDelay[V](ssc, initialNumBatches)
- ssc.stop()
- verifyOutput[V](output, expectedOutput.take(initialNumBatches), true)
- Thread.sleep(1000)
-
- // Restart and complete the computation from checkpoint file
- logInfo(
- "\n-------------------------------------------\n" +
- " Restarting stream computation " +
- "\n-------------------------------------------\n"
- )
- ssc = new StreamingContext(checkpointDir)
- ssc.start()
- val outputNew = advanceTimeWithRealDelay[V](ssc, nextNumBatches)
- // the first element will be re-processed data of the last batch before restart
- verifyOutput[V](outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), true)
- ssc.stop()
- ssc = null
- }
-
- /**
* Advances the manual clock on the streaming scheduler by given number of batches.
* It also waits for the expected amount of time for each batch.
*/
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
index a45c92d9c7..be0f4636a6 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
@@ -142,6 +142,7 @@ class BatchCounter(ssc: StreamingContext) {
// All access to this state should be guarded by `BatchCounter.this.synchronized`
private var numCompletedBatches = 0
private var numStartedBatches = 0
+ private var lastCompletedBatchTime: Time = null
private val listener = new StreamingListener {
override def onBatchStarted(batchStarted: StreamingListenerBatchStarted): Unit =
@@ -152,6 +153,7 @@ class BatchCounter(ssc: StreamingContext) {
override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted): Unit =
BatchCounter.this.synchronized {
numCompletedBatches += 1
+ lastCompletedBatchTime = batchCompleted.batchInfo.batchTime
BatchCounter.this.notifyAll()
}
}
@@ -165,6 +167,10 @@ class BatchCounter(ssc: StreamingContext) {
numStartedBatches
}
+ def getLastCompletedBatchTime: Time = this.synchronized {
+ lastCompletedBatchTime
+ }
+
/**
* Wait until `expectedNumCompletedBatches` batches are completed, or timeout. Return true if
* `expectedNumCompletedBatches` batches are completed. Otherwise, return false to indicate it's
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala
index 58aef74c00..1fc320d31b 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/TrackStateByKeySuite.scala
@@ -25,31 +25,27 @@ import scala.reflect.ClassTag
import org.scalatest.PrivateMethodTester._
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
-import org.apache.spark.streaming.dstream.{InternalTrackStateDStream, TrackStateDStream, TrackStateDStreamImpl}
+import org.apache.spark.streaming.dstream.{DStream, InternalTrackStateDStream, TrackStateDStream, TrackStateDStreamImpl}
import org.apache.spark.util.{ManualClock, Utils}
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
-class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter {
+class TrackStateByKeySuite extends SparkFunSuite
+ with DStreamCheckpointTester with BeforeAndAfterAll with BeforeAndAfter {
private var sc: SparkContext = null
- private var ssc: StreamingContext = null
- private var checkpointDir: File = null
- private val batchDuration = Seconds(1)
+ protected var checkpointDir: File = null
+ protected val batchDuration = Seconds(1)
before {
- StreamingContext.getActive().foreach {
- _.stop(stopSparkContext = false)
- }
+ StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) }
checkpointDir = Utils.createTempDir("checkpoint")
-
- ssc = new StreamingContext(sc, batchDuration)
- ssc.checkpoint(checkpointDir.toString)
}
after {
- StreamingContext.getActive().foreach {
- _.stop(stopSparkContext = false)
+ if (checkpointDir != null) {
+ Utils.deleteRecursively(checkpointDir)
}
+ StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) }
}
override def beforeAll(): Unit = {
@@ -242,7 +238,7 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef
assert(dstreamImpl.stateClass === classOf[Double])
assert(dstreamImpl.emittedClass === classOf[Long])
}
-
+ val ssc = new StreamingContext(sc, batchDuration)
val inputStream = new TestInputStream[(String, Int)](ssc, Seq.empty, numPartitions = 2)
// Defining StateSpec inline with trackStateByKey and simple function implicitly gets the types
@@ -451,8 +447,9 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef
expectedCheckpointDuration: Duration,
explicitCheckpointDuration: Option[Duration] = None
): Unit = {
+ val ssc = new StreamingContext(sc, batchDuration)
+
try {
- ssc = new StreamingContext(sc, batchDuration)
val inputStream = new TestInputStream(ssc, Seq.empty[Seq[Int]], 2).map(_ -> 1)
val dummyFunc = (value: Option[Int], state: State[Int]) => 0
val trackStateStream = inputStream.trackStateByKey(StateSpec.function(dummyFunc))
@@ -462,11 +459,12 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef
trackStateStream.checkpoint(d)
}
trackStateStream.register()
+ ssc.checkpoint(checkpointDir.toString)
ssc.start() // should initialize all the checkpoint durations
assert(trackStateStream.checkpointDuration === null)
assert(internalTrackStateStream.checkpointDuration === expectedCheckpointDuration)
} finally {
- StreamingContext.getActive().foreach { _.stop(stopSparkContext = false) }
+ ssc.stop(stopSparkContext = false)
}
}
@@ -479,6 +477,50 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef
testCheckpointDuration(Seconds(10), Seconds(20), Some(Seconds(20)))
}
+
+ test("trackStateByKey - driver failure recovery") {
+ val inputData =
+ Seq(
+ Seq(),
+ Seq("a"),
+ Seq("a", "b"),
+ Seq("a", "b", "c"),
+ Seq("a", "b"),
+ Seq("a"),
+ Seq()
+ )
+
+ val stateData =
+ Seq(
+ Seq(),
+ Seq(("a", 1)),
+ Seq(("a", 2), ("b", 1)),
+ Seq(("a", 3), ("b", 2), ("c", 1)),
+ Seq(("a", 4), ("b", 3), ("c", 1)),
+ Seq(("a", 5), ("b", 3), ("c", 1)),
+ Seq(("a", 5), ("b", 3), ("c", 1))
+ )
+
+ def operation(dstream: DStream[String]): DStream[(String, Int)] = {
+
+ val checkpointDuration = batchDuration * (stateData.size / 2)
+
+ val runningCount = (value: Option[Int], state: State[Int]) => {
+ state.update(state.getOption().getOrElse(0) + value.getOrElse(0))
+ state.get()
+ }
+
+ val trackStateStream = dstream.map { _ -> 1 }.trackStateByKey(
+ StateSpec.function(runningCount))
+ // Set internval make sure there is one RDD checkpointing
+ trackStateStream.checkpoint(checkpointDuration)
+ trackStateStream.stateSnapshots()
+ }
+
+ testCheckpointedOperation(inputData, operation, stateData, inputData.size / 2,
+ batchDuration = batchDuration, stopSparkContextAfterTest = false)
+ }
+
private def testOperation[K: ClassTag, S: ClassTag, T: ClassTag](
input: Seq[Seq[K]],
trackStateSpec: StateSpec[K, Int, S, T],
@@ -500,6 +542,7 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef
): (Seq[Seq[T]], Seq[Seq[(K, S)]]) = {
// Setup the stream computation
+ val ssc = new StreamingContext(sc, Seconds(1))
val inputStream = new TestInputStream(ssc, input, numPartitions = 2)
val trackeStateStream = inputStream.map(x => (x, 1)).trackStateByKey(trackStateSpec)
val collectedOutputs = new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]]
@@ -511,12 +554,14 @@ class TrackStateByKeySuite extends SparkFunSuite with BeforeAndAfterAll with Bef
stateSnapshotStream.register()
val batchCounter = new BatchCounter(ssc)
+ ssc.checkpoint(checkpointDir.toString)
ssc.start()
val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
clock.advance(batchDuration.milliseconds * numBatches)
batchCounter.waitUntilBatchesCompleted(numBatches, 10000)
+ ssc.stop(stopSparkContext = false)
(collectedOutputs, collectedStateSnapshots)
}