aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala2
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala19
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala58
3 files changed, 75 insertions, 4 deletions
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
index 0e0f5bd3b9..b3ffc71904 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala
@@ -73,7 +73,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging {
logDebug("Stopping JobScheduler")
// First, stop receiving
- receiverTracker.stop()
+ receiverTracker.stop(processAllReceivedData)
// Second, stop generating jobs. If it has to process all received data,
// then this will wait for all the processing through JobScheduler to be over.
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
index 4f99886973..00456ab2a0 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
@@ -86,10 +86,10 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
}
/** Stop the receiver execution thread. */
- def stop() = synchronized {
+ def stop(graceful: Boolean) = synchronized {
if (!receiverInputStreams.isEmpty && actor != null) {
// First, stop the receivers
- if (!skipReceiverLaunch) receiverExecutor.stop()
+ if (!skipReceiverLaunch) receiverExecutor.stop(graceful)
// Finally, stop the actor
ssc.env.actorSystem.stop(actor)
@@ -218,6 +218,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
/** This thread class runs all the receivers on the cluster. */
class ReceiverLauncher {
@transient val env = ssc.env
+ @volatile @transient private var running = false
@transient val thread = new Thread() {
override def run() {
try {
@@ -233,7 +234,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
thread.start()
}
- def stop() {
+ def stop(graceful: Boolean) {
// Send the stop signal to all the receivers
stopReceivers()
@@ -241,6 +242,16 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
// That is, for the receivers to quit gracefully.
thread.join(10000)
+ if (graceful) {
+ val pollTime = 100
+ def done = { receiverInfo.isEmpty && !running }
+ logInfo("Waiting for receiver job to terminate gracefully")
+ while(!done) {
+ Thread.sleep(pollTime)
+ }
+ logInfo("Waited for receiver job to terminate gracefully")
+ }
+
// Check if all the receivers have been deregistered or not
if (!receiverInfo.isEmpty) {
logWarning("All of the receivers have not deregistered, " + receiverInfo)
@@ -295,7 +306,9 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
// Distribute the receivers and start them
logInfo("Starting " + receivers.length + " receivers")
+ running = true
ssc.sparkContext.runJob(tempRDD, ssc.sparkContext.clean(startReceiver))
+ running = false
logInfo("All of the receivers have been terminated")
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
index 9f352bdcb0..0b5af25e0f 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
@@ -205,6 +205,32 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w
}
}
+ test("stop slow receiver gracefully") {
+ val conf = new SparkConf().setMaster(master).setAppName(appName)
+ conf.set("spark.streaming.gracefulStopTimeout", "20000")
+ sc = new SparkContext(conf)
+ logInfo("==================================\n\n\n")
+ ssc = new StreamingContext(sc, Milliseconds(100))
+ var runningCount = 0
+ SlowTestReceiver.receivedAllRecords = false
+ //Create test receiver that sleeps in onStop()
+ val totalNumRecords = 15
+ val recordsPerSecond = 1
+ val input = ssc.receiverStream(new SlowTestReceiver(totalNumRecords, recordsPerSecond))
+ input.count().foreachRDD { rdd =>
+ val count = rdd.first()
+ runningCount += count.toInt
+ logInfo("Count = " + count + ", Running count = " + runningCount)
+ }
+ ssc.start()
+ ssc.awaitTermination(500)
+ ssc.stop(stopSparkContext = false, stopGracefully = true)
+ logInfo("Running count = " + runningCount)
+ assert(runningCount > 0)
+ assert(runningCount == totalNumRecords)
+ Thread.sleep(100)
+ }
+
test("awaitTermination") {
ssc = new StreamingContext(master, appName, batchDuration)
val inputStream = addInputStream(ssc)
@@ -319,6 +345,38 @@ object TestReceiver {
val counter = new AtomicInteger(1)
}
+/** Custom receiver for testing whether a slow receiver can be shutdown gracefully or not */
+class SlowTestReceiver(totalRecords: Int, recordsPerSecond: Int) extends Receiver[Int](StorageLevel.MEMORY_ONLY) with Logging {
+
+ var receivingThreadOption: Option[Thread] = None
+
+ def onStart() {
+ val thread = new Thread() {
+ override def run() {
+ logInfo("Receiving started")
+ for(i <- 1 to totalRecords) {
+ Thread.sleep(1000 / recordsPerSecond)
+ store(i)
+ }
+ SlowTestReceiver.receivedAllRecords = true
+ logInfo(s"Received all $totalRecords records")
+ }
+ }
+ receivingThreadOption = Some(thread)
+ thread.start()
+ }
+
+ def onStop() {
+ // Simulate slow receiver by waiting for all records to be produced
+ while(!SlowTestReceiver.receivedAllRecords) Thread.sleep(100)
+ // no cleanup to be done, the receiving thread should stop on it own
+ }
+}
+
+object SlowTestReceiver {
+ var receivedAllRecords = false
+}
+
/** Streaming application for testing DStream and RDD creation sites */
package object testPackage extends Assertions {
def test() {