aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortedyu <yuzhihong@gmail.com>2015-11-17 22:47:53 -0800
committerTathagata Das <tathagata.das1565@gmail.com>2015-11-17 22:47:53 -0800
commit446738e51fcda50cf2dc44123ff6bf12a1611dc0 (patch)
tree8940cccf64649eedfc1834e20146d7b1761cec5c
parent8fb775ba874dd0488667bf299a7b49760062dc00 (diff)
downloadspark-446738e51fcda50cf2dc44123ff6bf12a1611dc0.tar.gz
spark-446738e51fcda50cf2dc44123ff6bf12a1611dc0.tar.bz2
spark-446738e51fcda50cf2dc44123ff6bf12a1611dc0.zip
[SPARK-11761] Prevent the call to StreamingContext#stop() in the listener bus's thread
See discussion toward the tail of https://github.com/apache/spark/pull/9723 From zsxwing : ``` The user should not call stop or other long-time work in a listener since it will block the listener thread, and prevent from stopping SparkContext/StreamingContext. I cannot see an approach since we need to stop the listener bus's thread before stopping SparkContext/StreamingContext totally. ``` Proposed solution is to prevent the call to StreamingContext#stop() in the listener bus's thread. Author: tedyu <yuzhihong@gmail.com> Closes #9741 from tedyu/master.
-rw-r--r--core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala46
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala6
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala34
3 files changed, 67 insertions, 19 deletions
diff --git a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala
index c20627b056..6c1fca71f2 100644
--- a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala
+++ b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala
@@ -19,6 +19,7 @@ package org.apache.spark.util
import java.util.concurrent._
import java.util.concurrent.atomic.AtomicBoolean
+import scala.util.DynamicVariable
import org.apache.spark.SparkContext
@@ -60,25 +61,27 @@ private[spark] abstract class AsynchronousListenerBus[L <: AnyRef, E](name: Stri
private val listenerThread = new Thread(name) {
setDaemon(true)
override def run(): Unit = Utils.tryOrStopSparkContext(sparkContext) {
- while (true) {
- eventLock.acquire()
- self.synchronized {
- processingEvent = true
- }
- try {
- val event = eventQueue.poll
- if (event == null) {
- // Get out of the while loop and shutdown the daemon thread
- if (!stopped.get) {
- throw new IllegalStateException("Polling `null` from eventQueue means" +
- " the listener bus has been stopped. So `stopped` must be true")
- }
- return
- }
- postToAll(event)
- } finally {
+ AsynchronousListenerBus.withinListenerThread.withValue(true) {
+ while (true) {
+ eventLock.acquire()
self.synchronized {
- processingEvent = false
+ processingEvent = true
+ }
+ try {
+ val event = eventQueue.poll
+ if (event == null) {
+ // Get out of the while loop and shutdown the daemon thread
+ if (!stopped.get) {
+ throw new IllegalStateException("Polling `null` from eventQueue means" +
+ " the listener bus has been stopped. So `stopped` must be true")
+ }
+ return
+ }
+ postToAll(event)
+ } finally {
+ self.synchronized {
+ processingEvent = false
+ }
}
}
}
@@ -177,3 +180,10 @@ private[spark] abstract class AsynchronousListenerBus[L <: AnyRef, E](name: Stri
*/
def onDropEvent(event: E): Unit
}
+
+private[spark] object AsynchronousListenerBus {
+ /* Allows for Context to check whether stop() call is made within listener thread
+ */
+ val withinListenerThread: DynamicVariable[Boolean] = new DynamicVariable[Boolean](false)
+}
+
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
index 97113835f3..aee172a4f5 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala
@@ -44,7 +44,7 @@ import org.apache.spark.streaming.dstream._
import org.apache.spark.streaming.receiver.{ActorReceiver, ActorSupervisorStrategy, Receiver}
import org.apache.spark.streaming.scheduler.{JobScheduler, StreamingListener}
import org.apache.spark.streaming.ui.{StreamingJobProgressListener, StreamingTab}
-import org.apache.spark.util.{CallSite, ShutdownHookManager, ThreadUtils, Utils}
+import org.apache.spark.util.{AsynchronousListenerBus, CallSite, ShutdownHookManager, ThreadUtils, Utils}
/**
* Main entry point for Spark Streaming functionality. It provides methods used to create
@@ -693,6 +693,10 @@ class StreamingContext private[streaming] (
*/
def stop(stopSparkContext: Boolean, stopGracefully: Boolean): Unit = {
var shutdownHookRefToRemove: AnyRef = null
+ if (AsynchronousListenerBus.withinListenerThread.value) {
+ throw new SparkException("Cannot stop StreamingContext within listener thread of" +
+ " AsynchronousListenerBus")
+ }
synchronized {
try {
state match {
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala
index 5dc0472c77..df4575ab25 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala
@@ -21,6 +21,7 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, SynchronizedBuffer, Synch
import scala.concurrent.Future
import scala.concurrent.ExecutionContext.Implicits.global
+import org.apache.spark.SparkException
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.streaming.receiver.Receiver
@@ -161,6 +162,14 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers {
}
}
+ test("don't call ssc.stop in listener") {
+ ssc = new StreamingContext("local[2]", "ssc", Milliseconds(1000))
+ val inputStream = ssc.receiverStream(new StreamingListenerSuiteReceiver)
+ inputStream.foreachRDD(_.count)
+
+ startStreamingContextAndCallStop(ssc)
+ }
+
test("onBatchCompleted with successful batch") {
ssc = new StreamingContext("local[2]", "test", Milliseconds(1000))
val inputStream = ssc.receiverStream(new StreamingListenerSuiteReceiver)
@@ -207,6 +216,17 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers {
assert(failureReasons(1).contains("This is another failed job"))
}
+ private def startStreamingContextAndCallStop(_ssc: StreamingContext): Unit = {
+ val contextStoppingCollector = new StreamingContextStoppingCollector(_ssc)
+ _ssc.addStreamingListener(contextStoppingCollector)
+ val batchCounter = new BatchCounter(_ssc)
+ _ssc.start()
+ // Make sure running at least one batch
+ batchCounter.waitUntilBatchesCompleted(expectedNumCompletedBatches = 1, timeout = 10000)
+ _ssc.stop()
+ assert(contextStoppingCollector.sparkExSeen)
+ }
+
private def startStreamingContextAndCollectFailureReasons(
_ssc: StreamingContext, isFailed: Boolean = false): Map[Int, String] = {
val failureReasonsCollector = new FailureReasonsCollector()
@@ -320,3 +340,17 @@ class FailureReasonsCollector extends StreamingListener {
}
}
}
+/**
+ * A StreamingListener that calls StreamingContext.stop().
+ */
+class StreamingContextStoppingCollector(val ssc: StreamingContext) extends StreamingListener {
+ @volatile var sparkExSeen = false
+ override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) {
+ try {
+ ssc.stop()
+ } catch {
+ case se: SparkException =>
+ sparkExSeen = true
+ }
+ }
+}