From 5c0dafc2c8734a421206a808b73be67b66264dd7 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 13 May 2014 18:32:32 -0700 Subject: [SPARK-1816] LiveListenerBus dies if a listener throws an exception The solution is to wrap a try / catch / log around the posting of each event to each listener. Author: Andrew Or Closes #759 from andrewor14/listener-die and squashes the following commits: aee5107 [Andrew Or] Merge branch 'master' of github.com:apache/spark into listener-die 370939f [Andrew Or] Remove two layers of indirection 422d278 [Andrew Or] Explicitly throw an exception instead of 1 / 0 0df0e2a [Andrew Or] Try/catch and log exceptions when posting events --- .../apache/spark/scheduler/LiveListenerBus.scala | 36 ++++++++++++---- .../apache/spark/scheduler/SparkListenerBus.scala | 50 +++++++++++++++------- .../main/scala/org/apache/spark/util/Utils.scala | 2 +- .../spark/scheduler/SparkListenerSuite.scala | 50 ++++++++++++++++++++-- 4 files changed, 109 insertions(+), 29 deletions(-) (limited to 'core') diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala index dec3316bf7..36a6e6338f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala @@ -20,6 +20,7 @@ package org.apache.spark.scheduler import java.util.concurrent.{LinkedBlockingQueue, Semaphore} import org.apache.spark.Logging +import org.apache.spark.util.Utils /** * Asynchronously passes SparkListenerEvents to registered SparkListeners. @@ -42,7 +43,7 @@ private[spark] class LiveListenerBus extends SparkListenerBus with Logging { private val listenerThread = new Thread("SparkListenerBus") { setDaemon(true) - override def run() { + override def run(): Unit = Utils.logUncaughtExceptions { while (true) { eventLock.acquire() // Atomically remove and process this event @@ -77,11 +78,8 @@ private[spark] class LiveListenerBus extends SparkListenerBus with Logging { val eventAdded = eventQueue.offer(event) if (eventAdded) { eventLock.release() - } else if (!queueFullErrorMessageLogged) { - logError("Dropping SparkListenerEvent because no remaining room in event queue. " + - "This likely means one of the SparkListeners is too slow and cannot keep up with the " + - "rate at which tasks are being started by the scheduler.") - queueFullErrorMessageLogged = true + } else { + logQueueFullErrorMessage() } } @@ -96,13 +94,18 @@ private[spark] class LiveListenerBus extends SparkListenerBus with Logging { if (System.currentTimeMillis > finishTime) { return false } - /* Sleep rather than using wait/notify, because this is used only for testing and wait/notify - * add overhead in the general case. */ + /* Sleep rather than using wait/notify, because this is used only for testing and + * wait/notify add overhead in the general case. */ Thread.sleep(10) } true } + /** + * For testing only. Return whether the listener daemon thread is still alive. + */ + def listenerThreadIsAlive: Boolean = synchronized { listenerThread.isAlive } + /** * Return whether the event queue is empty. * @@ -111,6 +114,23 @@ private[spark] class LiveListenerBus extends SparkListenerBus with Logging { */ def queueIsEmpty: Boolean = synchronized { eventQueue.isEmpty } + /** + * Log an error message to indicate that the event queue is full. Do this only once. + */ + private def logQueueFullErrorMessage(): Unit = { + if (!queueFullErrorMessageLogged) { + if (listenerThread.isAlive) { + logError("Dropping SparkListenerEvent because no remaining room in event queue. " + + "This likely means one of the SparkListeners is too slow and cannot keep up with" + + "the rate at which tasks are being started by the scheduler.") + } else { + logError("SparkListenerBus thread is dead! This means SparkListenerEvents have not" + + "been (and will no longer be) propagated to listeners for some time.") + } + queueFullErrorMessageLogged = true + } + } + def stop() { if (!started) { throw new IllegalStateException("Attempted to stop a listener bus that has not yet started!") diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index 0286aac876..ed9fb24bc8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -20,10 +20,13 @@ package org.apache.spark.scheduler import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import org.apache.spark.Logging +import org.apache.spark.util.Utils + /** * A SparkListenerEvent bus that relays events to its listeners */ -private[spark] trait SparkListenerBus { +private[spark] trait SparkListenerBus extends Logging { // SparkListeners attached to this event bus protected val sparkListeners = new ArrayBuffer[SparkListener] @@ -34,38 +37,53 @@ private[spark] trait SparkListenerBus { } /** - * Post an event to all attached listeners. This does nothing if the event is - * SparkListenerShutdown. + * Post an event to all attached listeners. + * This does nothing if the event is SparkListenerShutdown. */ def postToAll(event: SparkListenerEvent) { event match { case stageSubmitted: SparkListenerStageSubmitted => - sparkListeners.foreach(_.onStageSubmitted(stageSubmitted)) + foreachListener(_.onStageSubmitted(stageSubmitted)) case stageCompleted: SparkListenerStageCompleted => - sparkListeners.foreach(_.onStageCompleted(stageCompleted)) + foreachListener(_.onStageCompleted(stageCompleted)) case jobStart: SparkListenerJobStart => - sparkListeners.foreach(_.onJobStart(jobStart)) + foreachListener(_.onJobStart(jobStart)) case jobEnd: SparkListenerJobEnd => - sparkListeners.foreach(_.onJobEnd(jobEnd)) + foreachListener(_.onJobEnd(jobEnd)) case taskStart: SparkListenerTaskStart => - sparkListeners.foreach(_.onTaskStart(taskStart)) + foreachListener(_.onTaskStart(taskStart)) case taskGettingResult: SparkListenerTaskGettingResult => - sparkListeners.foreach(_.onTaskGettingResult(taskGettingResult)) + foreachListener(_.onTaskGettingResult(taskGettingResult)) case taskEnd: SparkListenerTaskEnd => - sparkListeners.foreach(_.onTaskEnd(taskEnd)) + foreachListener(_.onTaskEnd(taskEnd)) case environmentUpdate: SparkListenerEnvironmentUpdate => - sparkListeners.foreach(_.onEnvironmentUpdate(environmentUpdate)) + foreachListener(_.onEnvironmentUpdate(environmentUpdate)) case blockManagerAdded: SparkListenerBlockManagerAdded => - sparkListeners.foreach(_.onBlockManagerAdded(blockManagerAdded)) + foreachListener(_.onBlockManagerAdded(blockManagerAdded)) case blockManagerRemoved: SparkListenerBlockManagerRemoved => - sparkListeners.foreach(_.onBlockManagerRemoved(blockManagerRemoved)) + foreachListener(_.onBlockManagerRemoved(blockManagerRemoved)) case unpersistRDD: SparkListenerUnpersistRDD => - sparkListeners.foreach(_.onUnpersistRDD(unpersistRDD)) + foreachListener(_.onUnpersistRDD(unpersistRDD)) case applicationStart: SparkListenerApplicationStart => - sparkListeners.foreach(_.onApplicationStart(applicationStart)) + foreachListener(_.onApplicationStart(applicationStart)) case applicationEnd: SparkListenerApplicationEnd => - sparkListeners.foreach(_.onApplicationEnd(applicationEnd)) + foreachListener(_.onApplicationEnd(applicationEnd)) case SparkListenerShutdown => } } + + /** + * Apply the given function to all attached listeners, catching and logging any exception. + */ + private def foreachListener(f: SparkListener => Unit): Unit = { + sparkListeners.foreach { listener => + try { + f(listener) + } catch { + case e: Exception => + logError(s"Listener ${Utils.getFormattedClassName(listener)} threw an exception", e) + } + } + } + } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 0631e54237..99ef6dd1fa 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1128,7 +1128,7 @@ private[spark] object Utils extends Logging { } /** - * Executes the given block, printing and re-throwing any uncaught exceptions. + * Execute the given block, logging and re-throwing any uncaught exception. * This is particularly useful for wrapping code that runs in a thread, to ensure * that exceptions are printed, and to avoid having to catch Throwable. */ diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 4e9fd07e68..5426e578a9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -331,16 +331,47 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc } } - def checkNonZeroAvg(m: Traversable[Long], msg: String) { + test("SparkListener moves on if a listener throws an exception") { + val badListener = new BadListener + val jobCounter1 = new BasicJobCounter + val jobCounter2 = new BasicJobCounter + val bus = new LiveListenerBus + + // Propagate events to bad listener first + bus.addListener(badListener) + bus.addListener(jobCounter1) + bus.addListener(jobCounter2) + bus.start() + + // Post events to all listeners, and wait until the queue is drained + (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, JobSucceeded)) } + assert(bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + + // The exception should be caught, and the event should be propagated to other listeners + assert(bus.listenerThreadIsAlive) + assert(jobCounter1.count === 5) + assert(jobCounter2.count === 5) + } + + /** + * Assert that the given list of numbers has an average that is greater than zero. + */ + private def checkNonZeroAvg(m: Traversable[Long], msg: String) { assert(m.sum / m.size.toDouble > 0.0, msg) } - class BasicJobCounter extends SparkListener { + /** + * A simple listener that counts the number of jobs observed. + */ + private class BasicJobCounter extends SparkListener { var count = 0 override def onJobEnd(job: SparkListenerJobEnd) = count += 1 } - class SaveStageAndTaskInfo extends SparkListener { + /** + * A simple listener that saves all task infos and task metrics. + */ + private class SaveStageAndTaskInfo extends SparkListener { val stageInfos = mutable.Map[StageInfo, Seq[(TaskInfo, TaskMetrics)]]() var taskInfoMetrics = mutable.Buffer[(TaskInfo, TaskMetrics)]() @@ -358,7 +389,10 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc } } - class SaveTaskEvents extends SparkListener { + /** + * A simple listener that saves the task indices for all task events. + */ + private class SaveTaskEvents extends SparkListener { val startedTasks = new mutable.HashSet[Int]() val startedGettingResultTasks = new mutable.HashSet[Int]() val endedTasks = new mutable.HashSet[Int]() @@ -377,4 +411,12 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc startedGettingResultTasks += taskGettingResult.taskInfo.index } } + + /** + * A simple listener that throws an exception on job end. + */ + private class BadListener extends SparkListener { + override def onJobEnd(jobEnd: SparkListenerJobEnd) = { throw new Exception } + } + } -- cgit v1.2.3