aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--project/MimaBuild.scala24
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala14
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala48
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala12
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala151
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala1
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala2
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala124
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala56
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala154
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala5
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala62
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala4
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala108
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala2
15 files changed, 552 insertions, 215 deletions
diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala
index e7c9c47c96..5ea4817bfd 100644
--- a/project/MimaBuild.scala
+++ b/project/MimaBuild.scala
@@ -58,17 +58,19 @@ object MimaBuild {
SparkBuild.SPARK_VERSION match {
case v if v.startsWith("1.0") =>
Seq(
- excludePackage("org.apache.spark.api.java"),
- excludePackage("org.apache.spark.streaming.api.java"),
- excludePackage("org.apache.spark.mllib")
- ) ++
- excludeSparkClass("rdd.ClassTags") ++
- excludeSparkClass("util.XORShiftRandom") ++
- excludeSparkClass("mllib.recommendation.MFDataGenerator") ++
- excludeSparkClass("mllib.optimization.SquaredGradient") ++
- excludeSparkClass("mllib.regression.RidgeRegressionWithSGD") ++
- excludeSparkClass("mllib.regression.LassoWithSGD") ++
- excludeSparkClass("mllib.regression.LinearRegressionWithSGD")
+ excludePackage("org.apache.spark.api.java"),
+ excludePackage("org.apache.spark.streaming.api.java"),
+ excludePackage("org.apache.spark.mllib")
+ ) ++
+ excludeSparkClass("rdd.ClassTags") ++
+ excludeSparkClass("util.XORShiftRandom") ++
+ excludeSparkClass("mllib.recommendation.MFDataGenerator") ++
+ excludeSparkClass("mllib.optimization.SquaredGradient") ++
+ excludeSparkClass("mllib.regression.RidgeRegressionWithSGD") ++
+ excludeSparkClass("mllib.regression.LassoWithSGD") ++
+ excludeSparkClass("mllib.regression.LinearRegressionWithSGD") ++
+ excludeSparkClass("streaming.dstream.NetworkReceiver") ++
+ excludeSparkClass("streaming.dstream.NetworkReceiver#NetworkReceiverActor")
case _ => Seq()
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
index baf80fe2a9..93023e8dce 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
@@ -194,19 +194,19 @@ class CheckpointWriter(
}
}
- def stop() {
- synchronized {
- if (stopped) {
- return
- }
- stopped = true
- }
+ def stop(): Unit = synchronized {
+ if (stopped) return
+
executor.shutdown()
val startTime = System.currentTimeMillis()
val terminated = executor.awaitTermination(10, java.util.concurrent.TimeUnit.SECONDS)
+ if (!terminated) {
+ executor.shutdownNow()
+ }
val endTime = System.currentTimeMillis()
logInfo("CheckpointWriter executor terminated ? " + terminated +
", waited for " + (endTime - startTime) + " ms.")
+ stopped = true
}
private def fs = synchronized {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
index e198c69470..a4e236c65f 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
@@ -158,6 +158,15 @@ class StreamingContext private[streaming] (
private[streaming] val waiter = new ContextWaiter
+ /** Enumeration to identify current state of the StreamingContext */
+ private[streaming] object StreamingContextState extends Enumeration {
+ type CheckpointState = Value
+ val Initialized, Started, Stopped = Value
+ }
+
+ import StreamingContextState._
+ private[streaming] var state = Initialized
+
/**
* Return the associated Spark context
*/
@@ -405,9 +414,18 @@ class StreamingContext private[streaming] (
/**
* Start the execution of the streams.
*/
- def start() = synchronized {
+ def start(): Unit = synchronized {
+ // Throw exception if the context has already been started once
+ // or if a stopped context is being started again
+ if (state == Started) {
+ throw new SparkException("StreamingContext has already been started")
+ }
+ if (state == Stopped) {
+ throw new SparkException("StreamingContext has already been stopped")
+ }
validate()
scheduler.start()
+ state = Started
}
/**
@@ -428,14 +446,38 @@ class StreamingContext private[streaming] (
}
/**
- * Stop the execution of the streams.
+ * Stop the execution of the streams immediately (does not wait for all received data
+ * to be processed).
* @param stopSparkContext Stop the associated SparkContext or not
+ *
*/
def stop(stopSparkContext: Boolean = true): Unit = synchronized {
- scheduler.stop()
+ stop(stopSparkContext, false)
+ }
+
+ /**
+ * Stop the execution of the streams, with option of ensuring all received data
+ * has been processed.
+ * @param stopSparkContext Stop the associated SparkContext or not
+ * @param stopGracefully Stop gracefully by waiting for the processing of all
+ * received data to be completed
+ */
+ def stop(stopSparkContext: Boolean, stopGracefully: Boolean): Unit = synchronized {
+ // Warn (but not fail) if context is stopped twice,
+ // or context is stopped before starting
+ if (state == Initialized) {
+ logWarning("StreamingContext has not been started yet")
+ return
+ }
+ if (state == Stopped) {
+ logWarning("StreamingContext has already been stopped")
+ return
+ } // no need to throw an exception as its okay to stop twice
+ scheduler.stop(stopGracefully)
logInfo("StreamingContext stopped successfully")
waiter.notifyStop()
if (stopSparkContext) sc.stop()
+ state = Stopped
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
index b705d2ec9a..c800602d09 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
@@ -509,8 +509,16 @@ class JavaStreamingContext(val ssc: StreamingContext) {
* Stop the execution of the streams.
* @param stopSparkContext Stop the associated SparkContext or not
*/
- def stop(stopSparkContext: Boolean): Unit = {
- ssc.stop(stopSparkContext)
+ def stop(stopSparkContext: Boolean) = ssc.stop(stopSparkContext)
+
+ /**
+ * Stop the execution of the streams.
+ * @param stopSparkContext Stop the associated SparkContext or not
+ * @param stopGracefully Stop gracefully by waiting for the processing of all
+ * received data to be completed
+ */
+ def stop(stopSparkContext: Boolean, stopGracefully: Boolean) = {
+ ssc.stop(stopSparkContext, stopGracefully)
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
index 72ad0bae75..d19a635fe8 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala
@@ -17,7 +17,7 @@
package org.apache.spark.streaming.dstream
-import java.util.concurrent.ArrayBlockingQueue
+import java.util.concurrent.{TimeUnit, ArrayBlockingQueue}
import java.nio.ByteBuffer
import scala.collection.mutable.ArrayBuffer
@@ -34,6 +34,7 @@ import org.apache.spark.{Logging, SparkEnv}
import org.apache.spark.rdd.{RDD, BlockRDD}
import org.apache.spark.storage.{BlockId, StorageLevel, StreamBlockId}
import org.apache.spark.streaming.scheduler.{DeregisterReceiver, AddBlocks, RegisterReceiver}
+import org.apache.spark.util.AkkaUtils
/**
* Abstract class for defining any [[org.apache.spark.streaming.dstream.InputDStream]]
@@ -69,7 +70,7 @@ abstract class NetworkInputDStream[T: ClassTag](@transient ssc_ : StreamingConte
// then this returns an empty RDD. This may happen when recovering from a
// master failure
if (validTime >= graph.startTime) {
- val blockIds = ssc.scheduler.networkInputTracker.getBlockIds(id, validTime)
+ val blockIds = ssc.scheduler.networkInputTracker.getBlocks(id, validTime)
Some(new BlockRDD[T](ssc.sc, blockIds))
} else {
Some(new BlockRDD[T](ssc.sc, Array[BlockId]()))
@@ -79,7 +80,7 @@ abstract class NetworkInputDStream[T: ClassTag](@transient ssc_ : StreamingConte
private[streaming] sealed trait NetworkReceiverMessage
-private[streaming] case class StopReceiver(msg: String) extends NetworkReceiverMessage
+private[streaming] case class StopReceiver() extends NetworkReceiverMessage
private[streaming] case class ReportBlock(blockId: BlockId, metadata: Any)
extends NetworkReceiverMessage
private[streaming] case class ReportError(msg: String) extends NetworkReceiverMessage
@@ -90,13 +91,31 @@ private[streaming] case class ReportError(msg: String) extends NetworkReceiverMe
*/
abstract class NetworkReceiver[T: ClassTag]() extends Serializable with Logging {
+ /** Local SparkEnv */
lazy protected val env = SparkEnv.get
+ /** Remote Akka actor for the NetworkInputTracker */
+ lazy protected val trackerActor = {
+ val ip = env.conf.get("spark.driver.host", "localhost")
+ val port = env.conf.getInt("spark.driver.port", 7077)
+ val url = "akka.tcp://spark@%s:%s/user/NetworkInputTracker".format(ip, port)
+ env.actorSystem.actorSelection(url)
+ }
+
+ /** Akka actor for receiving messages from the NetworkInputTracker in the driver */
lazy protected val actor = env.actorSystem.actorOf(
Props(new NetworkReceiverActor()), "NetworkReceiver-" + streamId)
+ /** Timeout for Akka actor messages */
+ lazy protected val askTimeout = AkkaUtils.askTimeout(env.conf)
+
+ /** Thread that starts the receiver and stays blocked while data is being received */
lazy protected val receivingThread = Thread.currentThread()
+ /** Exceptions that occurs while receiving data */
+ protected lazy val exceptions = new ArrayBuffer[Exception]
+
+ /** Identifier of the stream this receiver is associated with */
protected var streamId: Int = -1
/**
@@ -112,7 +131,7 @@ abstract class NetworkReceiver[T: ClassTag]() extends Serializable with Logging
def getLocationPreference() : Option[String] = None
/**
- * Starts the receiver. First is accesses all the lazy members to
+ * Start the receiver. First is accesses all the lazy members to
* materialize them. Then it calls the user-defined onStart() method to start
* other threads, etc required to receiver the data.
*/
@@ -124,83 +143,107 @@ abstract class NetworkReceiver[T: ClassTag]() extends Serializable with Logging
receivingThread
// Call user-defined onStart()
+ logInfo("Starting receiver")
onStart()
+
+ // Wait until interrupt is called on this thread
+ while(true) Thread.sleep(100000)
} catch {
case ie: InterruptedException =>
- logInfo("Receiving thread interrupted")
+ logInfo("Receiving thread has been interrupted, receiver " + streamId + " stopped")
case e: Exception =>
- stopOnError(e)
+ logError("Error receiving data in receiver " + streamId, e)
+ exceptions += e
+ }
+
+ // Call user-defined onStop()
+ logInfo("Stopping receiver")
+ try {
+ onStop()
+ } catch {
+ case e: Exception =>
+ logError("Error stopping receiver " + streamId, e)
+ exceptions += e
+ }
+
+ val message = if (exceptions.isEmpty) {
+ null
+ } else if (exceptions.size == 1) {
+ val e = exceptions.head
+ "Exception in receiver " + streamId + ": " + e.getMessage + "\n" + e.getStackTraceString
+ } else {
+ "Multiple exceptions in receiver " + streamId + "(" + exceptions.size + "):\n"
+ exceptions.zipWithIndex.map {
+ case (e, i) => "Exception " + i + ": " + e.getMessage + "\n" + e.getStackTraceString
+ }.mkString("\n")
}
+ logInfo("Deregistering receiver " + streamId)
+ val future = trackerActor.ask(DeregisterReceiver(streamId, message))(askTimeout)
+ Await.result(future, askTimeout)
+ logInfo("Deregistered receiver " + streamId)
+ env.actorSystem.stop(actor)
+ logInfo("Stopped receiver " + streamId)
}
/**
- * Stops the receiver. First it interrupts the main receiving thread,
- * that is, the thread that called receiver.start(). Then it calls the user-defined
- * onStop() method to stop other threads and/or do cleanup.
+ * Stop the receiver. First it interrupts the main receiving thread,
+ * that is, the thread that called receiver.start().
*/
def stop() {
+ // Stop receiving by interrupting the receiving thread
receivingThread.interrupt()
- onStop()
- // TODO: terminate the actor
+ logInfo("Interrupted receiving thread " + receivingThread + " for stopping")
}
/**
- * Stops the receiver and reports exception to the tracker.
+ * Stop the receiver and reports exception to the tracker.
* This should be called whenever an exception is to be handled on any thread
* of the receiver.
*/
protected def stopOnError(e: Exception) {
logError("Error receiving data", e)
+ exceptions += e
stop()
- actor ! ReportError(e.toString)
}
-
/**
- * Pushes a block (as an ArrayBuffer filled with data) into the block manager.
+ * Push a block (as an ArrayBuffer filled with data) into the block manager.
*/
def pushBlock(blockId: BlockId, arrayBuffer: ArrayBuffer[T], metadata: Any, level: StorageLevel) {
env.blockManager.put(blockId, arrayBuffer.asInstanceOf[ArrayBuffer[Any]], level)
- actor ! ReportBlock(blockId, metadata)
+ trackerActor ! AddBlocks(streamId, Array(blockId), metadata)
+ logDebug("Pushed block " + blockId)
}
/**
- * Pushes a block (as bytes) into the block manager.
+ * Push a block (as bytes) into the block manager.
*/
def pushBlock(blockId: BlockId, bytes: ByteBuffer, metadata: Any, level: StorageLevel) {
env.blockManager.putBytes(blockId, bytes, level)
- actor ! ReportBlock(blockId, metadata)
+ trackerActor ! AddBlocks(streamId, Array(blockId), metadata)
+ }
+
+ /** Set the ID of the DStream that this receiver is associated with */
+ protected[streaming] def setStreamId(id: Int) {
+ streamId = id
}
/** A helper actor that communicates with the NetworkInputTracker */
private class NetworkReceiverActor extends Actor {
- logInfo("Attempting to register with tracker")
- val ip = env.conf.get("spark.driver.host", "localhost")
- val port = env.conf.getInt("spark.driver.port", 7077)
- val url = "akka.tcp://spark@%s:%s/user/NetworkInputTracker".format(ip, port)
- val tracker = env.actorSystem.actorSelection(url)
- val timeout = 5.seconds
override def preStart() {
- val future = tracker.ask(RegisterReceiver(streamId, self))(timeout)
- Await.result(future, timeout)
+ logInfo("Registered receiver " + streamId)
+ val future = trackerActor.ask(RegisterReceiver(streamId, self))(askTimeout)
+ Await.result(future, askTimeout)
}
override def receive() = {
- case ReportBlock(blockId, metadata) =>
- tracker ! AddBlocks(streamId, Array(blockId), metadata)
- case ReportError(msg) =>
- tracker ! DeregisterReceiver(streamId, msg)
- case StopReceiver(msg) =>
+ case StopReceiver =>
+ logInfo("Received stop signal")
stop()
- tracker ! DeregisterReceiver(streamId, msg)
}
}
- protected[streaming] def setStreamId(id: Int) {
- streamId = id
- }
-
/**
* Batches objects created by a [[org.apache.spark.streaming.dstream.NetworkReceiver]] and puts
* them into appropriately named blocks at regular intervals. This class starts two threads,
@@ -214,23 +257,26 @@ abstract class NetworkReceiver[T: ClassTag]() extends Serializable with Logging
val clock = new SystemClock()
val blockInterval = env.conf.getLong("spark.streaming.blockInterval", 200)
- val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer)
+ val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer,
+ "BlockGenerator")
val blockStorageLevel = storageLevel
val blocksForPushing = new ArrayBlockingQueue[Block](1000)
val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } }
var currentBuffer = new ArrayBuffer[T]
+ var stopped = false
def start() {
blockIntervalTimer.start()
blockPushingThread.start()
- logInfo("Data handler started")
+ logInfo("Started BlockGenerator")
}
def stop() {
- blockIntervalTimer.stop()
- blockPushingThread.interrupt()
- logInfo("Data handler stopped")
+ blockIntervalTimer.stop(false)
+ stopped = true
+ blockPushingThread.join()
+ logInfo("Stopped BlockGenerator")
}
def += (obj: T): Unit = synchronized {
@@ -248,24 +294,35 @@ abstract class NetworkReceiver[T: ClassTag]() extends Serializable with Logging
}
} catch {
case ie: InterruptedException =>
- logInfo("Block interval timer thread interrupted")
+ logInfo("Block updating timer thread was interrupted")
case e: Exception =>
- NetworkReceiver.this.stop()
+ NetworkReceiver.this.stopOnError(e)
}
}
private def keepPushingBlocks() {
- logInfo("Block pushing thread started")
+ logInfo("Started block pushing thread")
try {
- while(true) {
+ while(!stopped) {
+ Option(blocksForPushing.poll(100, TimeUnit.MILLISECONDS)) match {
+ case Some(block) =>
+ NetworkReceiver.this.pushBlock(block.id, block.buffer, block.metadata, storageLevel)
+ case None =>
+ }
+ }
+ // Push out the blocks that are still left
+ logInfo("Pushing out the last " + blocksForPushing.size() + " blocks")
+ while (!blocksForPushing.isEmpty) {
val block = blocksForPushing.take()
NetworkReceiver.this.pushBlock(block.id, block.buffer, block.metadata, storageLevel)
+ logInfo("Blocks left to push " + blocksForPushing.size())
}
+ logInfo("Stopped blocks pushing thread")
} catch {
case ie: InterruptedException =>
- logInfo("Block pushing thread interrupted")
+ logInfo("Block pushing thread was interrupted")
case e: Exception =>
- NetworkReceiver.this.stop()
+ NetworkReceiver.this.stopOnError(e)
}
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala
index 2cdd13f205..63d94d1cc6 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala
@@ -67,7 +67,6 @@ class SocketReceiver[T: ClassTag](
protected def onStop() {
blockGenerator.stop()
}
-
}
private[streaming]
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala
index bd78bae8a5..44eb2750c6 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala
@@ -174,10 +174,10 @@ private[streaming] class ActorReceiver[T: ClassTag](
blocksGenerator.start()
supervisor
logInfo("Supervision tree for receivers initialized at:" + supervisor.path)
+
}
protected def onStop() = {
supervisor ! PoisonPill
}
-
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
index c7306248b1..92d885c4bc 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
@@ -39,16 +39,22 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
private val ssc = jobScheduler.ssc
private val graph = ssc.graph
+
val clock = {
val clockClass = ssc.sc.conf.get(
"spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock")
Class.forName(clockClass).newInstance().asInstanceOf[Clock]
}
+
private val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds,
- longTime => eventActor ! GenerateJobs(new Time(longTime)))
- private lazy val checkpointWriter =
- if (ssc.checkpointDuration != null && ssc.checkpointDir != null) {
- new CheckpointWriter(this, ssc.conf, ssc.checkpointDir, ssc.sparkContext.hadoopConfiguration)
+ longTime => eventActor ! GenerateJobs(new Time(longTime)), "JobGenerator")
+
+ // This is marked lazy so that this is initialized after checkpoint duration has been set
+ // in the context and the generator has been started.
+ private lazy val shouldCheckpoint = ssc.checkpointDuration != null && ssc.checkpointDir != null
+
+ private lazy val checkpointWriter = if (shouldCheckpoint) {
+ new CheckpointWriter(this, ssc.conf, ssc.checkpointDir, ssc.sparkContext.hadoopConfiguration)
} else {
null
}
@@ -57,17 +63,16 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
// This not being null means the scheduler has been started and not stopped
private var eventActor: ActorRef = null
+ // last batch whose completion,checkpointing and metadata cleanup has been completed
+ private var lastProcessedBatch: Time = null
+
/** Start generation of jobs */
- def start() = synchronized {
- if (eventActor != null) {
- throw new SparkException("JobGenerator already started")
- }
+ def start(): Unit = synchronized {
+ if (eventActor != null) return // generator has already been started
eventActor = ssc.env.actorSystem.actorOf(Props(new Actor {
def receive = {
- case event: JobGeneratorEvent =>
- logDebug("Got event of type " + event.getClass.getName)
- processEvent(event)
+ case event: JobGeneratorEvent => processEvent(event)
}
}), "JobGenerator")
if (ssc.isCheckpointPresent) {
@@ -77,30 +82,79 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
}
}
- /** Stop generation of jobs */
- def stop() = synchronized {
- if (eventActor != null) {
- timer.stop()
- ssc.env.actorSystem.stop(eventActor)
- if (checkpointWriter != null) checkpointWriter.stop()
- ssc.graph.stop()
- logInfo("JobGenerator stopped")
+ /**
+ * Stop generation of jobs. processReceivedData = true makes this wait until jobs
+ * of current ongoing time interval has been generated, processed and corresponding
+ * checkpoints written.
+ */
+ def stop(processReceivedData: Boolean): Unit = synchronized {
+ if (eventActor == null) return // generator has already been stopped
+
+ if (processReceivedData) {
+ logInfo("Stopping JobGenerator gracefully")
+ val timeWhenStopStarted = System.currentTimeMillis()
+ val stopTimeout = 10 * ssc.graph.batchDuration.milliseconds
+ val pollTime = 100
+
+ // To prevent graceful stop to get stuck permanently
+ def hasTimedOut = {
+ val timedOut = System.currentTimeMillis() - timeWhenStopStarted > stopTimeout
+ if (timedOut) logWarning("Timed out while stopping the job generator")
+ timedOut
+ }
+
+ // Wait until all the received blocks in the network input tracker has
+ // been consumed by network input DStreams, and jobs have been generated with them
+ logInfo("Waiting for all received blocks to be consumed for job generation")
+ while(!hasTimedOut && jobScheduler.networkInputTracker.hasMoreReceivedBlockIds) {
+ Thread.sleep(pollTime)
+ }
+ logInfo("Waited for all received blocks to be consumed for job generation")
+
+ // Stop generating jobs
+ val stopTime = timer.stop(false)
+ graph.stop()
+ logInfo("Stopped generation timer")
+
+ // Wait for the jobs to complete and checkpoints to be written
+ def haveAllBatchesBeenProcessed = {
+ lastProcessedBatch != null && lastProcessedBatch.milliseconds == stopTime
+ }
+ logInfo("Waiting for jobs to be processed and checkpoints to be written")
+ while (!hasTimedOut && !haveAllBatchesBeenProcessed) {
+ Thread.sleep(pollTime)
+ }
+ logInfo("Waited for jobs to be processed and checkpoints to be written")
+ } else {
+ logInfo("Stopping JobGenerator immediately")
+ // Stop timer and graph immediately, ignore unprocessed data and pending jobs
+ timer.stop(true)
+ graph.stop()
}
+
+ // Stop the actor and checkpoint writer
+ if (shouldCheckpoint) checkpointWriter.stop()
+ ssc.env.actorSystem.stop(eventActor)
+ logInfo("Stopped JobGenerator")
}
/**
- * On batch completion, clear old metadata and checkpoint computation.
+ * Callback called when a batch has been completely processed.
*/
def onBatchCompletion(time: Time) {
eventActor ! ClearMetadata(time)
}
-
+
+ /**
+ * Callback called when the checkpoint of a batch has been written.
+ */
def onCheckpointCompletion(time: Time) {
eventActor ! ClearCheckpointData(time)
}
/** Processes all events */
private def processEvent(event: JobGeneratorEvent) {
+ logDebug("Got event " + event)
event match {
case GenerateJobs(time) => generateJobs(time)
case ClearMetadata(time) => clearMetadata(time)
@@ -114,7 +168,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
val startTime = new Time(timer.getStartTime())
graph.start(startTime - graph.batchDuration)
timer.start(startTime.milliseconds)
- logInfo("JobGenerator started at " + startTime)
+ logInfo("Started JobGenerator at " + startTime)
}
/** Restarts the generator based on the information in checkpoint */
@@ -152,15 +206,17 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
// Restart the timer
timer.start(restartTime.milliseconds)
- logInfo("JobGenerator restarted at " + restartTime)
+ logInfo("Restarted JobGenerator at " + restartTime)
}
/** Generate jobs and perform checkpoint for the given `time`. */
private def generateJobs(time: Time) {
SparkEnv.set(ssc.env)
Try(graph.generateJobs(time)) match {
- case Success(jobs) => jobScheduler.runJobs(time, jobs)
- case Failure(e) => jobScheduler.reportError("Error generating jobs for time " + time, e)
+ case Success(jobs) =>
+ jobScheduler.runJobs(time, jobs)
+ case Failure(e) =>
+ jobScheduler.reportError("Error generating jobs for time " + time, e)
}
eventActor ! DoCheckpoint(time)
}
@@ -168,20 +224,32 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
/** Clear DStream metadata for the given `time`. */
private def clearMetadata(time: Time) {
ssc.graph.clearMetadata(time)
- eventActor ! DoCheckpoint(time)
+
+ // If checkpointing is enabled, then checkpoint,
+ // else mark batch to be fully processed
+ if (shouldCheckpoint) {
+ eventActor ! DoCheckpoint(time)
+ } else {
+ markBatchFullyProcessed(time)
+ }
}
/** Clear DStream checkpoint data for the given `time`. */
private def clearCheckpointData(time: Time) {
ssc.graph.clearCheckpointData(time)
+ markBatchFullyProcessed(time)
}
/** Perform checkpoint for the give `time`. */
- private def doCheckpoint(time: Time) = synchronized {
- if (checkpointWriter != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointDuration)) {
+ private def doCheckpoint(time: Time) {
+ if (shouldCheckpoint && (time - graph.zeroTime).isMultipleOf(ssc.checkpointDuration)) {
logInfo("Checkpointing graph for time " + time)
ssc.graph.updateCheckpointData(time)
checkpointWriter.write(new Checkpoint(ssc, time))
}
}
+
+ private def markBatchFullyProcessed(time: Time) {
+ lastProcessedBatch = time
+ }
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
index de675d3c7f..04e0a6a283 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
@@ -39,7 +39,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging {
private val jobSets = new ConcurrentHashMap[Time, JobSet]
private val numConcurrentJobs = ssc.conf.getInt("spark.streaming.concurrentJobs", 1)
- private val executor = Executors.newFixedThreadPool(numConcurrentJobs)
+ private val jobExecutor = Executors.newFixedThreadPool(numConcurrentJobs)
private val jobGenerator = new JobGenerator(this)
val clock = jobGenerator.clock
val listenerBus = new StreamingListenerBus()
@@ -50,36 +50,54 @@ class JobScheduler(val ssc: StreamingContext) extends Logging {
private var eventActor: ActorRef = null
- def start() = synchronized {
- if (eventActor != null) {
- throw new SparkException("JobScheduler already started")
- }
+ def start(): Unit = synchronized {
+ if (eventActor != null) return // scheduler has already been started
+ logDebug("Starting JobScheduler")
eventActor = ssc.env.actorSystem.actorOf(Props(new Actor {
def receive = {
case event: JobSchedulerEvent => processEvent(event)
}
}), "JobScheduler")
+
listenerBus.start()
networkInputTracker = new NetworkInputTracker(ssc)
networkInputTracker.start()
- Thread.sleep(1000)
jobGenerator.start()
- logInfo("JobScheduler started")
+ logInfo("Started JobScheduler")
}
- def stop() = synchronized {
- if (eventActor != null) {
- jobGenerator.stop()
- networkInputTracker.stop()
- executor.shutdown()
- if (!executor.awaitTermination(2, TimeUnit.SECONDS)) {
- executor.shutdownNow()
- }
- listenerBus.stop()
- ssc.env.actorSystem.stop(eventActor)
- logInfo("JobScheduler stopped")
+ def stop(processAllReceivedData: Boolean): Unit = synchronized {
+ if (eventActor == null) return // scheduler has already been stopped
+ logDebug("Stopping JobScheduler")
+
+ // First, stop receiving
+ networkInputTracker.stop()
+
+ // Second, stop generating jobs. If it has to process all received data,
+ // then this will wait for all the processing through JobScheduler to be over.
+ jobGenerator.stop(processAllReceivedData)
+
+ // Stop the executor for receiving new jobs
+ logDebug("Stopping job executor")
+ jobExecutor.shutdown()
+
+ // Wait for the queued jobs to complete if indicated
+ val terminated = if (processAllReceivedData) {
+ jobExecutor.awaitTermination(1, TimeUnit.HOURS) // just a very large period of time
+ } else {
+ jobExecutor.awaitTermination(2, TimeUnit.SECONDS)
}
+ if (!terminated) {
+ jobExecutor.shutdownNow()
+ }
+ logDebug("Stopped job executor")
+
+ // Stop everything else
+ listenerBus.stop()
+ ssc.env.actorSystem.stop(eventActor)
+ eventActor = null
+ logInfo("Stopped JobScheduler")
}
def runJobs(time: Time, jobs: Seq[Job]) {
@@ -88,7 +106,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging {
} else {
val jobSet = new JobSet(time, jobs)
jobSets.put(time, jobSet)
- jobSet.jobs.foreach(job => executor.execute(new JobHandler(job)))
+ jobSet.jobs.foreach(job => jobExecutor.execute(new JobHandler(job)))
logInfo("Added jobs for time " + time)
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala
index cad68e248a..067e804202 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala
@@ -17,20 +17,14 @@
package org.apache.spark.streaming.scheduler
-import org.apache.spark.streaming.dstream.{NetworkInputDStream, NetworkReceiver}
-import org.apache.spark.streaming.dstream.{StopReceiver, ReportBlock, ReportError}
-import org.apache.spark.{SparkException, Logging, SparkEnv}
-import org.apache.spark.SparkContext._
-
-import scala.collection.mutable.HashMap
-import scala.collection.mutable.Queue
-import scala.concurrent.duration._
+import scala.collection.mutable.{HashMap, Queue, SynchronizedMap}
import akka.actor._
-import akka.pattern.ask
-import akka.dispatch._
+import org.apache.spark.{Logging, SparkEnv, SparkException}
+import org.apache.spark.SparkContext._
import org.apache.spark.storage.BlockId
-import org.apache.spark.streaming.{Time, StreamingContext}
+import org.apache.spark.streaming.{StreamingContext, Time}
+import org.apache.spark.streaming.dstream.{NetworkReceiver, StopReceiver}
import org.apache.spark.util.AkkaUtils
private[streaming] sealed trait NetworkInputTrackerMessage
@@ -52,8 +46,8 @@ class NetworkInputTracker(ssc: StreamingContext) extends Logging {
val networkInputStreams = ssc.graph.getNetworkInputStreams()
val networkInputStreamMap = Map(networkInputStreams.map(x => (x.id, x)): _*)
val receiverExecutor = new ReceiverExecutor()
- val receiverInfo = new HashMap[Int, ActorRef]
- val receivedBlockIds = new HashMap[Int, Queue[BlockId]]
+ val receiverInfo = new HashMap[Int, ActorRef] with SynchronizedMap[Int, ActorRef]
+ val receivedBlockIds = new HashMap[Int, Queue[BlockId]] with SynchronizedMap[Int, Queue[BlockId]]
val timeout = AkkaUtils.askTimeout(ssc.conf)
@@ -63,7 +57,7 @@ class NetworkInputTracker(ssc: StreamingContext) extends Logging {
var currentTime: Time = null
/** Start the actor and receiver execution thread. */
- def start() {
+ def start() = synchronized {
if (actor != null) {
throw new SparkException("NetworkInputTracker already started")
}
@@ -77,72 +71,99 @@ class NetworkInputTracker(ssc: StreamingContext) extends Logging {
}
/** Stop the receiver execution thread. */
- def stop() {
+ def stop() = synchronized {
if (!networkInputStreams.isEmpty && actor != null) {
- receiverExecutor.interrupt()
- receiverExecutor.stopReceivers()
+ // First, stop the receivers
+ receiverExecutor.stop()
+
+ // Finally, stop the actor
ssc.env.actorSystem.stop(actor)
+ actor = null
logInfo("NetworkInputTracker stopped")
}
}
- /** Return all the blocks received from a receiver. */
- def getBlockIds(receiverId: Int, time: Time): Array[BlockId] = synchronized {
- val queue = receivedBlockIds.synchronized {
- receivedBlockIds.getOrElse(receiverId, new Queue[BlockId]())
+ /** Register a receiver */
+ def registerReceiver(streamId: Int, receiverActor: ActorRef, sender: ActorRef) {
+ if (!networkInputStreamMap.contains(streamId)) {
+ throw new Exception("Register received for unexpected id " + streamId)
}
- val result = queue.synchronized {
- queue.dequeueAll(x => true)
- }
- logInfo("Stream " + receiverId + " received " + result.size + " blocks")
- result.toArray
+ receiverInfo += ((streamId, receiverActor))
+ logInfo("Registered receiver for network stream " + streamId + " from " + sender.path.address)
+ }
+
+ /** Deregister a receiver */
+ def deregisterReceiver(streamId: Int, message: String) {
+ receiverInfo -= streamId
+ logError("Deregistered receiver for network stream " + streamId + " with message:\n" + message)
+ }
+
+ /** Get all the received blocks for the given stream. */
+ def getBlocks(streamId: Int, time: Time): Array[BlockId] = {
+ val queue = receivedBlockIds.getOrElseUpdate(streamId, new Queue[BlockId]())
+ val result = queue.dequeueAll(x => true).toArray
+ logInfo("Stream " + streamId + " received " + result.size + " blocks")
+ result
+ }
+
+ /** Add new blocks for the given stream */
+ def addBlocks(streamId: Int, blockIds: Seq[BlockId], metadata: Any) = {
+ val queue = receivedBlockIds.getOrElseUpdate(streamId, new Queue[BlockId])
+ queue ++= blockIds
+ networkInputStreamMap(streamId).addMetadata(metadata)
+ logDebug("Stream " + streamId + " received new blocks: " + blockIds.mkString("[", ", ", "]"))
+ }
+
+ /** Check if any blocks are left to be processed */
+ def hasMoreReceivedBlockIds: Boolean = {
+ !receivedBlockIds.forall(_._2.isEmpty)
}
/** Actor to receive messages from the receivers. */
private class NetworkInputTrackerActor extends Actor {
def receive = {
- case RegisterReceiver(streamId, receiverActor) => {
- if (!networkInputStreamMap.contains(streamId)) {
- throw new Exception("Register received for unexpected id " + streamId)
- }
- receiverInfo += ((streamId, receiverActor))
- logInfo("Registered receiver for network stream " + streamId + " from "
- + sender.path.address)
+ case RegisterReceiver(streamId, receiverActor) =>
+ registerReceiver(streamId, receiverActor, sender)
+ sender ! true
+ case AddBlocks(streamId, blockIds, metadata) =>
+ addBlocks(streamId, blockIds, metadata)
+ case DeregisterReceiver(streamId, message) =>
+ deregisterReceiver(streamId, message)
sender ! true
- }
- case AddBlocks(streamId, blockIds, metadata) => {
- val tmp = receivedBlockIds.synchronized {
- if (!receivedBlockIds.contains(streamId)) {
- receivedBlockIds += ((streamId, new Queue[BlockId]))
- }
- receivedBlockIds(streamId)
- }
- tmp.synchronized {
- tmp ++= blockIds
- }
- networkInputStreamMap(streamId).addMetadata(metadata)
- }
- case DeregisterReceiver(streamId, msg) => {
- receiverInfo -= streamId
- logError("De-registered receiver for network stream " + streamId
- + " with message " + msg)
- // TODO: Do something about the corresponding NetworkInputDStream
- }
}
}
/** This thread class runs all the receivers on the cluster. */
- class ReceiverExecutor extends Thread {
- val env = ssc.env
-
- override def run() {
- try {
- SparkEnv.set(env)
- startReceivers()
- } catch {
- case ie: InterruptedException => logInfo("ReceiverExecutor interrupted")
- } finally {
- stopReceivers()
+ class ReceiverExecutor {
+ @transient val env = ssc.env
+ @transient val thread = new Thread() {
+ override def run() {
+ try {
+ SparkEnv.set(env)
+ startReceivers()
+ } catch {
+ case ie: InterruptedException => logInfo("ReceiverExecutor interrupted")
+ }
+ }
+ }
+
+ def start() {
+ thread.start()
+ }
+
+ def stop() {
+ // Send the stop signal to all the receivers
+ stopReceivers()
+
+ // Wait for the Spark job that runs the receivers to be over
+ // That is, for the receivers to quit gracefully.
+ thread.join(10000)
+
+ // Check if all the receivers have been deregistered or not
+ if (!receiverInfo.isEmpty) {
+ logWarning("All of the receivers have not deregistered, " + receiverInfo)
+ } else {
+ logInfo("All of the receivers have deregistered successfully")
}
}
@@ -150,7 +171,7 @@ class NetworkInputTracker(ssc: StreamingContext) extends Logging {
* Get the receivers from the NetworkInputDStreams, distributes them to the
* worker nodes as a parallel collection, and runs them.
*/
- def startReceivers() {
+ private def startReceivers() {
val receivers = networkInputStreams.map(nis => {
val rcvr = nis.getReceiver()
rcvr.setStreamId(nis.id)
@@ -186,13 +207,16 @@ class NetworkInputTracker(ssc: StreamingContext) extends Logging {
}
// Distribute the receivers and start them
+ logInfo("Starting " + receivers.length + " receivers")
ssc.sparkContext.runJob(tempRDD, startReceiver)
+ logInfo("All of the receivers have been terminated")
}
/** Stops the receivers. */
- def stopReceivers() {
+ private def stopReceivers() {
// Signal the receivers to stop
receiverInfo.values.foreach(_ ! StopReceiver)
+ logInfo("Sent stop signal to all " + receiverInfo.size + " receivers")
}
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala
index c3a849d276..c5ef2cc8c3 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala
@@ -48,14 +48,11 @@ class SystemClock() extends Clock {
minPollTime
}
}
-
-
+
while (true) {
currentTime = System.currentTimeMillis()
waitTime = targetTime - currentTime
-
if (waitTime <= 0) {
-
return currentTime
}
val sleepTime =
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala
index 559c247385..f71938ac55 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala
@@ -17,44 +17,84 @@
package org.apache.spark.streaming.util
+import org.apache.spark.Logging
+
private[streaming]
-class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) => Unit) {
+class RecurringTimer(clock: Clock, period: Long, callback: (Long) => Unit, name: String)
+ extends Logging {
- private val thread = new Thread("RecurringTimer") {
+ private val thread = new Thread("RecurringTimer - " + name) {
+ setDaemon(true)
override def run() { loop }
}
-
- private var nextTime = 0L
+ @volatile private var prevTime = -1L
+ @volatile private var nextTime = -1L
+ @volatile private var stopped = false
+
+ /**
+ * Get the time when this timer will fire if it is started right now.
+ * The time will be a multiple of this timer's period and more than
+ * current system time.
+ */
def getStartTime(): Long = {
(math.floor(clock.currentTime.toDouble / period) + 1).toLong * period
}
+ /**
+ * Get the time when the timer will fire if it is restarted right now.
+ * This time depends on when the timer was started the first time, and was stopped
+ * for whatever reason. The time must be a multiple of this timer's period and
+ * more than current time.
+ */
def getRestartTime(originalStartTime: Long): Long = {
val gap = clock.currentTime - originalStartTime
(math.floor(gap.toDouble / period).toLong + 1) * period + originalStartTime
}
- def start(startTime: Long): Long = {
+ /**
+ * Start at the given start time.
+ */
+ def start(startTime: Long): Long = synchronized {
nextTime = startTime
thread.start()
+ logInfo("Started timer for " + name + " at time " + nextTime)
nextTime
}
+ /**
+ * Start at the earliest time it can start based on the period.
+ */
def start(): Long = {
start(getStartTime())
}
- def stop() {
- thread.interrupt()
+ /**
+ * Stop the timer, and return the last time the callback was made.
+ * interruptTimer = true will interrupt the callback
+ * if it is in progress (not guaranteed to give correct time in this case).
+ */
+ def stop(interruptTimer: Boolean): Long = synchronized {
+ if (!stopped) {
+ stopped = true
+ if (interruptTimer) thread.interrupt()
+ thread.join()
+ logInfo("Stopped timer for " + name + " after time " + prevTime)
+ }
+ prevTime
}
-
+
+ /**
+ * Repeatedly call the callback every interval.
+ */
private def loop() {
try {
- while (true) {
+ while (!stopped) {
clock.waitTillTime(nextTime)
callback(nextTime)
+ prevTime = nextTime
nextTime += period
+ logDebug("Callback for " + name + " called at time " + prevTime)
}
} catch {
case e: InterruptedException =>
@@ -74,10 +114,10 @@ object RecurringTimer {
println("" + currentTime + ": " + (currentTime - lastRecurTime))
lastRecurTime = currentTime
}
- val timer = new RecurringTimer(new SystemClock(), period, onRecur)
+ val timer = new RecurringTimer(new SystemClock(), period, onRecur, "Test")
timer.start()
Thread.sleep(30 * 1000)
- timer.stop()
+ timer.stop(true)
}
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
index bcb0c28bf0..bb73dbf29b 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
@@ -324,7 +324,7 @@ class BasicOperationsSuite extends TestSuiteBase {
val updateStateOperation = (s: DStream[String]) => {
val updateFunc = (values: Seq[Int], state: Option[Int]) => {
- Some(values.foldLeft(0)(_ + _) + state.getOrElse(0))
+ Some(values.sum + state.getOrElse(0))
}
s.map(x => (x, 1)).updateStateByKey[Int](updateFunc)
}
@@ -359,7 +359,7 @@ class BasicOperationsSuite extends TestSuiteBase {
// updateFunc clears a state when a StateObject is seen without new values twice in a row
val updateFunc = (values: Seq[Int], state: Option[StateObject]) => {
val stateObj = state.getOrElse(new StateObject)
- values.foldLeft(0)(_ + _) match {
+ values.sum match {
case 0 => stateObj.expireCounter += 1 // no new values
case n => { // has new values, increment and reset expireCounter
stateObj.counter += n
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
index 717da8e004..9cc27ef7f0 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
@@ -17,19 +17,22 @@
package org.apache.spark.streaming
-import org.scalatest.{FunSuite, BeforeAndAfter}
-import org.scalatest.exceptions.TestFailedDueToTimeoutException
+import java.util.concurrent.atomic.AtomicInteger
+
+import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException}
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.streaming.dstream.{DStream, NetworkReceiver}
+import org.apache.spark.util.{MetadataCleaner, Utils}
+import org.scalatest.{BeforeAndAfter, FunSuite}
import org.scalatest.concurrent.Timeouts
+import org.scalatest.exceptions.TestFailedDueToTimeoutException
import org.scalatest.time.SpanSugar._
-import org.apache.spark.{SparkException, SparkConf, SparkContext}
-import org.apache.spark.util.{Utils, MetadataCleaner}
-import org.apache.spark.streaming.dstream.DStream
-class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts {
+class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts with Logging {
val master = "local[2]"
val appName = this.getClass.getSimpleName
- val batchDuration = Seconds(1)
+ val batchDuration = Milliseconds(500)
val sparkHome = "someDir"
val envPair = "key" -> "value"
val ttl = StreamingContext.DEFAULT_CLEANER_TTL + 100
@@ -108,19 +111,31 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts {
val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName)
myConf.set("spark.cleaner.ttl", ttl.toString)
val ssc1 = new StreamingContext(myConf, batchDuration)
+ addInputStream(ssc1).register
+ ssc1.start()
val cp = new Checkpoint(ssc1, Time(1000))
assert(MetadataCleaner.getDelaySeconds(cp.sparkConf) === ttl)
ssc1.stop()
val newCp = Utils.deserialize[Checkpoint](Utils.serialize(cp))
assert(MetadataCleaner.getDelaySeconds(newCp.sparkConf) === ttl)
- ssc = new StreamingContext(null, cp, null)
+ ssc = new StreamingContext(null, newCp, null)
assert(MetadataCleaner.getDelaySeconds(ssc.conf) === ttl)
}
- test("start multiple times") {
+ test("start and stop state check") {
ssc = new StreamingContext(master, appName, batchDuration)
addInputStream(ssc).register
+ assert(ssc.state === ssc.StreamingContextState.Initialized)
+ ssc.start()
+ assert(ssc.state === ssc.StreamingContextState.Started)
+ ssc.stop()
+ assert(ssc.state === ssc.StreamingContextState.Stopped)
+ }
+
+ test("start multiple times") {
+ ssc = new StreamingContext(master, appName, batchDuration)
+ addInputStream(ssc).register
ssc.start()
intercept[SparkException] {
ssc.start()
@@ -133,18 +148,61 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts {
ssc.start()
ssc.stop()
ssc.stop()
- ssc = null
}
+ test("stop before start and start after stop") {
+ ssc = new StreamingContext(master, appName, batchDuration)
+ addInputStream(ssc).register
+ ssc.stop() // stop before start should not throw exception
+ ssc.start()
+ ssc.stop()
+ intercept[SparkException] {
+ ssc.start() // start after stop should throw exception
+ }
+ }
+
+
test("stop only streaming context") {
ssc = new StreamingContext(master, appName, batchDuration)
sc = ssc.sparkContext
addInputStream(ssc).register
ssc.start()
ssc.stop(false)
- ssc = null
assert(sc.makeRDD(1 to 100).collect().size === 100)
ssc = new StreamingContext(sc, batchDuration)
+ addInputStream(ssc).register
+ ssc.start()
+ ssc.stop()
+ }
+
+ test("stop gracefully") {
+ val conf = new SparkConf().setMaster(master).setAppName(appName)
+ conf.set("spark.cleaner.ttl", "3600")
+ sc = new SparkContext(conf)
+ for (i <- 1 to 4) {
+ logInfo("==================================")
+ ssc = new StreamingContext(sc, batchDuration)
+ var runningCount = 0
+ TestReceiver.counter.set(1)
+ val input = ssc.networkStream(new TestReceiver)
+ input.count.foreachRDD(rdd => {
+ val count = rdd.first()
+ logInfo("Count = " + count)
+ runningCount += count.toInt
+ })
+ ssc.start()
+ ssc.awaitTermination(500)
+ ssc.stop(stopSparkContext = false, stopGracefully = true)
+ logInfo("Running count = " + runningCount)
+ logInfo("TestReceiver.counter = " + TestReceiver.counter.get())
+ assert(runningCount > 0)
+ assert(
+ (TestReceiver.counter.get() == runningCount + 1) ||
+ (TestReceiver.counter.get() == runningCount + 2),
+ "Received records = " + TestReceiver.counter.get() + ", " +
+ "processed records = " + runningCount
+ )
+ }
}
test("awaitTermination") {
@@ -199,7 +257,6 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts {
test("awaitTermination with error in job generation") {
ssc = new StreamingContext(master, appName, batchDuration)
val inputStream = addInputStream(ssc)
-
inputStream.transform(rdd => { throw new TestException("error in transform"); rdd }).register
val exception = intercept[TestException] {
ssc.start()
@@ -215,4 +272,29 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts {
}
}
-class TestException(msg: String) extends Exception(msg) \ No newline at end of file
+class TestException(msg: String) extends Exception(msg)
+
+/** Custom receiver for testing whether all data received by a receiver gets processed or not */
+class TestReceiver extends NetworkReceiver[Int] {
+ protected lazy val blockGenerator = new BlockGenerator(StorageLevel.MEMORY_ONLY)
+ protected def onStart() {
+ blockGenerator.start()
+ logInfo("BlockGenerator started on thread " + receivingThread)
+ try {
+ while(true) {
+ blockGenerator += TestReceiver.counter.getAndIncrement
+ Thread.sleep(0)
+ }
+ } finally {
+ logInfo("Receiving stopped at count value of " + TestReceiver.counter.get())
+ }
+ }
+
+ protected def onStop() {
+ blockGenerator.stop()
+ }
+}
+
+object TestReceiver {
+ val counter = new AtomicInteger(1)
+}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
index 201630672a..aa2d5c2fc2 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
@@ -277,7 +277,7 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms")
assert(output.size === numExpectedOutput, "Unexpected number of outputs generated")
- Thread.sleep(500) // Give some time for the forgetting old RDDs to complete
+ Thread.sleep(100) // Give some time for the forgetting old RDDs to complete
} catch {
case e: Exception => {e.printStackTrace(); throw e}
} finally {