aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/Accumulable.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/util/JsonProtocol.scala6
-rw-r--r--core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala14
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala8
-rw-r--r--core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala16
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala21
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala23
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala24
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala38
13 files changed, 133 insertions, 45 deletions
diff --git a/core/src/main/scala/org/apache/spark/Accumulable.scala b/core/src/main/scala/org/apache/spark/Accumulable.scala
index 52f572b63f..601b503d12 100644
--- a/core/src/main/scala/org/apache/spark/Accumulable.scala
+++ b/core/src/main/scala/org/apache/spark/Accumulable.scala
@@ -22,6 +22,7 @@ import java.io.{ObjectInputStream, Serializable}
import scala.collection.generic.Growable
import scala.reflect.ClassTag
+import org.apache.spark.scheduler.AccumulableInfo
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.util.Utils
@@ -187,6 +188,13 @@ class Accumulable[R, T] private (
*/
private[spark] def setValueAny(newValue: Any): Unit = { setValue(newValue.asInstanceOf[R]) }
+ /**
+ * Create an [[AccumulableInfo]] representation of this [[Accumulable]] with the provided values.
+ */
+ private[spark] def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = {
+ new AccumulableInfo(id, name, update, value, internal, countFailedValues)
+ }
+
// Called by Java when deserializing an object
private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
in.defaultReadObject()
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index 8d10bf588e..0a6ebcb3e0 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -323,8 +323,8 @@ class TaskMetrics(initialAccums: Seq[Accumulator[_]]) extends Serializable {
* field is always empty, since this represents the partial updates recorded in this task,
* not the aggregated value across multiple tasks.
*/
- def accumulatorUpdates(): Seq[AccumulableInfo] = accums.map { a =>
- new AccumulableInfo(a.id, a.name, Some(a.localValue), None, a.isInternal, a.countFailedValues)
+ def accumulatorUpdates(): Seq[AccumulableInfo] = {
+ accums.map { a => a.toInfo(Some(a.localValue), None) }
}
// If we are reconstructing this TaskMetrics on the driver, some metrics may already be set.
diff --git a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala
index 9d45fff921..cedacad44a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala
@@ -35,6 +35,7 @@ import org.apache.spark.annotation.DeveloperApi
* @param value total accumulated value so far, maybe None if used on executors to describe a task
* @param internal whether this accumulator was internal
* @param countFailedValues whether to count this accumulator's partial value if the task failed
+ * @param metadata internal metadata associated with this accumulator, if any
*/
@DeveloperApi
case class AccumulableInfo private[spark] (
@@ -43,7 +44,9 @@ case class AccumulableInfo private[spark] (
update: Option[Any], // represents a partial update within a task
value: Option[Any],
private[spark] val internal: Boolean,
- private[spark] val countFailedValues: Boolean)
+ private[spark] val countFailedValues: Boolean,
+ // TODO: use this to identify internal task metrics instead of encoding it in the name
+ private[spark] val metadata: Option[String] = None)
/**
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 897479b500..ee0b8a1c95 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -1101,11 +1101,8 @@ class DAGScheduler(
acc ++= partialValue
// To avoid UI cruft, ignore cases where value wasn't updated
if (acc.name.isDefined && partialValue != acc.zero) {
- val name = acc.name
- stage.latestInfo.accumulables(id) = new AccumulableInfo(
- id, name, None, Some(acc.value), acc.isInternal, acc.countFailedValues)
- event.taskInfo.accumulables += new AccumulableInfo(
- id, name, Some(partialValue), Some(acc.value), acc.isInternal, acc.countFailedValues)
+ stage.latestInfo.accumulables(id) = acc.toInfo(None, Some(acc.value))
+ event.taskInfo.accumulables += acc.toInfo(Some(partialValue), Some(acc.value))
}
}
} catch {
diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
index dc8070cf8a..a2487eeb04 100644
--- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
@@ -290,7 +290,8 @@ private[spark] object JsonProtocol {
("Update" -> accumulableInfo.update.map { v => accumValueToJson(name, v) }) ~
("Value" -> accumulableInfo.value.map { v => accumValueToJson(name, v) }) ~
("Internal" -> accumulableInfo.internal) ~
- ("Count Failed Values" -> accumulableInfo.countFailedValues)
+ ("Count Failed Values" -> accumulableInfo.countFailedValues) ~
+ ("Metadata" -> accumulableInfo.metadata)
}
/**
@@ -728,7 +729,8 @@ private[spark] object JsonProtocol {
val value = Utils.jsonOption(json \ "Value").map { v => accumValueFromJson(name, v) }
val internal = (json \ "Internal").extractOpt[Boolean].getOrElse(false)
val countFailedValues = (json \ "Count Failed Values").extractOpt[Boolean].getOrElse(false)
- new AccumulableInfo(id, name, update, value, internal, countFailedValues)
+ val metadata = (json \ "Metadata").extractOpt[String]
+ new AccumulableInfo(id, name, update, value, internal, countFailedValues, metadata)
}
/**
diff --git a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala
index 15be0b194e..67c4595ed1 100644
--- a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala
@@ -551,8 +551,6 @@ private[spark] object TaskMetricsSuite extends Assertions {
* Make an [[AccumulableInfo]] out of an [[Accumulable]] with the intent to use the
* info as an accumulator update.
*/
- def makeInfo(a: Accumulable[_, _]): AccumulableInfo = {
- new AccumulableInfo(a.id, a.name, Some(a.value), None, a.isInternal, a.countFailedValues)
- }
+ def makeInfo(a: Accumulable[_, _]): AccumulableInfo = a.toInfo(Some(a.value), None)
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index d9c71ec2ea..62972a0738 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -1581,12 +1581,9 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
assert(Accumulators.get(acc1.id).isDefined)
assert(Accumulators.get(acc2.id).isDefined)
assert(Accumulators.get(acc3.id).isDefined)
- val accInfo1 = new AccumulableInfo(
- acc1.id, acc1.name, Some(15L), None, internal = false, countFailedValues = false)
- val accInfo2 = new AccumulableInfo(
- acc2.id, acc2.name, Some(13L), None, internal = false, countFailedValues = false)
- val accInfo3 = new AccumulableInfo(
- acc3.id, acc3.name, Some(18L), None, internal = false, countFailedValues = false)
+ val accInfo1 = acc1.toInfo(Some(15L), None)
+ val accInfo2 = acc2.toInfo(Some(13L), None)
+ val accInfo3 = acc3.toInfo(Some(18L), None)
val accumUpdates = Seq(accInfo1, accInfo2, accInfo3)
val exceptionFailure = new ExceptionFailure(new SparkException("fondue?"), accumUpdates)
submit(new MyRDD(sc, 1, Nil), Array(0))
@@ -1954,10 +1951,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
extraAccumUpdates: Seq[AccumulableInfo] = Seq.empty[AccumulableInfo],
taskInfo: TaskInfo = createFakeTaskInfo()): CompletionEvent = {
val accumUpdates = reason match {
- case Success =>
- task.initialAccumulators.map { a =>
- new AccumulableInfo(a.id, a.name, Some(a.zero), None, a.isInternal, a.countFailedValues)
- }
+ case Success => task.initialAccumulators.map { a => a.toInfo(Some(a.zero), None) }
case ef: ExceptionFailure => ef.accumUpdates
case _ => Seq.empty[AccumulableInfo]
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index a2e7436564..2c99dd5afb 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -165,9 +165,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
val taskSet = FakeTask.createTaskSet(1)
val clock = new ManualClock
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock)
- val accumUpdates = taskSet.tasks.head.initialAccumulators.map { a =>
- new AccumulableInfo(a.id, a.name, Some(0L), None, a.isInternal, a.countFailedValues)
- }
+ val accumUpdates = taskSet.tasks.head.initialAccumulators.map { a => a.toInfo(Some(0L), None) }
// Offer a host with NO_PREF as the constraint,
// we should get a nopref task immediately since that's what we only have
@@ -186,9 +184,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
val taskSet = FakeTask.createTaskSet(3)
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES)
val accumUpdatesByTask: Array[Seq[AccumulableInfo]] = taskSet.tasks.map { task =>
- task.initialAccumulators.map { a =>
- new AccumulableInfo(a.id, a.name, Some(0L), None, a.isInternal, a.countFailedValues)
- }
+ task.initialAccumulators.map { a => a.toInfo(Some(0L), None) }
}
// First three offers should all find tasks
diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
index 57021d1d3d..48951c3168 100644
--- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
@@ -374,15 +374,18 @@ class JsonProtocolSuite extends SparkFunSuite {
test("AccumulableInfo backward compatibility") {
// "Internal" property of AccumulableInfo was added in 1.5.1
val accumulableInfo = makeAccumulableInfo(1, internal = true, countFailedValues = true)
- val oldJson = JsonProtocol.accumulableInfoToJson(accumulableInfo)
- .removeField({ _._1 == "Internal" })
+ val accumulableInfoJson = JsonProtocol.accumulableInfoToJson(accumulableInfo)
+ val oldJson = accumulableInfoJson.removeField({ _._1 == "Internal" })
val oldInfo = JsonProtocol.accumulableInfoFromJson(oldJson)
assert(!oldInfo.internal)
// "Count Failed Values" property of AccumulableInfo was added in 2.0.0
- val oldJson2 = JsonProtocol.accumulableInfoToJson(accumulableInfo)
- .removeField({ _._1 == "Count Failed Values" })
+ val oldJson2 = accumulableInfoJson.removeField({ _._1 == "Count Failed Values" })
val oldInfo2 = JsonProtocol.accumulableInfoFromJson(oldJson2)
assert(!oldInfo2.countFailedValues)
+ // "Metadata" property of AccumulableInfo was added in 2.0.0
+ val oldJson3 = accumulableInfoJson.removeField({ _._1 == "Metadata" })
+ val oldInfo3 = JsonProtocol.accumulableInfoFromJson(oldJson3)
+ assert(oldInfo3.metadata.isEmpty)
}
test("ExceptionFailure backward compatibility: accumulator updates") {
@@ -820,9 +823,10 @@ private[spark] object JsonProtocolSuite extends Assertions {
private def makeAccumulableInfo(
id: Int,
internal: Boolean = false,
- countFailedValues: Boolean = false): AccumulableInfo =
+ countFailedValues: Boolean = false,
+ metadata: Option[String] = None): AccumulableInfo =
new AccumulableInfo(id, Some(s"Accumulable$id"), Some(s"delta$id"), Some(s"val$id"),
- internal, countFailedValues)
+ internal, countFailedValues, metadata)
/**
* Creates a TaskMetrics object describing a task that read data from Hadoop (if hasHadoopInput is
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
index 950dc78162..6b43d273fe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.execution.metric
import org.apache.spark.{Accumulable, AccumulableParam, Accumulators, SparkContext}
+import org.apache.spark.scheduler.AccumulableInfo
import org.apache.spark.util.Utils
/**
@@ -27,9 +28,16 @@ import org.apache.spark.util.Utils
* An implementation of SQLMetric should override `+=` and `add` to avoid boxing.
*/
private[sql] abstract class SQLMetric[R <: SQLMetricValue[T], T](
- name: String, val param: SQLMetricParam[R, T])
+ name: String,
+ val param: SQLMetricParam[R, T])
extends Accumulable[R, T](param.zero, param, Some(name), internal = true) {
+ // Provide special identifier as metadata so we can tell that this is a `SQLMetric` later
+ override def toInfo(update: Option[Any], value: Option[Any]): AccumulableInfo = {
+ new AccumulableInfo(id, Some(name), update, value, isInternal, countFailedValues,
+ Some(SQLMetrics.ACCUM_IDENTIFIER))
+ }
+
def reset(): Unit = {
this.value = param.zero
}
@@ -73,6 +81,14 @@ private[sql] class LongSQLMetricValue(private var _value : Long) extends SQLMetr
// Although there is a boxing here, it's fine because it's only called in SQLListener
override def value: Long = _value
+
+ // Needed for SQLListenerSuite
+ override def equals(other: Any): Boolean = {
+ other match {
+ case o: LongSQLMetricValue => value == o.value
+ case _ => false
+ }
+ }
}
/**
@@ -126,6 +142,9 @@ private object StaticsLongSQLMetricParam extends LongSQLMetricParam(
private[sql] object SQLMetrics {
+ // Identifier for distinguishing SQL metrics from other accumulators
+ private[sql] val ACCUM_IDENTIFIER = "sql"
+
private def createLongMetric(
sc: SparkContext,
name: String,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
index 544606f116..835e7ba6c5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
@@ -23,7 +23,7 @@ import org.apache.spark.{JobExecutionStatus, Logging, SparkConf}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.scheduler._
import org.apache.spark.sql.execution.{SparkPlanInfo, SQLExecution}
-import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetricParam, SQLMetricValue}
+import org.apache.spark.sql.execution.metric._
import org.apache.spark.ui.SparkUI
@DeveloperApi
@@ -314,14 +314,17 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi
}
+
+/**
+ * A [[SQLListener]] for rendering the SQL UI in the history server.
+ */
private[spark] class SQLHistoryListener(conf: SparkConf, sparkUI: SparkUI)
extends SQLListener(conf) {
private var sqlTabAttached = false
- override def onExecutorMetricsUpdate(
- executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = synchronized {
- // Do nothing
+ override def onExecutorMetricsUpdate(u: SparkListenerExecutorMetricsUpdate): Unit = {
+ // Do nothing; these events are not logged
}
override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized {
@@ -329,9 +332,15 @@ private[spark] class SQLHistoryListener(conf: SparkConf, sparkUI: SparkUI)
taskEnd.taskInfo.taskId,
taskEnd.stageId,
taskEnd.stageAttemptId,
- taskEnd.taskInfo.accumulables.map { a =>
- val newValue = new LongSQLMetricValue(a.update.map(_.asInstanceOf[Long]).getOrElse(0L))
- a.copy(update = Some(newValue))
+ taskEnd.taskInfo.accumulables.flatMap { a =>
+ // Filter out accumulators that are not SQL metrics
+ // For now we assume all SQL metrics are Long's that have been JSON serialized as String's
+ if (a.metadata.exists(_ == SQLMetrics.ACCUM_IDENTIFIER)) {
+ val newValue = new LongSQLMetricValue(a.update.map(_.toString.toLong).getOrElse(0L))
+ Some(a.copy(update = Some(newValue)))
+ } else {
+ None
+ }
},
finishTask = true)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index 82f6811503..2260e48702 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.SparkPlanInfo
import org.apache.spark.sql.execution.ui.SparkPlanGraph
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{JsonProtocol, Utils}
class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
@@ -356,6 +356,28 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
}
}
+ test("metrics can be loaded by history server") {
+ val metric = new LongSQLMetric("zanzibar", LongSQLMetricParam)
+ metric += 10L
+ val metricInfo = metric.toInfo(Some(metric.localValue), None)
+ metricInfo.update match {
+ case Some(v: LongSQLMetricValue) => assert(v.value === 10L)
+ case Some(v) => fail(s"metric value was not a LongSQLMetricValue: ${v.getClass.getName}")
+ case _ => fail("metric update is missing")
+ }
+ assert(metricInfo.metadata === Some(SQLMetrics.ACCUM_IDENTIFIER))
+ // After serializing to JSON, the original value type is lost, but we can still
+ // identify that it's a SQL metric from the metadata
+ val metricInfoJson = JsonProtocol.accumulableInfoToJson(metricInfo)
+ val metricInfoDeser = JsonProtocol.accumulableInfoFromJson(metricInfoJson)
+ metricInfoDeser.update match {
+ case Some(v: String) => assert(v.toLong === 10L)
+ case Some(v) => fail(s"deserialized metric value was not a string: ${v.getClass.getName}")
+ case _ => fail("deserialized metric update is missing")
+ }
+ assert(metricInfoDeser.metadata === Some(SQLMetrics.ACCUM_IDENTIFIER))
+ }
+
}
private case class MethodIdentifier[T](cls: Class[T], name: String, desc: String)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
index 2c408c8878..085e4a49a5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
@@ -26,8 +26,9 @@ import org.apache.spark.executor.TaskMetrics
import org.apache.spark.scheduler._
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.execution.{SparkPlanInfo, SQLExecution}
-import org.apache.spark.sql.execution.metric.LongSQLMetricValue
+import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetrics}
import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.ui.SparkUI
class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
import testImplicits._
@@ -335,8 +336,43 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
assert(sqlContext.listener.stageIdToStageMetrics.size == previousStageNumber + 1)
}
+ test("SPARK-13055: history listener only tracks SQL metrics") {
+ val listener = new SQLHistoryListener(sparkContext.conf, mock(classOf[SparkUI]))
+ // We need to post other events for the listener to track our accumulators.
+ // These are largely just boilerplate unrelated to what we're trying to test.
+ val df = createTestDataFrame
+ val executionStart = SparkListenerSQLExecutionStart(
+ 0, "", "", "", SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), 0)
+ val stageInfo = createStageInfo(0, 0)
+ val jobStart = SparkListenerJobStart(0, 0, Seq(stageInfo), createProperties(0))
+ val stageSubmitted = SparkListenerStageSubmitted(stageInfo)
+ // This task has both accumulators that are SQL metrics and accumulators that are not.
+ // The listener should only track the ones that are actually SQL metrics.
+ val sqlMetric = SQLMetrics.createLongMetric(sparkContext, "beach umbrella")
+ val nonSqlMetric = sparkContext.accumulator[Int](0, "baseball")
+ val sqlMetricInfo = sqlMetric.toInfo(Some(sqlMetric.localValue), None)
+ val nonSqlMetricInfo = nonSqlMetric.toInfo(Some(nonSqlMetric.localValue), None)
+ val taskInfo = createTaskInfo(0, 0)
+ taskInfo.accumulables ++= Seq(sqlMetricInfo, nonSqlMetricInfo)
+ val taskEnd = SparkListenerTaskEnd(0, 0, "just-a-task", null, taskInfo, null)
+ listener.onOtherEvent(executionStart)
+ listener.onJobStart(jobStart)
+ listener.onStageSubmitted(stageSubmitted)
+ // Before SPARK-13055, this throws ClassCastException because the history listener would
+ // assume that the accumulator value is of type Long, but this may not be true for
+ // accumulators that are not SQL metrics.
+ listener.onTaskEnd(taskEnd)
+ val trackedAccums = listener.stageIdToStageMetrics.values.flatMap { stageMetrics =>
+ stageMetrics.taskIdToMetricUpdates.values.flatMap(_.accumulatorUpdates)
+ }
+ // Listener tracks only SQL metrics, not other accumulators
+ assert(trackedAccums.size === 1)
+ assert(trackedAccums.head === sqlMetricInfo)
+ }
+
}
+
class SQLListenerMemoryLeakSuite extends SparkFunSuite {
test("no memory leak") {