From ad0954f6de29761e0e7e543212c5bfe1fdcbed9f Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 17 Jul 2015 14:00:31 -0700 Subject: [SPARK-5681] [STREAMING] Move 'stopReceivers' to the event loop to resolve the race condition This is an alternative way to fix `SPARK-5681`. It minimizes the changes. Closes #4467 Author: zsxwing Author: Liang-Chi Hsieh Closes #6294 from zsxwing/pr4467 and squashes the following commits: 709ac1f [zsxwing] Fix the comment e103e8a [zsxwing] Move ReceiverTracker.stop into ReceiverTracker.stop f637142 [zsxwing] Address minor code style comments a178d37 [zsxwing] Move 'stopReceivers' to the event looop to resolve the race condition 51fb07e [zsxwing] Fix the code style 3cb19a3 [zsxwing] Merge branch 'master' into pr4467 b4c29e7 [zsxwing] Stop receiver only if we start it c41ee94 [zsxwing] Make stopReceivers private 7c73c1f [zsxwing] Use trackerStateLock to protect trackerState a8120c0 [zsxwing] Merge branch 'master' into pr4467 7b1d9af [zsxwing] "case Throwable" => "case NonFatal" 15ed4a1 [zsxwing] Register before starting the receiver fff63f9 [zsxwing] Use a lock to eliminate the race condition when stopping receivers and registering receivers happen at the same time. e0ef72a [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into tracker_status_timeout 19b76d9 [Liang-Chi Hsieh] Remove timeout. 34c18dc [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into tracker_status_timeout c419677 [Liang-Chi Hsieh] Fix style. 9e1a760 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into tracker_status_timeout 355f9ce [Liang-Chi Hsieh] Separate register and start events for receivers. 3d568e8 [Liang-Chi Hsieh] Let receivers get registered first before going started. ae0d9fd [Liang-Chi Hsieh] Merge branch 'master' into tracker_status_timeout 77983f3 [Liang-Chi Hsieh] Add tracker status and stop to receive messages when stopping tracker. --- .../streaming/receiver/ReceiverSupervisor.scala | 42 ++++--- .../receiver/ReceiverSupervisorImpl.scala | 2 +- .../streaming/scheduler/ReceiverTracker.scala | 139 ++++++++++++++------- .../org/apache/spark/streaming/ReceiverSuite.scala | 2 + .../spark/streaming/StreamingContextSuite.scala | 15 +++ 5 files changed, 138 insertions(+), 62 deletions(-) (limited to 'streaming') diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala index eeb14ca3a4..6467029a27 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala @@ -22,6 +22,7 @@ import java.util.concurrent.CountDownLatch import scala.collection.mutable.ArrayBuffer import scala.concurrent._ +import scala.util.control.NonFatal import org.apache.spark.{Logging, SparkConf} import org.apache.spark.storage.StreamBlockId @@ -36,7 +37,7 @@ private[streaming] abstract class ReceiverSupervisor( conf: SparkConf ) extends Logging { - /** Enumeration to identify current state of the StreamingContext */ + /** Enumeration to identify current state of the Receiver */ object ReceiverState extends Enumeration { type CheckpointState = Value val Initialized, Started, Stopped = Value @@ -97,8 +98,8 @@ private[streaming] abstract class ReceiverSupervisor( /** Called when supervisor is stopped */ protected def onStop(message: String, error: Option[Throwable]) { } - /** Called when receiver is started */ - protected def onReceiverStart() { } + /** Called when receiver is started. Return true if the driver accepts us */ + protected def onReceiverStart(): Boolean /** Called when receiver is stopped */ protected def onReceiverStop(message: String, error: Option[Throwable]) { } @@ -121,13 +122,17 @@ private[streaming] abstract class ReceiverSupervisor( /** Start receiver */ def startReceiver(): Unit = synchronized { try { - logInfo("Starting receiver") - receiver.onStart() - logInfo("Called receiver onStart") - onReceiverStart() - receiverState = Started + if (onReceiverStart()) { + logInfo("Starting receiver") + receiverState = Started + receiver.onStart() + logInfo("Called receiver onStart") + } else { + // The driver refused us + stop("Registered unsuccessfully because Driver refused to start receiver " + streamId, None) + } } catch { - case t: Throwable => + case NonFatal(t) => stop("Error starting receiver " + streamId, Some(t)) } } @@ -136,12 +141,19 @@ private[streaming] abstract class ReceiverSupervisor( def stopReceiver(message: String, error: Option[Throwable]): Unit = synchronized { try { logInfo("Stopping receiver with message: " + message + ": " + error.getOrElse("")) - receiverState = Stopped - receiver.onStop() - logInfo("Called receiver onStop") - onReceiverStop(message, error) + receiverState match { + case Initialized => + logWarning("Skip stopping receiver because it has not yet stared") + case Started => + receiverState = Stopped + receiver.onStop() + logInfo("Called receiver onStop") + onReceiverStop(message, error) + case Stopped => + logWarning("Receiver has been stopped") + } } catch { - case t: Throwable => + case NonFatal(t) => logError("Error stopping receiver " + streamId + t.getStackTraceString) } } @@ -167,7 +179,7 @@ private[streaming] abstract class ReceiverSupervisor( }(futureExecutionContext) } - /** Check if receiver has been marked for stopping */ + /** Check if receiver has been marked for starting */ def isReceiverStarted(): Boolean = { logDebug("state = " + receiverState) receiverState == Started diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 6078cdf8f8..f6ba66b3ae 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -162,7 +162,7 @@ private[streaming] class ReceiverSupervisorImpl( env.rpcEnv.stop(endpoint) } - override protected def onReceiverStart() { + override protected def onReceiverStart(): Boolean = { val msg = RegisterReceiver( streamId, receiver.getClass.getSimpleName, Utils.localHostName(), endpoint) trackerEndpoint.askWithRetry[Boolean](msg) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 644e581cd8..6910d81d98 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -20,7 +20,6 @@ package org.apache.spark.streaming.scheduler import scala.collection.mutable.{ArrayBuffer, HashMap, SynchronizedMap} import scala.language.existentials import scala.math.max -import org.apache.spark.rdd._ import org.apache.spark.streaming.util.WriteAheadLogUtils import org.apache.spark.{Logging, SparkEnv, SparkException} @@ -47,6 +46,8 @@ private[streaming] case class ReportError(streamId: Int, message: String, error: private[streaming] case class DeregisterReceiver(streamId: Int, msg: String, error: String) extends ReceiverTrackerMessage +private[streaming] case object StopAllReceivers extends ReceiverTrackerMessage + /** * This class manages the execution of the receivers of ReceiverInputDStreams. Instance of * this class must be created after all input streams have been added and StreamingContext.start() @@ -71,13 +72,23 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false ) private val listenerBus = ssc.scheduler.listenerBus + /** Enumeration to identify current state of the ReceiverTracker */ + object TrackerState extends Enumeration { + type TrackerState = Value + val Initialized, Started, Stopping, Stopped = Value + } + import TrackerState._ + + /** State of the tracker. Protected by "trackerStateLock" */ + @volatile private var trackerState = Initialized + // endpoint is created when generator starts. // This not being null means the tracker has been started and not stopped private var endpoint: RpcEndpointRef = null /** Start the endpoint and receiver execution thread. */ def start(): Unit = synchronized { - if (endpoint != null) { + if (isTrackerStarted) { throw new SparkException("ReceiverTracker already started") } @@ -86,20 +97,46 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false "ReceiverTracker", new ReceiverTrackerEndpoint(ssc.env.rpcEnv)) if (!skipReceiverLaunch) receiverExecutor.start() logInfo("ReceiverTracker started") + trackerState = Started } } /** Stop the receiver execution thread. */ def stop(graceful: Boolean): Unit = synchronized { - if (!receiverInputStreams.isEmpty && endpoint != null) { + if (isTrackerStarted) { // First, stop the receivers - if (!skipReceiverLaunch) receiverExecutor.stop(graceful) + trackerState = Stopping + if (!skipReceiverLaunch) { + // Send the stop signal to all the receivers + endpoint.askWithRetry[Boolean](StopAllReceivers) + + // Wait for the Spark job that runs the receivers to be over + // That is, for the receivers to quit gracefully. + receiverExecutor.awaitTermination(10000) + + if (graceful) { + val pollTime = 100 + logInfo("Waiting for receiver job to terminate gracefully") + while (receiverInfo.nonEmpty || receiverExecutor.running) { + Thread.sleep(pollTime) + } + logInfo("Waited for receiver job to terminate gracefully") + } + + // Check if all the receivers have been deregistered or not + if (receiverInfo.nonEmpty) { + logWarning("Not all of the receivers have deregistered, " + receiverInfo) + } else { + logInfo("All of the receivers have deregistered successfully") + } + } // Finally, stop the endpoint ssc.env.rpcEnv.stop(endpoint) endpoint = null receivedBlockTracker.stop() logInfo("ReceiverTracker stopped") + trackerState = Stopped } } @@ -145,14 +182,23 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false host: String, receiverEndpoint: RpcEndpointRef, senderAddress: RpcAddress - ) { + ): Boolean = { if (!receiverInputStreamIds.contains(streamId)) { throw new SparkException("Register received for unexpected id " + streamId) } - receiverInfo(streamId) = ReceiverInfo( - streamId, s"${typ}-${streamId}", receiverEndpoint, true, host) - listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId))) - logInfo("Registered receiver for stream " + streamId + " from " + senderAddress) + + if (isTrackerStopping || isTrackerStopped) { + false + } else { + // "stopReceivers" won't happen at the same time because both "registerReceiver" and are + // called in the event loop. So here we can assume "stopReceivers" has not yet been called. If + // "stopReceivers" is called later, it should be able to see this receiver. + receiverInfo(streamId) = ReceiverInfo( + streamId, s"${typ}-${streamId}", receiverEndpoint, true, host) + listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId))) + logInfo("Registered receiver for stream " + streamId + " from " + senderAddress) + true + } } /** Deregister a receiver */ @@ -220,20 +266,33 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RegisterReceiver(streamId, typ, host, receiverEndpoint) => - registerReceiver(streamId, typ, host, receiverEndpoint, context.sender.address) - context.reply(true) + val successful = + registerReceiver(streamId, typ, host, receiverEndpoint, context.sender.address) + context.reply(successful) case AddBlock(receivedBlockInfo) => context.reply(addBlock(receivedBlockInfo)) case DeregisterReceiver(streamId, message, error) => deregisterReceiver(streamId, message, error) context.reply(true) + case StopAllReceivers => + assert(isTrackerStopping || isTrackerStopped) + stopReceivers() + context.reply(true) + } + + /** Send stop signal to the receivers. */ + private def stopReceivers() { + // Signal the receivers to stop + receiverInfo.values.flatMap { info => Option(info.endpoint)} + .foreach { _.send(StopReceiver) } + logInfo("Sent stop signal to all " + receiverInfo.size + " receivers") } } /** This thread class runs all the receivers on the cluster. */ class ReceiverLauncher { @transient val env = ssc.env - @volatile @transient private var running = false + @volatile @transient var running = false @transient val thread = new Thread() { override def run() { try { @@ -249,31 +308,6 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false thread.start() } - def stop(graceful: Boolean) { - // Send the stop signal to all the receivers - stopReceivers() - - // Wait for the Spark job that runs the receivers to be over - // That is, for the receivers to quit gracefully. - thread.join(10000) - - if (graceful) { - val pollTime = 100 - logInfo("Waiting for receiver job to terminate gracefully") - while (receiverInfo.nonEmpty || running) { - Thread.sleep(pollTime) - } - logInfo("Waited for receiver job to terminate gracefully") - } - - // Check if all the receivers have been deregistered or not - if (receiverInfo.nonEmpty) { - logWarning("Not all of the receivers have deregistered, " + receiverInfo) - } else { - logInfo("All of the receivers have deregistered successfully") - } - } - /** * Get the list of executors excluding driver */ @@ -358,17 +392,30 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false // Distribute the receivers and start them logInfo("Starting " + receivers.length + " receivers") running = true - ssc.sparkContext.runJob(tempRDD, ssc.sparkContext.clean(startReceiver)) - running = false - logInfo("All of the receivers have been terminated") + try { + ssc.sparkContext.runJob(tempRDD, ssc.sparkContext.clean(startReceiver)) + logInfo("All of the receivers have been terminated") + } finally { + running = false + } } - /** Stops the receivers. */ - private def stopReceivers() { - // Signal the receivers to stop - receiverInfo.values.flatMap { info => Option(info.endpoint)} - .foreach { _.send(StopReceiver) } - logInfo("Sent stop signal to all " + receiverInfo.size + " receivers") + /** + * Wait until the Spark job that runs the receivers is terminated, or return when + * `milliseconds` elapses + */ + def awaitTermination(milliseconds: Long): Unit = { + thread.join(milliseconds) } } + + /** Check if tracker has been marked for starting */ + private def isTrackerStarted(): Boolean = trackerState == Started + + /** Check if tracker has been marked for stopping */ + private def isTrackerStopping(): Boolean = trackerState == Stopping + + /** Check if tracker has been marked for stopped */ + private def isTrackerStopped(): Boolean = trackerState == Stopped + } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index 5d7127627e..13b4d17c86 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -346,6 +346,8 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable { def reportError(message: String, throwable: Throwable) { errors += throwable } + + override protected def onReceiverStart(): Boolean = true } /** diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index f588cf5bc1..4bba9691f8 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -285,6 +285,21 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo } } + test("stop gracefully even if a receiver misses StopReceiver") { + // This is not a deterministic unit. But if this unit test is flaky, then there is definitely + // something wrong. See SPARK-5681 + val conf = new SparkConf().setMaster(master).setAppName(appName) + sc = new SparkContext(conf) + ssc = new StreamingContext(sc, Milliseconds(100)) + val input = ssc.receiverStream(new TestReceiver) + input.foreachRDD(_ => {}) + ssc.start() + // Call `ssc.stop` at once so that it's possible that the receiver will miss "StopReceiver" + failAfter(30000 millis) { + ssc.stop(stopSparkContext = true, stopGracefully = true) + } + } + test("stop slow receiver gracefully") { val conf = new SparkConf().setMaster(master).setAppName(appName) conf.set("spark.streaming.gracefulStopTimeout", "20000s") -- cgit v1.2.3