aboutsummaryrefslogtreecommitdiff
path: root/streaming
diff options
context:
space:
mode:
authorTathagata Das <tathagata.das1565@gmail.com>2014-04-08 00:00:17 -0700
committerPatrick Wendell <pwendell@gmail.com>2014-04-08 00:00:17 -0700
commit83ac9a4bbf272028d0c4639cbd1e12022b9ae77a (patch)
tree00a8c8f581e2f681463126a7e8a34993a790c692 /streaming
parent11eabbe125b2ee572fad359c33c93f5e6fdf0b2d (diff)
downloadspark-83ac9a4bbf272028d0c4639cbd1e12022b9ae77a.tar.gz
spark-83ac9a4bbf272028d0c4639cbd1e12022b9ae77a.tar.bz2
spark-83ac9a4bbf272028d0c4639cbd1e12022b9ae77a.zip
[SPARK-1331] Added graceful shutdown to Spark Streaming
Current version of StreamingContext.stop() directly kills all the data receivers (NetworkReceiver) without waiting for the data already received to be persisted and processed. This PR provides the fix. Now, when the StreamingContext.stop() is called, the following sequence of steps will happen. 1. The driver will send a stop signal to all the active receivers. 2. Each receiver, when it gets a stop signal from the driver, first stop receiving more data, then waits for the thread that persists data blocks to BlockManager to finish persisting all receive data, and finally quits. 3. After all the receivers have stopped, the driver will wait for the Job Generator and Job Scheduler to finish processing all the received data. It also fixes the semantics of StreamingContext.start and stop. It will throw appropriate errors and warnings if stop() is called before start(), stop() is called twice, etc. Author: Tathagata Das <tathagata.das1565@gmail.com> Closes #247 from tdas/graceful-shutdown and squashes the following commits: 61c0016 [Tathagata Das] Updated MIMA binary check excludes. ae1d39b [Tathagata Das] Merge remote-tracking branch 'apache-github/master' into graceful-shutdown 6b59cfc [Tathagata Das] Minor changes based on Andrew's comment on PR. d0b8d65 [Tathagata Das] Reduced time taken by graceful shutdown unit test. f55bc67 [Tathagata Das] Fix scalastyle c69b3a7 [Tathagata Das] Updates based on Patrick's comments. c43b8ae [Tathagata Das] Added graceful shutdown to Spark Streaming.
Diffstat (limited to 'streaming')
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala14
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala48
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala12
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala151
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala1
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala2
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala124
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala56
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala154
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala5
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala62
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala4
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala108
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala2
14 files changed, 539 insertions, 204 deletions
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
index baf80fe2a9..93023e8dce 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
@@ -194,19 +194,19 @@ class CheckpointWriter(
}
}
- def stop() {
- synchronized {
- if (stopped) {
- return
- }
- stopped = true
- }
+ def stop(): Unit = synchronized {
+ if (stopped) return
+
executor.shutdown()
val startTime = System.currentTimeMillis()
val terminated = executor.awaitTermination(10, java.util.concurrent.TimeUnit.SECONDS)
+ if (!terminated) {
+ executor.shutdownNow()
+ }
val endTime = System.currentTimeMillis()
logInfo("CheckpointWriter executor terminated ? " + terminated +
", waited for " + (endTime - startTime) + " ms.")
+ stopped = true
}
private def fs = synchronized {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
index e198c69470..a4e236c65f 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
@@ -158,6 +158,15 @@ class StreamingContext private[streaming] (
private[streaming] val waiter = new ContextWaiter
+ /** Enumeration to identify current state of the StreamingContext */
+ private[streaming] object StreamingContextState extends Enumeration {
+ type CheckpointState = Value
+ val Initialized, Started, Stopped = Value
+ }
+
+ import StreamingContextState._
+ private[streaming] var state = Initialized
+
/**
* Return the associated Spark context
*/
@@ -405,9 +414,18 @@ class StreamingContext private[streaming] (
/**
* Start the execution of the streams.
*/
- def start() = synchronized {
+ def start(): Unit = synchronized {
+ // Throw exception if the context has already been started once
+ // or if a stopped context is being started again
+ if (state == Started) {
+ throw new SparkException("StreamingContext has already been started")
+ }
+ if (state == Stopped) {
+ throw new SparkException("StreamingContext has already been stopped")
+ }
validate()
scheduler.start()
+ state = Started
}
/**
@@ -428,14 +446,38 @@ class StreamingContext private[streaming] (
}
/**
- * Stop the execution of the streams.
+ * Stop the execution of the streams immediately (does not wait for all received data
+ * to be processed).
* @param stopSparkContext Stop the associated SparkContext or not
+ *
*/
def stop(stopSparkContext: Boolean = true): Unit = synchronized {
- scheduler.stop()
+ stop(stopSparkContext, false)
+ }
+
+ /**
+ * Stop the execution of the streams, with option of ensuring all received data
+ * has been processed.
+ * @param stopSparkContext Stop the associated SparkContext or not
+ * @param stopGracefully Stop gracefully by waiting for the processing of all
+ * received data to be completed
+ */
+ def stop(stopSparkContext: Boolean, stopGracefully: Boolean): Unit = synchronized {
+ // Warn (but not fail) if context is stopped twice,
+ // or context is stopped before starting
+ if (state == Initialized) {
+ logWarning("StreamingContext has not been started yet")
+ return
+ }
+ if (state == Stopped) {
+ logWarning("StreamingContext has already been stopped")
+ return
+ } // no need to throw an exception as its okay to stop twice
+ scheduler.stop(stopGracefully)
logInfo("StreamingContext stopped successfully")
waiter.notifyStop()
if (stopSparkContext) sc.stop()
+ state = Stopped
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
index b705d2ec9a..c800602d09 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
@@ -509,8 +509,16 @@ class JavaStreamingContext(val ssc: StreamingContext) {
* Stop the execution of the streams.
* @param stopSparkContext Stop the associated SparkContext or not
*/
- def stop(stopSparkContext: Boolean): Unit = {
- ssc.stop(stopSparkContext)
+ def stop(stopSparkContext: Boolean) = ssc.stop(stopSparkContext)
+
+ /**
+ * Stop the execution of the streams.
+ * @param stopSparkContext Stop the associated SparkContext or not
+ * @param stopGracefully Stop gracefully by waiting for the processing of all
+ * received data to be completed
+ */
+ def stop(stopSparkContext: Boolean, stopGracefully: Boolean) = {
+ ssc.stop(stopSparkContext, stopGracefully)
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
index 72ad0bae75..d19a635fe8 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
@@ -17,7 +17,7 @@
package org.apache.spark.streaming.dstream
-import java.util.concurrent.ArrayBlockingQueue
+import java.util.concurrent.{TimeUnit, ArrayBlockingQueue}
import java.nio.ByteBuffer
import scala.collection.mutable.ArrayBuffer
@@ -34,6 +34,7 @@ import org.apache.spark.{Logging, SparkEnv}
import org.apache.spark.rdd.{RDD, BlockRDD}
import org.apache.spark.storage.{BlockId, StorageLevel, StreamBlockId}
import org.apache.spark.streaming.scheduler.{DeregisterReceiver, AddBlocks, RegisterReceiver}
+import org.apache.spark.util.AkkaUtils
/**
* Abstract class for defining any [[org.apache.spark.streaming.dstream.InputDStream]]
@@ -69,7 +70,7 @@ abstract class NetworkInputDStream[T: ClassTag](@transient ssc_ : StreamingConte
// then this returns an empty RDD. This may happen when recovering from a
// master failure
if (validTime >= graph.startTime) {
- val blockIds = ssc.scheduler.networkInputTracker.getBlockIds(id, validTime)
+ val blockIds = ssc.scheduler.networkInputTracker.getBlocks(id, validTime)
Some(new BlockRDD[T](ssc.sc, blockIds))
} else {
Some(new BlockRDD[T](ssc.sc, Array[BlockId]()))
@@ -79,7 +80,7 @@ abstract class NetworkInputDStream[T: ClassTag](@transient ssc_ : StreamingConte
private[streaming] sealed trait NetworkReceiverMessage
-private[streaming] case class StopReceiver(msg: String) extends NetworkReceiverMessage
+private[streaming] case class StopReceiver() extends NetworkReceiverMessage
private[streaming] case class ReportBlock(blockId: BlockId, metadata: Any)
extends NetworkReceiverMessage
private[streaming] case class ReportError(msg: String) extends NetworkReceiverMessage
@@ -90,13 +91,31 @@ private[streaming] case class ReportError(msg: String) extends NetworkReceiverMe
*/
abstract class NetworkReceiver[T: ClassTag]() extends Serializable with Logging {
+ /** Local SparkEnv */
lazy protected val env = SparkEnv.get
+ /** Remote Akka actor for the NetworkInputTracker */
+ lazy protected val trackerActor = {
+ val ip = env.conf.get("spark.driver.host", "localhost")
+ val port = env.conf.getInt("spark.driver.port", 7077)
+ val url = "akka.tcp://spark@%s:%s/user/NetworkInputTracker".format(ip, port)
+ env.actorSystem.actorSelection(url)
+ }
+
+ /** Akka actor for receiving messages from the NetworkInputTracker in the driver */
lazy protected val actor = env.actorSystem.actorOf(
Props(new NetworkReceiverActor()), "NetworkReceiver-" + streamId)
+ /** Timeout for Akka actor messages */
+ lazy protected val askTimeout = AkkaUtils.askTimeout(env.conf)
+
+ /** Thread that starts the receiver and stays blocked while data is being received */
lazy protected val receivingThread = Thread.currentThread()
+ /** Exceptions that occurs while receiving data */
+ protected lazy val exceptions = new ArrayBuffer[Exception]
+
+ /** Identifier of the stream this receiver is associated with */
protected var streamId: Int = -1
/**
@@ -112,7 +131,7 @@ abstract class NetworkReceiver[T: ClassTag]() extends Serializable with Logging
def getLocationPreference() : Option[String] = None
/**
- * Starts the receiver. First is accesses all the lazy members to
+ * Start the receiver. First is accesses all the lazy members to
* materialize them. Then it calls the user-defined onStart() method to start
* other threads, etc required to receiver the data.
*/
@@ -124,83 +143,107 @@ abstract class NetworkReceiver[T: ClassTag]() extends Serializable with Logging
receivingThread
// Call user-defined onStart()
+ logInfo("Starting receiver")
onStart()
+
+ // Wait until interrupt is called on this thread
+ while(true) Thread.sleep(100000)
} catch {
case ie: InterruptedException =>
- logInfo("Receiving thread interrupted")
+ logInfo("Receiving thread has been interrupted, receiver " + streamId + " stopped")
case e: Exception =>
- stopOnError(e)
+ logError("Error receiving data in receiver " + streamId, e)
+ exceptions += e
+ }
+
+ // Call user-defined onStop()
+ logInfo("Stopping receiver")
+ try {
+ onStop()
+ } catch {
+ case e: Exception =>
+ logError("Error stopping receiver " + streamId, e)
+ exceptions += e
+ }
+
+ val message = if (exceptions.isEmpty) {
+ null
+ } else if (exceptions.size == 1) {
+ val e = exceptions.head
+ "Exception in receiver " + streamId + ": " + e.getMessage + "\n" + e.getStackTraceString
+ } else {
+ "Multiple exceptions in receiver " + streamId + "(" + exceptions.size + "):\n"
+ exceptions.zipWithIndex.map {
+ case (e, i) => "Exception " + i + ": " + e.getMessage + "\n" + e.getStackTraceString
+ }.mkString("\n")
}
+ logInfo("Deregistering receiver " + streamId)
+ val future = trackerActor.ask(DeregisterReceiver(streamId, message))(askTimeout)
+ Await.result(future, askTimeout)
+ logInfo("Deregistered receiver " + streamId)
+ env.actorSystem.stop(actor)
+ logInfo("Stopped receiver " + streamId)
}
/**
- * Stops the receiver. First it interrupts the main receiving thread,
- * that is, the thread that called receiver.start(). Then it calls the user-defined
- * onStop() method to stop other threads and/or do cleanup.
+ * Stop the receiver. First it interrupts the main receiving thread,
+ * that is, the thread that called receiver.start().
*/
def stop() {
+ // Stop receiving by interrupting the receiving thread
receivingThread.interrupt()
- onStop()
- // TODO: terminate the actor
+ logInfo("Interrupted receiving thread " + receivingThread + " for stopping")
}
/**
- * Stops the receiver and reports exception to the tracker.
+ * Stop the receiver and reports exception to the tracker.
* This should be called whenever an exception is to be handled on any thread
* of the receiver.
*/
protected def stopOnError(e: Exception) {
logError("Error receiving data", e)
+ exceptions += e
stop()
- actor ! ReportError(e.toString)
}
-
/**
- * Pushes a block (as an ArrayBuffer filled with data) into the block manager.
+ * Push a block (as an ArrayBuffer filled with data) into the block manager.
*/
def pushBlock(blockId: BlockId, arrayBuffer: ArrayBuffer[T], metadata: Any, level: StorageLevel) {
env.blockManager.put(blockId, arrayBuffer.asInstanceOf[ArrayBuffer[Any]], level)
- actor ! ReportBlock(blockId, metadata)
+ trackerActor ! AddBlocks(streamId, Array(blockId), metadata)
+ logDebug("Pushed block " + blockId)
}
/**
- * Pushes a block (as bytes) into the block manager.
+ * Push a block (as bytes) into the block manager.
*/
def pushBlock(blockId: BlockId, bytes: ByteBuffer, metadata: Any, level: StorageLevel) {
env.blockManager.putBytes(blockId, bytes, level)
- actor ! ReportBlock(blockId, metadata)
+ trackerActor ! AddBlocks(streamId, Array(blockId), metadata)
+ }
+
+ /** Set the ID of the DStream that this receiver is associated with */
+ protected[streaming] def setStreamId(id: Int) {
+ streamId = id
}
/** A helper actor that communicates with the NetworkInputTracker */
private class NetworkReceiverActor extends Actor {
- logInfo("Attempting to register with tracker")
- val ip = env.conf.get("spark.driver.host", "localhost")
- val port = env.conf.getInt("spark.driver.port", 7077)
- val url = "akka.tcp://spark@%s:%s/user/NetworkInputTracker".format(ip, port)
- val tracker = env.actorSystem.actorSelection(url)
- val timeout = 5.seconds
override def preStart() {
- val future = tracker.ask(RegisterReceiver(streamId, self))(timeout)
- Await.result(future, timeout)
+ logInfo("Registered receiver " + streamId)
+ val future = trackerActor.ask(RegisterReceiver(streamId, self))(askTimeout)
+ Await.result(future, askTimeout)
}
override def receive() = {
- case ReportBlock(blockId, metadata) =>
- tracker ! AddBlocks(streamId, Array(blockId), metadata)
- case ReportError(msg) =>
- tracker ! DeregisterReceiver(streamId, msg)
- case StopReceiver(msg) =>
+ case StopReceiver =>
+ logInfo("Received stop signal")
stop()
- tracker ! DeregisterReceiver(streamId, msg)
}
}
- protected[streaming] def setStreamId(id: Int) {
- streamId = id
- }
-
/**
* Batches objects created by a [[org.apache.spark.streaming.dstream.NetworkReceiver]] and puts
* them into appropriately named blocks at regular intervals. This class starts two threads,
@@ -214,23 +257,26 @@ abstract class NetworkReceiver[T: ClassTag]() extends Serializable with Logging
val clock = new SystemClock()
val blockInterval = env.conf.getLong("spark.streaming.blockInterval", 200)
- val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer)
+ val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer,
+ "BlockGenerator")
val blockStorageLevel = storageLevel
val blocksForPushing = new ArrayBlockingQueue[Block](1000)
val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } }
var currentBuffer = new ArrayBuffer[T]
+ var stopped = false
def start() {
blockIntervalTimer.start()
blockPushingThread.start()
- logInfo("Data handler started")
+ logInfo("Started BlockGenerator")
}
def stop() {
- blockIntervalTimer.stop()
- blockPushingThread.interrupt()
- logInfo("Data handler stopped")
+ blockIntervalTimer.stop(false)
+ stopped = true
+ blockPushingThread.join()
+ logInfo("Stopped BlockGenerator")
}
def += (obj: T): Unit = synchronized {
@@ -248,24 +294,35 @@ abstract class NetworkReceiver[T: ClassTag]() extends Serializable with Logging
}
} catch {
case ie: InterruptedException =>
- logInfo("Block interval timer thread interrupted")
+ logInfo("Block updating timer thread was interrupted")
case e: Exception =>
- NetworkReceiver.this.stop()
+ NetworkReceiver.this.stopOnError(e)
}
}
private def keepPushingBlocks() {
- logInfo("Block pushing thread started")
+ logInfo("Started block pushing thread")
try {
- while(true) {
+ while(!stopped) {
+ Option(blocksForPushing.poll(100, TimeUnit.MILLISECONDS)) match {
+ case Some(block) =>
+ NetworkReceiver.this.pushBlock(block.id, block.buffer, block.metadata, storageLevel)
+ case None =>
+ }
+ }
+ // Push out the blocks that are still left
+ logInfo("Pushing out the last " + blocksForPushing.size() + " blocks")
+ while (!blocksForPushing.isEmpty) {
val block = blocksForPushing.take()
NetworkReceiver.this.pushBlock(block.id, block.buffer, block.metadata, storageLevel)
+ logInfo("Blocks left to push " + blocksForPushing.size())
}
+ logInfo("Stopped blocks pushing thread")
} catch {
case ie: InterruptedException =>
- logInfo("Block pushing thread interrupted")
+ logInfo("Block pushing thread was interrupted")
case e: Exception =>
- NetworkReceiver.this.stop()
+ NetworkReceiver.this.stopOnError(e)
}
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala
index 2cdd13f205..63d94d1cc6 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala
@@ -67,7 +67,6 @@ class SocketReceiver[T: ClassTag](
protected def onStop() {
blockGenerator.stop()
}
-
}
private[streaming]
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala
index bd78bae8a5..44eb2750c6 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala
@@ -174,10 +174,10 @@ private[streaming] class ActorReceiver[T: ClassTag](
blocksGenerator.start()
supervisor
logInfo("Supervision tree for receivers initialized at:" + supervisor.path)
+
}
protected def onStop() = {
supervisor ! PoisonPill
}
-
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
index c7306248b1..92d885c4bc 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
@@ -39,16 +39,22 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
private val ssc = jobScheduler.ssc
private val graph = ssc.graph
+
val clock = {
val clockClass = ssc.sc.conf.get(
"spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock")
Class.forName(clockClass).newInstance().asInstanceOf[Clock]
}
+
private val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds,
- longTime => eventActor ! GenerateJobs(new Time(longTime)))
- private lazy val checkpointWriter =
- if (ssc.checkpointDuration != null && ssc.checkpointDir != null) {
- new CheckpointWriter(this, ssc.conf, ssc.checkpointDir, ssc.sparkContext.hadoopConfiguration)
+ longTime => eventActor ! GenerateJobs(new Time(longTime)), "JobGenerator")
+
+ // This is marked lazy so that this is initialized after checkpoint duration has been set
+ // in the context and the generator has been started.
+ private lazy val shouldCheckpoint = ssc.checkpointDuration != null && ssc.checkpointDir != null
+
+ private lazy val checkpointWriter = if (shouldCheckpoint) {
+ new CheckpointWriter(this, ssc.conf, ssc.checkpointDir, ssc.sparkContext.hadoopConfiguration)
} else {
null
}
@@ -57,17 +63,16 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
// This not being null means the scheduler has been started and not stopped
private var eventActor: ActorRef = null
+ // last batch whose completion,checkpointing and metadata cleanup has been completed
+ private var lastProcessedBatch: Time = null
+
/** Start generation of jobs */
- def start() = synchronized {
- if (eventActor != null) {
- throw new SparkException("JobGenerator already started")
- }
+ def start(): Unit = synchronized {
+ if (eventActor != null) return // generator has already been started
eventActor = ssc.env.actorSystem.actorOf(Props(new Actor {
def receive = {
- case event: JobGeneratorEvent =>
- logDebug("Got event of type " + event.getClass.getName)
- processEvent(event)
+ case event: JobGeneratorEvent => processEvent(event)
}
}), "JobGenerator")
if (ssc.isCheckpointPresent) {
@@ -77,30 +82,79 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
}
}
- /** Stop generation of jobs */
- def stop() = synchronized {
- if (eventActor != null) {
- timer.stop()
- ssc.env.actorSystem.stop(eventActor)
- if (checkpointWriter != null) checkpointWriter.stop()
- ssc.graph.stop()
- logInfo("JobGenerator stopped")
+ /**
+ * Stop generation of jobs. processReceivedData = true makes this wait until jobs
+ * of current ongoing time interval has been generated, processed and corresponding
+ * checkpoints written.
+ */
+ def stop(processReceivedData: Boolean): Unit = synchronized {
+ if (eventActor == null) return // generator has already been stopped
+
+ if (processReceivedData) {
+ logInfo("Stopping JobGenerator gracefully")
+ val timeWhenStopStarted = System.currentTimeMillis()
+ val stopTimeout = 10 * ssc.graph.batchDuration.milliseconds
+ val pollTime = 100
+
+ // To prevent graceful stop to get stuck permanently
+ def hasTimedOut = {
+ val timedOut = System.currentTimeMillis() - timeWhenStopStarted > stopTimeout
+ if (timedOut) logWarning("Timed out while stopping the job generator")
+ timedOut
+ }
+
+ // Wait until all the received blocks in the network input tracker has
+ // been consumed by network input DStreams, and jobs have been generated with them
+ logInfo("Waiting for all received blocks to be consumed for job generation")
+ while(!hasTimedOut && jobScheduler.networkInputTracker.hasMoreReceivedBlockIds) {
+ Thread.sleep(pollTime)
+ }
+ logInfo("Waited for all received blocks to be consumed for job generation")
+
+ // Stop generating jobs
+ val stopTime = timer.stop(false)
+ graph.stop()
+ logInfo("Stopped generation timer")
+
+ // Wait for the jobs to complete and checkpoints to be written
+ def haveAllBatchesBeenProcessed = {
+ lastProcessedBatch != null && lastProcessedBatch.milliseconds == stopTime
+ }
+ logInfo("Waiting for jobs to be processed and checkpoints to be written")
+ while (!hasTimedOut && !haveAllBatchesBeenProcessed) {
+ Thread.sleep(pollTime)
+ }
+ logInfo("Waited for jobs to be processed and checkpoints to be written")
+ } else {
+ logInfo("Stopping JobGenerator immediately")
+ // Stop timer and graph immediately, ignore unprocessed data and pending jobs
+ timer.stop(true)
+ graph.stop()
}
+
+ // Stop the actor and checkpoint writer
+ if (shouldCheckpoint) checkpointWriter.stop()
+ ssc.env.actorSystem.stop(eventActor)
+ logInfo("Stopped JobGenerator")
}
/**
- * On batch completion, clear old metadata and checkpoint computation.
+ * Callback called when a batch has been completely processed.
*/
def onBatchCompletion(time: Time) {
eventActor ! ClearMetadata(time)
}
-
+
+ /**
+ * Callback called when the checkpoint of a batch has been written.
+ */
def onCheckpointCompletion(time: Time) {
eventActor ! ClearCheckpointData(time)
}
/** Processes all events */
private def processEvent(event: JobGeneratorEvent) {
+ logDebug("Got event " + event)
event match {
case GenerateJobs(time) => generateJobs(time)
case ClearMetadata(time) => clearMetadata(time)
@@ -114,7 +168,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
val startTime = new Time(timer.getStartTime())
graph.start(startTime - graph.batchDuration)
timer.start(startTime.milliseconds)
- logInfo("JobGenerator started at " + startTime)
+ logInfo("Started JobGenerator at " + startTime)
}
/** Restarts the generator based on the information in checkpoint */
@@ -152,15 +206,17 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
// Restart the timer
timer.start(restartTime.milliseconds)
- logInfo("JobGenerator restarted at " + restartTime)
+ logInfo("Restarted JobGenerator at " + restartTime)
}
/** Generate jobs and perform checkpoint for the given `time`. */
private def generateJobs(time: Time) {
SparkEnv.set(ssc.env)
Try(graph.generateJobs(time)) match {
- case Success(jobs) => jobScheduler.runJobs(time, jobs)
- case Failure(e) => jobScheduler.reportError("Error generating jobs for time " + time, e)
+ case Success(jobs) =>
+ jobScheduler.runJobs(time, jobs)
+ case Failure(e) =>
+ jobScheduler.reportError("Error generating jobs for time " + time, e)
}
eventActor ! DoCheckpoint(time)
}
@@ -168,20 +224,32 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
/** Clear DStream metadata for the given `time`. */
private def clearMetadata(time: Time) {
ssc.graph.clearMetadata(time)
- eventActor ! DoCheckpoint(time)
+
+ // If checkpointing is enabled, then checkpoint,
+ // else mark batch to be fully processed
+ if (shouldCheckpoint) {
+ eventActor ! DoCheckpoint(time)
+ } else {
+ markBatchFullyProcessed(time)
+ }
}
/** Clear DStream checkpoint data for the given `time`. */
private def clearCheckpointData(time: Time) {
ssc.graph.clearCheckpointData(time)
+ markBatchFullyProcessed(time)
}
/** Perform checkpoint for the give `time`. */
- private def doCheckpoint(time: Time) = synchronized {
- if (checkpointWriter != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointDuration)) {
+ private def doCheckpoint(time: Time) {
+ if (shouldCheckpoint && (time - graph.zeroTime).isMultipleOf(ssc.checkpointDuration)) {
logInfo("Checkpointing graph for time " + time)
ssc.graph.updateCheckpointData(time)
checkpointWriter.write(new Checkpoint(ssc, time))
}
}
+
+ private def markBatchFullyProcessed(time: Time) {
+ lastProcessedBatch = time
+ }
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
index de675d3c7f..04e0a6a283 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
@@ -39,7 +39,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging {
private val jobSets = new ConcurrentHashMap[Time, JobSet]
private val numConcurrentJobs = ssc.conf.getInt("spark.streaming.concurrentJobs", 1)
- private val executor = Executors.newFixedThreadPool(numConcurrentJobs)
+ private val jobExecutor = Executors.newFixedThreadPool(numConcurrentJobs)
private val jobGenerator = new JobGenerator(this)
val clock = jobGenerator.clock
val listenerBus = new StreamingListenerBus()
@@ -50,36 +50,54 @@ class JobScheduler(val ssc: StreamingContext) extends Logging {
private var eventActor: ActorRef = null
- def start() = synchronized {
- if (eventActor != null) {
- throw new SparkException("JobScheduler already started")
- }
+ def start(): Unit = synchronized {
+ if (eventActor != null) return // scheduler has already been started
+ logDebug("Starting JobScheduler")
eventActor = ssc.env.actorSystem.actorOf(Props(new Actor {
def receive = {
case event: JobSchedulerEvent => processEvent(event)
}
}), "JobScheduler")
+
listenerBus.start()
networkInputTracker = new NetworkInputTracker(ssc)
networkInputTracker.start()
- Thread.sleep(1000)
jobGenerator.start()
- logInfo("JobScheduler started")
+ logInfo("Started JobScheduler")
}
- def stop() = synchronized {
- if (eventActor != null) {
- jobGenerator.stop()
- networkInputTracker.stop()
- executor.shutdown()
- if (!executor.awaitTermination(2, TimeUnit.SECONDS)) {
- executor.shutdownNow()
- }
- listenerBus.stop()
- ssc.env.actorSystem.stop(eventActor)
- logInfo("JobScheduler stopped")
+ def stop(processAllReceivedData: Boolean): Unit = synchronized {
+ if (eventActor == null) return // scheduler has already been stopped
+ logDebug("Stopping JobScheduler")
+
+ // First, stop receiving
+ networkInputTracker.stop()
+
+ // Second, stop generating jobs. If it has to process all received data,
+ // then this will wait for all the processing through JobScheduler to be over.
+ jobGenerator.stop(processAllReceivedData)
+
+ // Stop the executor for receiving new jobs
+ logDebug("Stopping job executor")
+ jobExecutor.shutdown()
+
+ // Wait for the queued jobs to complete if indicated
+ val terminated = if (processAllReceivedData) {
+ jobExecutor.awaitTermination(1, TimeUnit.HOURS) // just a very large period of time
+ } else {
+ jobExecutor.awaitTermination(2, TimeUnit.SECONDS)
}
+ if (!terminated) {
+ jobExecutor.shutdownNow()
+ }
+ logDebug("Stopped job executor")
+
+ // Stop everything else
+ listenerBus.stop()
+ ssc.env.actorSystem.stop(eventActor)
+ eventActor = null
+ logInfo("Stopped JobScheduler")
}
def runJobs(time: Time, jobs: Seq[Job]) {
@@ -88,7 +106,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging {
} else {
val jobSet = new JobSet(time, jobs)
jobSets.put(time, jobSet)
- jobSet.jobs.foreach(job => executor.execute(new JobHandler(job)))
+ jobSet.jobs.foreach(job => jobExecutor.execute(new JobHandler(job)))
logInfo("Added jobs for time " + time)
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala
index cad68e248a..067e804202 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala
@@ -17,20 +17,14 @@
package org.apache.spark.streaming.scheduler
-import org.apache.spark.streaming.dstream.{NetworkInputDStream, NetworkReceiver}
-import org.apache.spark.streaming.dstream.{StopReceiver, ReportBlock, ReportError}
-import org.apache.spark.{SparkException, Logging, SparkEnv}
-import org.apache.spark.SparkContext._
-
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.Queue
-import scala.concurrent.duration._
+import scala.collection.mutable.{HashMap, Queue, SynchronizedMap}
import akka.actor._
-import akka.pattern.ask
-import akka.dispatch._
+import org.apache.spark.{Logging, SparkEnv, SparkException}
+import org.apache.spark.SparkContext._
import org.apache.spark.storage.BlockId
-import org.apache.spark.streaming.{Time, StreamingContext}
+import org.apache.spark.streaming.{StreamingContext, Time}
+import org.apache.spark.streaming.dstream.{NetworkReceiver, StopReceiver}
import org.apache.spark.util.AkkaUtils
private[streaming] sealed trait NetworkInputTrackerMessage
@@ -52,8 +46,8 @@ class NetworkInputTracker(ssc: StreamingContext) extends Logging {
val networkInputStreams = ssc.graph.getNetworkInputStreams()
val networkInputStreamMap = Map(networkInputStreams.map(x => (x.id, x)): _*)
val receiverExecutor = new ReceiverExecutor()
- val receiverInfo = new HashMap[Int, ActorRef]
- val receivedBlockIds = new HashMap[Int, Queue[BlockId]]
+ val receiverInfo = new HashMap[Int, ActorRef] with SynchronizedMap[Int, ActorRef]
+ val receivedBlockIds = new HashMap[Int, Queue[BlockId]] with SynchronizedMap[Int, Queue[BlockId]]
val timeout = AkkaUtils.askTimeout(ssc.conf)
@@ -63,7 +57,7 @@ class NetworkInputTracker(ssc: StreamingContext) extends Logging {
var currentTime: Time = null
/** Start the actor and receiver execution thread. */
- def start() {
+ def start() = synchronized {
if (actor != null) {
throw new SparkException("NetworkInputTracker already started")
}
@@ -77,72 +71,99 @@ class NetworkInputTracker(ssc: StreamingContext) extends Logging {
}
/** Stop the receiver execution thread. */
- def stop() {
+ def stop() = synchronized {
if (!networkInputStreams.isEmpty && actor != null) {
- receiverExecutor.interrupt()
- receiverExecutor.stopReceivers()
+ // First, stop the receivers
+ receiverExecutor.stop()
+
+ // Finally, stop the actor
ssc.env.actorSystem.stop(actor)
+ actor = null
logInfo("NetworkInputTracker stopped")
}
}
- /** Return all the blocks received from a receiver. */
- def getBlockIds(receiverId: Int, time: Time): Array[BlockId] = synchronized {
- val queue = receivedBlockIds.synchronized {
- receivedBlockIds.getOrElse(receiverId, new Queue[BlockId]())
+ /** Register a receiver */
+ def registerReceiver(streamId: Int, receiverActor: ActorRef, sender: ActorRef) {
+ if (!networkInputStreamMap.contains(streamId)) {
+ throw new Exception("Register received for unexpected id " + streamId)
}
- val result = queue.synchronized {
- queue.dequeueAll(x => true)
- }
- logInfo("Stream " + receiverId + " received " + result.size + " blocks")
- result.toArray
+ receiverInfo += ((streamId, receiverActor))
+ logInfo("Registered receiver for network stream " + streamId + " from " + sender.path.address)
+ }
+
+ /** Deregister a receiver */
+ def deregisterReceiver(streamId: Int, message: String) {
+ receiverInfo -= streamId
+ logError("Deregistered receiver for network stream " + streamId + " with message:\n" + message)
+ }
+
+ /** Get all the received blocks for the given stream. */
+ def getBlocks(streamId: Int, time: Time): Array[BlockId] = {
+ val queue = receivedBlockIds.getOrElseUpdate(streamId, new Queue[BlockId]())
+ val result = queue.dequeueAll(x => true).toArray
+ logInfo("Stream " + streamId + " received " + result.size + " blocks")
+ result
+ }
+
+ /** Add new blocks for the given stream */
+ def addBlocks(streamId: Int, blockIds: Seq[BlockId], metadata: Any) = {
+ val queue = receivedBlockIds.getOrElseUpdate(streamId, new Queue[BlockId])
+ queue ++= blockIds
+ networkInputStreamMap(streamId).addMetadata(metadata)
+ logDebug("Stream " + streamId + " received new blocks: " + blockIds.mkString("[", ", ", "]"))
+ }
+
+ /** Check if any blocks are left to be processed */
+ def hasMoreReceivedBlockIds: Boolean = {
+ !receivedBlockIds.forall(_._2.isEmpty)
}
/** Actor to receive messages from the receivers. */
private class NetworkInputTrackerActor extends Actor {
def receive = {
- case RegisterReceiver(streamId, receiverActor) => {
- if (!networkInputStreamMap.contains(streamId)) {
- throw new Exception("Register received for unexpected id " + streamId)
- }
- receiverInfo += ((streamId, receiverActor))
- logInfo("Registered receiver for network stream " + streamId + " from "
- + sender.path.address)
+ case RegisterReceiver(streamId, receiverActor) =>
+ registerReceiver(streamId, receiverActor, sender)
+ sender ! true
+ case AddBlocks(streamId, blockIds, metadata) =>
+ addBlocks(streamId, blockIds, metadata)
+ case DeregisterReceiver(streamId, message) =>
+ deregisterReceiver(streamId, message)
sender ! true
- }
- case AddBlocks(streamId, blockIds, metadata) => {
- val tmp = receivedBlockIds.synchronized {
- if (!receivedBlockIds.contains(streamId)) {
- receivedBlockIds += ((streamId, new Queue[BlockId]))
- }
- receivedBlockIds(streamId)
- }
- tmp.synchronized {
- tmp ++= blockIds
- }
- networkInputStreamMap(streamId).addMetadata(metadata)
- }
- case DeregisterReceiver(streamId, msg) => {
- receiverInfo -= streamId
- logError("De-registered receiver for network stream " + streamId
- + " with message " + msg)
- // TODO: Do something about the corresponding NetworkInputDStream
- }
}
}
/** This thread class runs all the receivers on the cluster. */
- class ReceiverExecutor extends Thread {
- val env = ssc.env
-
- override def run() {
- try {
- SparkEnv.set(env)
- startReceivers()
- } catch {
- case ie: InterruptedException => logInfo("ReceiverExecutor interrupted")
- } finally {
- stopReceivers()
+ class ReceiverExecutor {
+ @transient val env = ssc.env
+ @transient val thread = new Thread() {
+ override def run() {
+ try {
+ SparkEnv.set(env)
+ startReceivers()
+ } catch {
+ case ie: InterruptedException => logInfo("ReceiverExecutor interrupted")
+ }
+ }
+ }
+
+ def start() {
+ thread.start()
+ }
+
+ def stop() {
+ // Send the stop signal to all the receivers
+ stopReceivers()
+
+ // Wait for the Spark job that runs the receivers to be over
+ // That is, for the receivers to quit gracefully.
+ thread.join(10000)
+
+ // Check if all the receivers have been deregistered or not
+ if (!receiverInfo.isEmpty) {
+ logWarning("All of the receivers have not deregistered, " + receiverInfo)
+ } else {
+ logInfo("All of the receivers have deregistered successfully")
}
}
@@ -150,7 +171,7 @@ class NetworkInputTracker(ssc: StreamingContext) extends Logging {
* Get the receivers from the NetworkInputDStreams, distributes them to the
* worker nodes as a parallel collection, and runs them.
*/
- def startReceivers() {
+ private def startReceivers() {
val receivers = networkInputStreams.map(nis => {
val rcvr = nis.getReceiver()
rcvr.setStreamId(nis.id)
@@ -186,13 +207,16 @@ class NetworkInputTracker(ssc: StreamingContext) extends Logging {
}
// Distribute the receivers and start them
+ logInfo("Starting " + receivers.length + " receivers")
ssc.sparkContext.runJob(tempRDD, startReceiver)
+ logInfo("All of the receivers have been terminated")
}
/** Stops the receivers. */
- def stopReceivers() {
+ private def stopReceivers() {
// Signal the receivers to stop
receiverInfo.values.foreach(_ ! StopReceiver)
+ logInfo("Sent stop signal to all " + receiverInfo.size + " receivers")
}
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala
index c3a849d276..c5ef2cc8c3 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala
@@ -48,14 +48,11 @@ class SystemClock() extends Clock {
minPollTime
}
}
-
-
+
while (true) {
currentTime = System.currentTimeMillis()
waitTime = targetTime - currentTime
-
if (waitTime <= 0) {
-
return currentTime
}
val sleepTime =
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala
index 559c247385..f71938ac55 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala
@@ -17,44 +17,84 @@
package org.apache.spark.streaming.util
+import org.apache.spark.Logging
+
private[streaming]
-class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) => Unit) {
+class RecurringTimer(clock: Clock, period: Long, callback: (Long) => Unit, name: String)
+ extends Logging {
- private val thread = new Thread("RecurringTimer") {
+ private val thread = new Thread("RecurringTimer - " + name) {
+ setDaemon(true)
override def run() { loop }
}
-
- private var nextTime = 0L
+ @volatile private var prevTime = -1L
+ @volatile private var nextTime = -1L
+ @volatile private var stopped = false
+
+ /**
+ * Get the time when this timer will fire if it is started right now.
+ * The time will be a multiple of this timer's period and more than
+ * current system time.
+ */
def getStartTime(): Long = {
(math.floor(clock.currentTime.toDouble / period) + 1).toLong * period
}
+ /**
+ * Get the time when the timer will fire if it is restarted right now.
+ * This time depends on when the timer was started the first time, and was stopped
+ * for whatever reason. The time must be a multiple of this timer's period and
+ * more than current time.
+ */
def getRestartTime(originalStartTime: Long): Long = {
val gap = clock.currentTime - originalStartTime
(math.floor(gap.toDouble / period).toLong + 1) * period + originalStartTime
}
- def start(startTime: Long): Long = {
+ /**
+ * Start at the given start time.
+ */
+ def start(startTime: Long): Long = synchronized {
nextTime = startTime
thread.start()
+ logInfo("Started timer for " + name + " at time " + nextTime)
nextTime
}
+ /**
+ * Start at the earliest time it can start based on the period.
+ */
def start(): Long = {
start(getStartTime())
}
- def stop() {
- thread.interrupt()
+ /**
+ * Stop the timer, and return the last time the callback was made.
+ * interruptTimer = true will interrupt the callback
+ * if it is in progress (not guaranteed to give correct time in this case).
+ */
+ def stop(interruptTimer: Boolean): Long = synchronized {
+ if (!stopped) {
+ stopped = true
+ if (interruptTimer) thread.interrupt()
+ thread.join()
+ logInfo("Stopped timer for " + name + " after time " + prevTime)
+ }
+ prevTime
}
-
+
+ /**
+ * Repeatedly call the callback every interval.
+ */
private def loop() {
try {
- while (true) {
+ while (!stopped) {
clock.waitTillTime(nextTime)
callback(nextTime)
+ prevTime = nextTime
nextTime += period
+ logDebug("Callback for " + name + " called at time " + prevTime)
}
} catch {
case e: InterruptedException =>
@@ -74,10 +114,10 @@ object RecurringTimer {
println("" + currentTime + ": " + (currentTime - lastRecurTime))
lastRecurTime = currentTime
}
- val timer = new RecurringTimer(new SystemClock(), period, onRecur)
+ val timer = new RecurringTimer(new SystemClock(), period, onRecur, "Test")
timer.start()
Thread.sleep(30 * 1000)
- timer.stop()
+ timer.stop(true)
}
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
index bcb0c28bf0..bb73dbf29b 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
@@ -324,7 +324,7 @@ class BasicOperationsSuite extends TestSuiteBase {
val updateStateOperation = (s: DStream[String]) => {
val updateFunc = (values: Seq[Int], state: Option[Int]) => {
- Some(values.foldLeft(0)(_ + _) + state.getOrElse(0))
+ Some(values.sum + state.getOrElse(0))
}
s.map(x => (x, 1)).updateStateByKey[Int](updateFunc)
}
@@ -359,7 +359,7 @@ class BasicOperationsSuite extends TestSuiteBase {
// updateFunc clears a state when a StateObject is seen without new values twice in a row
val updateFunc = (values: Seq[Int], state: Option[StateObject]) => {
val stateObj = state.getOrElse(new StateObject)
- values.foldLeft(0)(_ + _) match {
+ values.sum match {
case 0 => stateObj.expireCounter += 1 // no new values
case n => { // has new values, increment and reset expireCounter
stateObj.counter += n
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
index 717da8e004..9cc27ef7f0 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
@@ -17,19 +17,22 @@
package org.apache.spark.streaming
-import org.scalatest.{FunSuite, BeforeAndAfter}
-import org.scalatest.exceptions.TestFailedDueToTimeoutException
+import java.util.concurrent.atomic.AtomicInteger
+
+import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException}
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.streaming.dstream.{DStream, NetworkReceiver}
+import org.apache.spark.util.{MetadataCleaner, Utils}
+import org.scalatest.{BeforeAndAfter, FunSuite}
import org.scalatest.concurrent.Timeouts
+import org.scalatest.exceptions.TestFailedDueToTimeoutException
import org.scalatest.time.SpanSugar._
-import org.apache.spark.{SparkException, SparkConf, SparkContext}
-import org.apache.spark.util.{Utils, MetadataCleaner}
-import org.apache.spark.streaming.dstream.DStream
-class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts {
+class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts with Logging {
val master = "local[2]"
val appName = this.getClass.getSimpleName
- val batchDuration = Seconds(1)
+ val batchDuration = Milliseconds(500)
val sparkHome = "someDir"
val envPair = "key" -> "value"
val ttl = StreamingContext.DEFAULT_CLEANER_TTL + 100
@@ -108,19 +111,31 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts {
val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName)
myConf.set("spark.cleaner.ttl", ttl.toString)
val ssc1 = new StreamingContext(myConf, batchDuration)
+ addInputStream(ssc1).register
+ ssc1.start()
val cp = new Checkpoint(ssc1, Time(1000))
assert(MetadataCleaner.getDelaySeconds(cp.sparkConf) === ttl)
ssc1.stop()
val newCp = Utils.deserialize[Checkpoint](Utils.serialize(cp))
assert(MetadataCleaner.getDelaySeconds(newCp.sparkConf) === ttl)
- ssc = new StreamingContext(null, cp, null)
+ ssc = new StreamingContext(null, newCp, null)
assert(MetadataCleaner.getDelaySeconds(ssc.conf) === ttl)
}
- test("start multiple times") {
+ test("start and stop state check") {
ssc = new StreamingContext(master, appName, batchDuration)
addInputStream(ssc).register
+ assert(ssc.state === ssc.StreamingContextState.Initialized)
+ ssc.start()
+ assert(ssc.state === ssc.StreamingContextState.Started)
+ ssc.stop()
+ assert(ssc.state === ssc.StreamingContextState.Stopped)
+ }
+
+ test("start multiple times") {
+ ssc = new StreamingContext(master, appName, batchDuration)
+ addInputStream(ssc).register
ssc.start()
intercept[SparkException] {
ssc.start()
@@ -133,18 +148,61 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts {
ssc.start()
ssc.stop()
ssc.stop()
- ssc = null
}
+ test("stop before start and start after stop") {
+ ssc = new StreamingContext(master, appName, batchDuration)
+ addInputStream(ssc).register
+ ssc.stop() // stop before start should not throw exception
+ ssc.start()
+ ssc.stop()
+ intercept[SparkException] {
+ ssc.start() // start after stop should throw exception
+ }
+ }
+
+
test("stop only streaming context") {
ssc = new StreamingContext(master, appName, batchDuration)
sc = ssc.sparkContext
addInputStream(ssc).register
ssc.start()
ssc.stop(false)
- ssc = null
assert(sc.makeRDD(1 to 100).collect().size === 100)
ssc = new StreamingContext(sc, batchDuration)
+ addInputStream(ssc).register
+ ssc.start()
+ ssc.stop()
+ }
+
+ test("stop gracefully") {
+ val conf = new SparkConf().setMaster(master).setAppName(appName)
+ conf.set("spark.cleaner.ttl", "3600")
+ sc = new SparkContext(conf)
+ for (i <- 1 to 4) {
+ logInfo("==================================")
+ ssc = new StreamingContext(sc, batchDuration)
+ var runningCount = 0
+ TestReceiver.counter.set(1)
+ val input = ssc.networkStream(new TestReceiver)
+ input.count.foreachRDD(rdd => {
+ val count = rdd.first()
+ logInfo("Count = " + count)
+ runningCount += count.toInt
+ })
+ ssc.start()
+ ssc.awaitTermination(500)
+ ssc.stop(stopSparkContext = false, stopGracefully = true)
+ logInfo("Running count = " + runningCount)
+ logInfo("TestReceiver.counter = " + TestReceiver.counter.get())
+ assert(runningCount > 0)
+ assert(
+ (TestReceiver.counter.get() == runningCount + 1) ||
+ (TestReceiver.counter.get() == runningCount + 2),
+ "Received records = " + TestReceiver.counter.get() + ", " +
+ "processed records = " + runningCount
+ )
+ }
}
test("awaitTermination") {
@@ -199,7 +257,6 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts {
test("awaitTermination with error in job generation") {
ssc = new StreamingContext(master, appName, batchDuration)
val inputStream = addInputStream(ssc)
-
inputStream.transform(rdd => { throw new TestException("error in transform"); rdd }).register
val exception = intercept[TestException] {
ssc.start()
@@ -215,4 +272,29 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts {
}
}
-class TestException(msg: String) extends Exception(msg) \ No newline at end of file
+class TestException(msg: String) extends Exception(msg)
+
+/** Custom receiver for testing whether all data received by a receiver gets processed or not */
+class TestReceiver extends NetworkReceiver[Int] {
+ protected lazy val blockGenerator = new BlockGenerator(StorageLevel.MEMORY_ONLY)
+ protected def onStart() {
+ blockGenerator.start()
+ logInfo("BlockGenerator started on thread " + receivingThread)
+ try {
+ while(true) {
+ blockGenerator += TestReceiver.counter.getAndIncrement
+ Thread.sleep(0)
+ }
+ } finally {
+ logInfo("Receiving stopped at count value of " + TestReceiver.counter.get())
+ }
+ }
+
+ protected def onStop() {
+ blockGenerator.stop()
+ }
+}
+
+object TestReceiver {
+ val counter = new AtomicInteger(1)
+}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
index 201630672a..aa2d5c2fc2 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
@@ -277,7 +277,7 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms")
assert(output.size === numExpectedOutput, "Unexpected number of outputs generated")
- Thread.sleep(500) // Give some time for the forgetting old RDDs to complete
+ Thread.sleep(100) // Give some time for the forgetting old RDDs to complete
} catch {
case e: Exception => {e.printStackTrace(); throw e}
} finally {