aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorzsxwing <zsxwing@gmail.com>2015-08-07 00:09:58 -0700
committerReynold Xin <rxin@databricks.com>2015-08-07 00:09:58 -0700
commitebfd91c542aaead343cb154277fcf9114382fee7 (patch)
tree89a15f29d335398b02f37749cfaa7d8c5a28abb9 /sql
parente57d6b56137bf3557efe5acea3ad390c1987b257 (diff)
downloadspark-ebfd91c542aaead343cb154277fcf9114382fee7.tar.gz
spark-ebfd91c542aaead343cb154277fcf9114382fee7.tar.bz2
spark-ebfd91c542aaead343cb154277fcf9114382fee7.zip
[SPARK-9467][SQL]Add SQLMetric to specialize accumulators to avoid boxing
This PR adds SQLMetric/SQLMetricParam/SQLMetricValue to specialize accumulators to avoid boxing. All SQL metrics should use these classes rather than `Accumulator`. Author: zsxwing <zsxwing@gmail.com> Closes #7996 from zsxwing/sql-accu and squashes the following commits: 14a5f0a [zsxwing] Address comments 367ca23 [zsxwing] Use localValue directly to avoid changing Accumulable 42f50c3 [zsxwing] Add SQLMetric to specialize accumulators to avoid boxing
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala33
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/metric/SQLMetrics.scala149
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/ui/SQLListener.scala17
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/ui/SparkPlanGraph.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/metric/SQLMetricsSuite.scala145
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ui/SQLListenerSuite.scala5
7 files changed, 337 insertions, 31 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 719ad432e2..1915496d16 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -21,7 +21,7 @@ import java.util.concurrent.atomic.AtomicBoolean
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.{Accumulator, Logging}
+import org.apache.spark.Logging
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.{RDD, RDDOperationScope}
import org.apache.spark.sql.SQLContext
@@ -32,6 +32,7 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.metric.{IntSQLMetric, LongSQLMetric, SQLMetric, SQLMetrics}
import org.apache.spark.sql.types.DataType
object SparkPlan {
@@ -84,22 +85,30 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
*/
protected[sql] def trackNumOfRowsEnabled: Boolean = false
- private lazy val numOfRowsAccumulator = sparkContext.internalAccumulator(0L, "number of rows")
+ private lazy val defaultMetrics: Map[String, SQLMetric[_, _]] =
+ if (trackNumOfRowsEnabled) {
+ Map("numRows" -> SQLMetrics.createLongMetric(sparkContext, "number of rows"))
+ }
+ else {
+ Map.empty
+ }
/**
- * Return all accumulators containing metrics of this SparkPlan.
+ * Return all metrics containing metrics of this SparkPlan.
*/
- private[sql] def accumulators: Map[String, Accumulator[_]] = if (trackNumOfRowsEnabled) {
- Map("numRows" -> numOfRowsAccumulator)
- } else {
- Map.empty
- }
+ private[sql] def metrics: Map[String, SQLMetric[_, _]] = defaultMetrics
+
+ /**
+ * Return a IntSQLMetric according to the name.
+ */
+ private[sql] def intMetric(name: String): IntSQLMetric =
+ metrics(name).asInstanceOf[IntSQLMetric]
/**
- * Return the accumulator according to the name.
+ * Return a LongSQLMetric according to the name.
*/
- private[sql] def accumulator[T](name: String): Accumulator[T] =
- accumulators(name).asInstanceOf[Accumulator[T]]
+ private[sql] def longMetric(name: String): LongSQLMetric =
+ metrics(name).asInstanceOf[LongSQLMetric]
// TODO: Move to `DistributedPlan`
/** Specifies how data is partitioned across different nodes in the cluster. */
@@ -148,7 +157,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
RDDOperationScope.withScope(sparkContext, nodeName, false, true) {
prepare()
if (trackNumOfRowsEnabled) {
- val numRows = accumulator[Long]("numRows")
+ val numRows = longMetric("numRows")
doExecute().map { row =>
numRows += 1
row
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index f4677b4ee8..0680f31d40 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.metric.SQLMetrics
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.collection.ExternalSorter
import org.apache.spark.util.collection.unsafe.sort.PrefixComparator
@@ -81,13 +82,13 @@ case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan)
case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
- private[sql] override lazy val accumulators = Map(
- "numInputRows" -> sparkContext.internalAccumulator(0L, "number of input rows"),
- "numOutputRows" -> sparkContext.internalAccumulator(0L, "number of output rows"))
+ private[sql] override lazy val metrics = Map(
+ "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"),
+ "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
protected override def doExecute(): RDD[InternalRow] = {
- val numInputRows = accumulator[Long]("numInputRows")
- val numOutputRows = accumulator[Long]("numOutputRows")
+ val numInputRows = longMetric("numInputRows")
+ val numOutputRows = longMetric("numOutputRows")
child.execute().mapPartitions { iter =>
val predicate = newPredicate(condition, child.output)
iter.filter { row =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/metric/SQLMetrics.scala
new file mode 100644
index 0000000000..3b907e5da7
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/metric/SQLMetrics.scala
@@ -0,0 +1,149 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.metric
+
+import org.apache.spark.{Accumulable, AccumulableParam, SparkContext}
+
+/**
+ * Create a layer for specialized metric. We cannot add `@specialized` to
+ * `Accumulable/AccumulableParam` because it will break Java source compatibility.
+ *
+ * An implementation of SQLMetric should override `+=` and `add` to avoid boxing.
+ */
+private[sql] abstract class SQLMetric[R <: SQLMetricValue[T], T](
+ name: String, val param: SQLMetricParam[R, T])
+ extends Accumulable[R, T](param.zero, param, Some(name), true)
+
+/**
+ * Create a layer for specialized metric. We cannot add `@specialized` to
+ * `Accumulable/AccumulableParam` because it will break Java source compatibility.
+ */
+private[sql] trait SQLMetricParam[R <: SQLMetricValue[T], T] extends AccumulableParam[R, T] {
+
+ def zero: R
+}
+
+/**
+ * Create a layer for specialized metric. We cannot add `@specialized` to
+ * `Accumulable/AccumulableParam` because it will break Java source compatibility.
+ */
+private[sql] trait SQLMetricValue[T] extends Serializable {
+
+ def value: T
+
+ override def toString: String = value.toString
+}
+
+/**
+ * A wrapper of Long to avoid boxing and unboxing when using Accumulator
+ */
+private[sql] class LongSQLMetricValue(private var _value : Long) extends SQLMetricValue[Long] {
+
+ def add(incr: Long): LongSQLMetricValue = {
+ _value += incr
+ this
+ }
+
+ // Although there is a boxing here, it's fine because it's only called in SQLListener
+ override def value: Long = _value
+}
+
+/**
+ * A wrapper of Int to avoid boxing and unboxing when using Accumulator
+ */
+private[sql] class IntSQLMetricValue(private var _value: Int) extends SQLMetricValue[Int] {
+
+ def add(term: Int): IntSQLMetricValue = {
+ _value += term
+ this
+ }
+
+ // Although there is a boxing here, it's fine because it's only called in SQLListener
+ override def value: Int = _value
+}
+
+/**
+ * A specialized long Accumulable to avoid boxing and unboxing when using Accumulator's
+ * `+=` and `add`.
+ */
+private[sql] class LongSQLMetric private[metric](name: String)
+ extends SQLMetric[LongSQLMetricValue, Long](name, LongSQLMetricParam) {
+
+ override def +=(term: Long): Unit = {
+ localValue.add(term)
+ }
+
+ override def add(term: Long): Unit = {
+ localValue.add(term)
+ }
+}
+
+/**
+ * A specialized int Accumulable to avoid boxing and unboxing when using Accumulator's
+ * `+=` and `add`.
+ */
+private[sql] class IntSQLMetric private[metric](name: String)
+ extends SQLMetric[IntSQLMetricValue, Int](name, IntSQLMetricParam) {
+
+ override def +=(term: Int): Unit = {
+ localValue.add(term)
+ }
+
+ override def add(term: Int): Unit = {
+ localValue.add(term)
+ }
+}
+
+private object LongSQLMetricParam extends SQLMetricParam[LongSQLMetricValue, Long] {
+
+ override def addAccumulator(r: LongSQLMetricValue, t: Long): LongSQLMetricValue = r.add(t)
+
+ override def addInPlace(r1: LongSQLMetricValue, r2: LongSQLMetricValue): LongSQLMetricValue =
+ r1.add(r2.value)
+
+ override def zero(initialValue: LongSQLMetricValue): LongSQLMetricValue = zero
+
+ override def zero: LongSQLMetricValue = new LongSQLMetricValue(0L)
+}
+
+private object IntSQLMetricParam extends SQLMetricParam[IntSQLMetricValue, Int] {
+
+ override def addAccumulator(r: IntSQLMetricValue, t: Int): IntSQLMetricValue = r.add(t)
+
+ override def addInPlace(r1: IntSQLMetricValue, r2: IntSQLMetricValue): IntSQLMetricValue =
+ r1.add(r2.value)
+
+ override def zero(initialValue: IntSQLMetricValue): IntSQLMetricValue = zero
+
+ override def zero: IntSQLMetricValue = new IntSQLMetricValue(0)
+}
+
+private[sql] object SQLMetrics {
+
+ def createIntMetric(sc: SparkContext, name: String): IntSQLMetric = {
+ val acc = new IntSQLMetric(name)
+ sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc))
+ acc
+ }
+
+ def createLongMetric(sc: SparkContext, name: String): LongSQLMetric = {
+ val acc = new LongSQLMetric(name)
+ sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc))
+ acc
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLListener.scala
index e7b1dd1ffa..2fd4fc658d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLListener.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/ui/SQLListener.scala
@@ -21,11 +21,12 @@ import scala.collection.mutable
import com.google.common.annotations.VisibleForTesting
-import org.apache.spark.{AccumulatorParam, JobExecutionStatus, Logging}
+import org.apache.spark.{JobExecutionStatus, Logging}
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.scheduler._
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.execution.SQLExecution
+import org.apache.spark.sql.metric.{SQLMetricParam, SQLMetricValue}
private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener with Logging {
@@ -36,8 +37,6 @@ private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener wit
// Old data in the following fields must be removed in "trimExecutionsIfNecessary".
// If adding new fields, make sure "trimExecutionsIfNecessary" can clean up old data
-
- // VisibleForTesting
private val _executionIdToData = mutable.HashMap[Long, SQLExecutionUIData]()
/**
@@ -270,9 +269,10 @@ private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener wit
accumulatorUpdate <- taskMetrics.accumulatorUpdates.toSeq) yield {
accumulatorUpdate
}
- }.filter { case (id, _) => executionUIData.accumulatorMetrics.keySet(id) }
+ }.filter { case (id, _) => executionUIData.accumulatorMetrics.contains(id) }
mergeAccumulatorUpdates(accumulatorUpdates, accumulatorId =>
- executionUIData.accumulatorMetrics(accumulatorId).accumulatorParam)
+ executionUIData.accumulatorMetrics(accumulatorId).metricParam).
+ mapValues(_.asInstanceOf[SQLMetricValue[_]].value)
case None =>
// This execution has been dropped
Map.empty
@@ -281,10 +281,11 @@ private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener wit
private def mergeAccumulatorUpdates(
accumulatorUpdates: Seq[(Long, Any)],
- paramFunc: Long => AccumulatorParam[Any]): Map[Long, Any] = {
+ paramFunc: Long => SQLMetricParam[SQLMetricValue[Any], Any]): Map[Long, Any] = {
accumulatorUpdates.groupBy(_._1).map { case (accumulatorId, values) =>
val param = paramFunc(accumulatorId)
- (accumulatorId, values.map(_._2).reduceLeft(param.addInPlace))
+ (accumulatorId,
+ values.map(_._2.asInstanceOf[SQLMetricValue[Any]]).foldLeft(param.zero)(param.addInPlace))
}
}
@@ -336,7 +337,7 @@ private[ui] class SQLExecutionUIData(
private[ui] case class SQLPlanMetric(
name: String,
accumulatorId: Long,
- accumulatorParam: AccumulatorParam[Any])
+ metricParam: SQLMetricParam[SQLMetricValue[Any], Any])
/**
* Store all accumulatorUpdates for all tasks in a Spark stage.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/ui/SparkPlanGraph.scala
index 7910c163ba..1ba50b95be 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/ui/SparkPlanGraph.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/ui/SparkPlanGraph.scala
@@ -21,8 +21,8 @@ import java.util.concurrent.atomic.AtomicLong
import scala.collection.mutable
-import org.apache.spark.AccumulatorParam
import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.metric.{SQLMetricParam, SQLMetricValue}
/**
* A graph used for storing information of an executionPlan of DataFrame.
@@ -61,9 +61,9 @@ private[sql] object SparkPlanGraph {
nodeIdGenerator: AtomicLong,
nodes: mutable.ArrayBuffer[SparkPlanGraphNode],
edges: mutable.ArrayBuffer[SparkPlanGraphEdge]): SparkPlanGraphNode = {
- val metrics = plan.accumulators.toSeq.map { case (key, accumulator) =>
- SQLPlanMetric(accumulator.name.getOrElse(key), accumulator.id,
- accumulator.param.asInstanceOf[AccumulatorParam[Any]])
+ val metrics = plan.metrics.toSeq.map { case (key, metric) =>
+ SQLPlanMetric(metric.name.getOrElse(key), metric.id,
+ metric.param.asInstanceOf[SQLMetricParam[SQLMetricValue[Any], Any]])
}
val node = SparkPlanGraphNode(
nodeIdGenerator.getAndIncrement(), plan.nodeName, plan.simpleString, metrics)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/metric/SQLMetricsSuite.scala
new file mode 100644
index 0000000000..d22160f538
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/metric/SQLMetricsSuite.scala
@@ -0,0 +1,145 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.metric
+
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
+
+import scala.collection.mutable
+
+import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm._
+import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.util.Utils
+
+
+class SQLMetricsSuite extends SparkFunSuite {
+
+ test("LongSQLMetric should not box Long") {
+ val l = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "long")
+ val f = () => { l += 1L }
+ BoxingFinder.getClassReader(f.getClass).foreach { cl =>
+ val boxingFinder = new BoxingFinder()
+ cl.accept(boxingFinder, 0)
+ assert(boxingFinder.boxingInvokes.isEmpty, s"Found boxing: ${boxingFinder.boxingInvokes}")
+ }
+ }
+
+ test("IntSQLMetric should not box Int") {
+ val l = SQLMetrics.createIntMetric(TestSQLContext.sparkContext, "Int")
+ val f = () => { l += 1 }
+ BoxingFinder.getClassReader(f.getClass).foreach { cl =>
+ val boxingFinder = new BoxingFinder()
+ cl.accept(boxingFinder, 0)
+ assert(boxingFinder.boxingInvokes.isEmpty, s"Found boxing: ${boxingFinder.boxingInvokes}")
+ }
+ }
+
+ test("Normal accumulator should do boxing") {
+ // We need this test to make sure BoxingFinder works.
+ val l = TestSQLContext.sparkContext.accumulator(0L)
+ val f = () => { l += 1L }
+ BoxingFinder.getClassReader(f.getClass).foreach { cl =>
+ val boxingFinder = new BoxingFinder()
+ cl.accept(boxingFinder, 0)
+ assert(boxingFinder.boxingInvokes.nonEmpty, "Found find boxing in this test")
+ }
+ }
+}
+
+private case class MethodIdentifier[T](cls: Class[T], name: String, desc: String)
+
+/**
+ * If `method` is null, search all methods of this class recursively to find if they do some boxing.
+ * If `method` is specified, only search this method of the class to speed up the searching.
+ *
+ * This method will skip the methods in `visitedMethods` to avoid potential infinite cycles.
+ */
+private class BoxingFinder(
+ method: MethodIdentifier[_] = null,
+ val boxingInvokes: mutable.Set[String] = mutable.Set.empty,
+ visitedMethods: mutable.Set[MethodIdentifier[_]] = mutable.Set.empty)
+ extends ClassVisitor(ASM4) {
+
+ private val primitiveBoxingClassName =
+ Set("java/lang/Long",
+ "java/lang/Double",
+ "java/lang/Integer",
+ "java/lang/Float",
+ "java/lang/Short",
+ "java/lang/Character",
+ "java/lang/Byte",
+ "java/lang/Boolean")
+
+ override def visitMethod(
+ access: Int, name: String, desc: String, sig: String, exceptions: Array[String]):
+ MethodVisitor = {
+ if (method != null && (method.name != name || method.desc != desc)) {
+ // If method is specified, skip other methods.
+ return new MethodVisitor(ASM4) {}
+ }
+
+ new MethodVisitor(ASM4) {
+ override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) {
+ if (op == INVOKESPECIAL && name == "<init>" || op == INVOKESTATIC && name == "valueOf") {
+ if (primitiveBoxingClassName.contains(owner)) {
+ // Find boxing methods, e.g, new java.lang.Long(l) or java.lang.Long.valueOf(l)
+ boxingInvokes.add(s"$owner.$name")
+ }
+ } else {
+ // scalastyle:off classforname
+ val classOfMethodOwner = Class.forName(owner.replace('/', '.'), false,
+ Thread.currentThread.getContextClassLoader)
+ // scalastyle:on classforname
+ val m = MethodIdentifier(classOfMethodOwner, name, desc)
+ if (!visitedMethods.contains(m)) {
+ // Keep track of visited methods to avoid potential infinite cycles
+ visitedMethods += m
+ BoxingFinder.getClassReader(classOfMethodOwner).foreach { cl =>
+ visitedMethods += m
+ cl.accept(new BoxingFinder(m, boxingInvokes, visitedMethods), 0)
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
+private object BoxingFinder {
+
+ def getClassReader(cls: Class[_]): Option[ClassReader] = {
+ val className = cls.getName.replaceFirst("^.*\\.", "") + ".class"
+ val resourceStream = cls.getResourceAsStream(className)
+ val baos = new ByteArrayOutputStream(128)
+ // Copy data over, before delegating to ClassReader -
+ // else we can run out of open file handles.
+ Utils.copyStream(resourceStream, baos, true)
+ // ASM4 doesn't support Java 8 classes, which requires ASM5.
+ // So if the class is ASM5 (E.g., java.lang.Long when using JDK8 runtime to run these codes),
+ // then ClassReader will throw IllegalArgumentException,
+ // However, since this is only for testing, it's safe to skip these classes.
+ try {
+ Some(new ClassReader(new ByteArrayInputStream(baos.toByteArray)))
+ } catch {
+ case _: IllegalArgumentException => None
+ }
+ }
+
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ui/SQLListenerSuite.scala
index f1fcaf5953..69a561e16a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ui/SQLListenerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ui/SQLListenerSuite.scala
@@ -21,6 +21,7 @@ import java.util.Properties
import org.apache.spark.{SparkException, SparkContext, SparkConf, SparkFunSuite}
import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.sql.metric.LongSQLMetricValue
import org.apache.spark.scheduler._
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.execution.SQLExecution
@@ -65,9 +66,9 @@ class SQLListenerSuite extends SparkFunSuite {
speculative = false
)
- private def createTaskMetrics(accumulatorUpdates: Map[Long, Any]): TaskMetrics = {
+ private def createTaskMetrics(accumulatorUpdates: Map[Long, Long]): TaskMetrics = {
val metrics = new TaskMetrics
- metrics.setAccumulatorsUpdater(() => accumulatorUpdates)
+ metrics.setAccumulatorsUpdater(() => accumulatorUpdates.mapValues(new LongSQLMetricValue(_)))
metrics.updateAccumulators()
metrics
}