aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala57
-rw-r--r--core/src/main/scala/org/apache/spark/util/ListenerBus.scala3
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala56
3 files changed, 95 insertions, 21 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 7f5aef1c75..a7adddb6c8 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -20,6 +20,7 @@ package org.apache.spark
import scala.language.implicitConversions
import java.io._
+import java.lang.reflect.Constructor
import java.net.URI
import java.util.{Arrays, Properties, UUID}
import java.util.concurrent.atomic.AtomicInteger
@@ -387,9 +388,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
}
executorAllocationManager.foreach(_.start())
- // At this point, all relevant SparkListeners have been registered, so begin releasing events
- listenerBus.start()
-
private[spark] val cleaner: Option[ContextCleaner] = {
if (conf.getBoolean("spark.cleaner.referenceTracking", true)) {
Some(new ContextCleaner(this))
@@ -399,6 +397,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
}
cleaner.foreach(_.start())
+ setupAndStartListenerBus()
postEnvironmentUpdate()
postApplicationStart()
@@ -1563,6 +1562,58 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
/** Register a new RDD, returning its RDD ID */
private[spark] def newRddId(): Int = nextRddId.getAndIncrement()
+ /**
+ * Registers listeners specified in spark.extraListeners, then starts the listener bus.
+ * This should be called after all internal listeners have been registered with the listener bus
+ * (e.g. after the web UI and event logging listeners have been registered).
+ */
+ private def setupAndStartListenerBus(): Unit = {
+ // Use reflection to instantiate listeners specified via `spark.extraListeners`
+ try {
+ val listenerClassNames: Seq[String] =
+ conf.get("spark.extraListeners", "").split(',').map(_.trim).filter(_ != "")
+ for (className <- listenerClassNames) {
+ // Use reflection to find the right constructor
+ val constructors = {
+ val listenerClass = Class.forName(className)
+ listenerClass.getConstructors.asInstanceOf[Array[Constructor[_ <: SparkListener]]]
+ }
+ val constructorTakingSparkConf = constructors.find { c =>
+ c.getParameterTypes.sameElements(Array(classOf[SparkConf]))
+ }
+ lazy val zeroArgumentConstructor = constructors.find { c =>
+ c.getParameterTypes.isEmpty
+ }
+ val listener: SparkListener = {
+ if (constructorTakingSparkConf.isDefined) {
+ constructorTakingSparkConf.get.newInstance(conf)
+ } else if (zeroArgumentConstructor.isDefined) {
+ zeroArgumentConstructor.get.newInstance()
+ } else {
+ throw new SparkException(
+ s"$className did not have a zero-argument constructor or a" +
+ " single-argument constructor that accepts SparkConf. Note: if the class is" +
+ " defined inside of another Scala class, then its constructors may accept an" +
+ " implicit parameter that references the enclosing class; in this case, you must" +
+ " define the listener as a top-level class in order to prevent this extra" +
+ " parameter from breaking Spark's ability to find a valid constructor.")
+ }
+ }
+ listenerBus.addListener(listener)
+ logInfo(s"Registered listener $className")
+ }
+ } catch {
+ case e: Exception =>
+ try {
+ stop()
+ } finally {
+ throw new SparkException(s"Exception when registering SparkListener", e)
+ }
+ }
+
+ listenerBus.start()
+ }
+
/** Post the application start event */
private def postApplicationStart() {
// Note: this code assumes that the task scheduler has been initialized and has contacted
diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala
index bd0aa4dc46..d60b8b9a31 100644
--- a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala
+++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala
@@ -28,7 +28,8 @@ import org.apache.spark.Logging
*/
private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging {
- private val listeners = new CopyOnWriteArrayList[L]
+ // Marked `private[spark]` for access in tests.
+ private[spark] val listeners = new CopyOnWriteArrayList[L]
/**
* Add a listener to listen events. This method is thread-safe and can be called in any thread.
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
index 0fb1bdd30d..3a41ee8d4a 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
@@ -20,26 +20,22 @@ package org.apache.spark.scheduler
import java.util.concurrent.Semaphore
import scala.collection.mutable
+import scala.collection.JavaConversions._
-import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
-import org.scalatest.Matchers
+import org.scalatest.{FunSuite, Matchers}
-import org.apache.spark.{LocalSparkContext, SparkContext}
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.util.ResetSystemProperties
+import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext}
-class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers with BeforeAndAfter
- with BeforeAndAfterAll with ResetSystemProperties {
+class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers
+ with ResetSystemProperties {
/** Length of time to wait while draining listener events. */
val WAIT_TIMEOUT_MILLIS = 10000
val jobCompletionTime = 1421191296660L
- before {
- sc = new SparkContext("local", "SparkListenerSuite")
- }
-
test("basic creation and shutdown of LiveListenerBus") {
val counter = new BasicJobCounter
val bus = new LiveListenerBus
@@ -127,6 +123,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers
}
test("basic creation of StageInfo") {
+ sc = new SparkContext("local", "SparkListenerSuite")
val listener = new SaveStageAndTaskInfo
sc.addSparkListener(listener)
val rdd1 = sc.parallelize(1 to 100, 4)
@@ -148,6 +145,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers
}
test("basic creation of StageInfo with shuffle") {
+ sc = new SparkContext("local", "SparkListenerSuite")
val listener = new SaveStageAndTaskInfo
sc.addSparkListener(listener)
val rdd1 = sc.parallelize(1 to 100, 4)
@@ -185,6 +183,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers
}
test("StageInfo with fewer tasks than partitions") {
+ sc = new SparkContext("local", "SparkListenerSuite")
val listener = new SaveStageAndTaskInfo
sc.addSparkListener(listener)
val rdd1 = sc.parallelize(1 to 100, 4)
@@ -201,6 +200,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers
}
test("local metrics") {
+ sc = new SparkContext("local", "SparkListenerSuite")
val listener = new SaveStageAndTaskInfo
sc.addSparkListener(listener)
sc.addSparkListener(new StatsReportListener)
@@ -267,6 +267,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers
}
test("onTaskGettingResult() called when result fetched remotely") {
+ sc = new SparkContext("local", "SparkListenerSuite")
val listener = new SaveTaskEvents
sc.addSparkListener(listener)
@@ -287,6 +288,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers
}
test("onTaskGettingResult() not called when result sent directly") {
+ sc = new SparkContext("local", "SparkListenerSuite")
val listener = new SaveTaskEvents
sc.addSparkListener(listener)
@@ -302,6 +304,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers
}
test("onTaskEnd() should be called for all started tasks, even after job has been killed") {
+ sc = new SparkContext("local", "SparkListenerSuite")
val WAIT_TIMEOUT_MILLIS = 10000
val listener = new SaveTaskEvents
sc.addSparkListener(listener)
@@ -356,6 +359,17 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers
assert(jobCounter2.count === 5)
}
+ test("registering listeners via spark.extraListeners") {
+ val conf = new SparkConf().setMaster("local").setAppName("test")
+ .set("spark.extraListeners", classOf[ListenerThatAcceptsSparkConf].getName + "," +
+ classOf[BasicJobCounter].getName)
+ sc = new SparkContext(conf)
+ sc.listenerBus.listeners.collect { case x: BasicJobCounter => x}.size should be (1)
+ sc.listenerBus.listeners.collect {
+ case x: ListenerThatAcceptsSparkConf => x
+ }.size should be (1)
+ }
+
/**
* Assert that the given list of numbers has an average that is greater than zero.
*/
@@ -364,14 +378,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers
}
/**
- * A simple listener that counts the number of jobs observed.
- */
- private class BasicJobCounter extends SparkListener {
- var count = 0
- override def onJobEnd(job: SparkListenerJobEnd) = count += 1
- }
-
- /**
* A simple listener that saves all task infos and task metrics.
*/
private class SaveStageAndTaskInfo extends SparkListener {
@@ -423,3 +429,19 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers
}
}
+
+// These classes can't be declared inside of the SparkListenerSuite class because we don't want
+// their constructors to contain references to SparkListenerSuite:
+
+/**
+ * A simple listener that counts the number of jobs observed.
+ */
+private class BasicJobCounter extends SparkListener {
+ var count = 0
+ override def onJobEnd(job: SparkListenerJobEnd) = count += 1
+}
+
+private class ListenerThatAcceptsSparkConf(conf: SparkConf) extends SparkListener {
+ var count = 0
+ override def onJobEnd(job: SparkListenerJobEnd) = count += 1
+}