aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/util/ListenerBus.scala8
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala20
-rw-r--r--yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala17
3 files changed, 29 insertions, 16 deletions
diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala
index d60b8b9a31..a725767d08 100644
--- a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala
+++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala
@@ -19,9 +19,12 @@ package org.apache.spark.util
import java.util.concurrent.CopyOnWriteArrayList
+import scala.collection.JavaConversions._
+import scala.reflect.ClassTag
import scala.util.control.NonFatal
import org.apache.spark.Logging
+import org.apache.spark.scheduler.SparkListener
/**
* An event bus which posts events to its listeners.
@@ -64,4 +67,9 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging {
*/
def onPostEvent(listener: L, event: E): Unit
+ private[spark] def findListenersByClass[T <: L : ClassTag](): Seq[T] = {
+ val c = implicitly[ClassTag[T]].runtimeClass
+ listeners.filter(_.getClass == c).map(_.asInstanceOf[T]).toSeq
+ }
+
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala
index 9cdb42814c..c93d16f8a1 100644
--- a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.deploy
import java.net.URL
+import scala.collection.JavaConversions._
import scala.collection.mutable
import scala.io.Source
@@ -65,16 +66,17 @@ class LogUrlsStandaloneSuite extends FunSuite with LocalSparkContext {
new MySparkConf().setAll(getAll)
}
}
- val conf = new MySparkConf()
+ val conf = new MySparkConf().set(
+ "spark.extraListeners", classOf[SaveExecutorInfo].getName)
sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
- val listener = new SaveExecutorInfo
- sc.addSparkListener(listener)
-
// Trigger a job so that executors get added
sc.parallelize(1 to 100, 4).map(_.toString).count()
assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
+ val listeners = sc.listenerBus.findListenersByClass[SaveExecutorInfo]
+ assert(listeners.size === 1)
+ val listener = listeners(0)
listener.addedExecutorInfos.values.foreach { info =>
assert(info.logUrlMap.nonEmpty)
info.logUrlMap.values.foreach { logUrl =>
@@ -82,12 +84,12 @@ class LogUrlsStandaloneSuite extends FunSuite with LocalSparkContext {
}
}
}
+}
- private class SaveExecutorInfo extends SparkListener {
- val addedExecutorInfos = mutable.Map[String, ExecutorInfo]()
+private[spark] class SaveExecutorInfo extends SparkListener {
+ val addedExecutorInfos = mutable.Map[String, ExecutorInfo]()
- override def onExecutorAdded(executor: SparkListenerExecutorAdded) {
- addedExecutorInfos(executor.executorId) = executor.executorInfo
- }
+ override def onExecutorAdded(executor: SparkListenerExecutorAdded) {
+ addedExecutorInfos(executor.executorId) = executor.executorInfo
}
}
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
index 76952e3341..a18c94d4ab 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
@@ -33,7 +33,7 @@ import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers}
import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException, TestUtils}
import org.apache.spark.scheduler.cluster.ExecutorInfo
-import org.apache.spark.scheduler.{SparkListener, SparkListenerExecutorAdded}
+import org.apache.spark.scheduler.{SparkListenerJobStart, SparkListener, SparkListenerExecutorAdded}
import org.apache.spark.util.Utils
/**
@@ -282,10 +282,10 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit
}
-private class SaveExecutorInfo extends SparkListener {
+private[spark] class SaveExecutorInfo extends SparkListener {
val addedExecutorInfos = mutable.Map[String, ExecutorInfo]()
- override def onExecutorAdded(executor : SparkListenerExecutorAdded) {
+ override def onExecutorAdded(executor: SparkListenerExecutorAdded) {
addedExecutorInfos(executor.executorId) = executor.executorInfo
}
}
@@ -293,7 +293,6 @@ private class SaveExecutorInfo extends SparkListener {
private object YarnClusterDriver extends Logging with Matchers {
val WAIT_TIMEOUT_MILLIS = 10000
- var listener: SaveExecutorInfo = null
def main(args: Array[String]): Unit = {
if (args.length != 1) {
@@ -306,10 +305,9 @@ private object YarnClusterDriver extends Logging with Matchers {
System.exit(1)
}
- listener = new SaveExecutorInfo
val sc = new SparkContext(new SparkConf()
+ .set("spark.extraListeners", classOf[SaveExecutorInfo].getName)
.setAppName("yarn \"test app\" 'with quotes' and \\back\\slashes and $dollarSigns"))
- sc.addSparkListener(listener)
val status = new File(args(0))
var result = "failure"
try {
@@ -323,7 +321,12 @@ private object YarnClusterDriver extends Logging with Matchers {
}
// verify log urls are present
- listener.addedExecutorInfos.values.foreach { info =>
+ val listeners = sc.listenerBus.findListenersByClass[SaveExecutorInfo]
+ assert(listeners.size === 1)
+ val listener = listeners(0)
+ val executorInfos = listener.addedExecutorInfos.values
+ assert(executorInfos.nonEmpty)
+ executorInfos.foreach { info =>
assert(info.logUrlMap.nonEmpty)
}
}