aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2015-10-15 14:50:58 -0700
committerAndrew Or <andrew@databricks.com>2015-10-15 14:50:58 -0700
commit6a2359ff1f7ad2233af2c530313d6ec2ecf70d19 (patch)
treec71d03c68072808a7a008720ada69d9ce483aea0 /sql
parent3b364ff0a4f38c2b8023429a55623de32be5f329 (diff)
downloadspark-6a2359ff1f7ad2233af2c530313d6ec2ecf70d19.tar.gz
spark-6a2359ff1f7ad2233af2c530313d6ec2ecf70d19.tar.bz2
spark-6a2359ff1f7ad2233af2c530313d6ec2ecf70d19.zip
[SPARK-10412] [SQL] report memory usage for tungsten sql physical operator
https://issues.apache.org/jira/browse/SPARK-10412 some screenshots: ### aggregate: ![screen shot 2015-10-12 at 2 23 11 pm](https://cloud.githubusercontent.com/assets/3182036/10439534/618320a4-70ef-11e5-94d8-62ea7f2d1531.png) ### join ![screen shot 2015-10-12 at 2 23 29 pm](https://cloud.githubusercontent.com/assets/3182036/10439537/6724797c-70ef-11e5-8f75-0cf5cbd42048.png) Author: Wenchen Fan <wenchen@databricks.com> Author: Wenchen Fan <cloud0fan@163.com> Closes #8931 from cloud-fan/viz.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala72
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala16
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala13
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala20
10 files changed, 116 insertions, 43 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index c342940e6e..0d3a4b36c1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -49,7 +49,9 @@ case class TungstenAggregate(
override private[sql] lazy val metrics = Map(
"numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"),
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"),
+ "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
+ "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"))
override def outputsUnsafeRows: Boolean = true
@@ -79,6 +81,8 @@ case class TungstenAggregate(
protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
val numInputRows = longMetric("numInputRows")
val numOutputRows = longMetric("numOutputRows")
+ val dataSize = longMetric("dataSize")
+ val spillSize = longMetric("spillSize")
/**
* Set up the underlying unsafe data structures used before computing the parent partition.
@@ -97,7 +101,9 @@ case class TungstenAggregate(
child.output,
testFallbackStartsAt,
numInputRows,
- numOutputRows)
+ numOutputRows,
+ dataSize,
+ spillSize)
}
/** Compute a partition using the iterator already set up previously. */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index fe708a5f71..7cd0f7b81e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -87,7 +87,9 @@ class TungstenAggregationIterator(
originalInputAttributes: Seq[Attribute],
testFallbackStartsAt: Option[Int],
numInputRows: LongSQLMetric,
- numOutputRows: LongSQLMetric)
+ numOutputRows: LongSQLMetric,
+ dataSize: LongSQLMetric,
+ spillSize: LongSQLMetric)
extends Iterator[UnsafeRow] with Logging {
// The parent partition iterator, to be initialized later in `start`
@@ -110,6 +112,10 @@ class TungstenAggregationIterator(
s"$allAggregateExpressions should have no more than 2 kinds of modes.")
}
+ // Remember spill data size of this task before execute this operator so that we can
+ // figure out how many bytes we spilled for this operator.
+ private val spillSizeBefore = TaskContext.get().taskMetrics().memoryBytesSpilled
+
//
// The modes of AggregateExpressions. Right now, we can handle the following mode:
// - Partial-only:
@@ -842,6 +848,8 @@ class TungstenAggregationIterator(
val mapMemory = hashMap.getPeakMemoryUsedBytes
val sorterMemory = Option(externalSorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L)
val peakMemory = Math.max(mapMemory, sorterMemory)
+ dataSize += peakMemory
+ spillSize += TaskContext.get().taskMetrics().memoryBytesSpilled - spillSizeBefore
TaskContext.get().internalMetricsToAccumulators(
InternalAccumulator.PEAK_EXECUTION_MEMORY).add(peakMemory)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
index 7a2a98ec18..075b7ad881 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.execution.metric
+import org.apache.spark.util.Utils
import org.apache.spark.{Accumulable, AccumulableParam, SparkContext}
/**
@@ -35,6 +36,12 @@ private[sql] abstract class SQLMetric[R <: SQLMetricValue[T], T](
*/
private[sql] trait SQLMetricParam[R <: SQLMetricValue[T], T] extends AccumulableParam[R, T] {
+ /**
+ * A function that defines how we aggregate the final accumulator results among all tasks,
+ * and represent it in string for a SQL physical operator.
+ */
+ val stringValue: Seq[T] => String
+
def zero: R
}
@@ -64,25 +71,11 @@ private[sql] class LongSQLMetricValue(private var _value : Long) extends SQLMetr
}
/**
- * A wrapper of Int to avoid boxing and unboxing when using Accumulator
- */
-private[sql] class IntSQLMetricValue(private var _value: Int) extends SQLMetricValue[Int] {
-
- def add(term: Int): IntSQLMetricValue = {
- _value += term
- this
- }
-
- // Although there is a boxing here, it's fine because it's only called in SQLListener
- override def value: Int = _value
-}
-
-/**
* A specialized long Accumulable to avoid boxing and unboxing when using Accumulator's
* `+=` and `add`.
*/
-private[sql] class LongSQLMetric private[metric](name: String)
- extends SQLMetric[LongSQLMetricValue, Long](name, LongSQLMetricParam) {
+private[sql] class LongSQLMetric private[metric](name: String, param: LongSQLMetricParam)
+ extends SQLMetric[LongSQLMetricValue, Long](name, param) {
override def +=(term: Long): Unit = {
localValue.add(term)
@@ -93,7 +86,8 @@ private[sql] class LongSQLMetric private[metric](name: String)
}
}
-private object LongSQLMetricParam extends SQLMetricParam[LongSQLMetricValue, Long] {
+private class LongSQLMetricParam(val stringValue: Seq[Long] => String, initialValue: Long)
+ extends SQLMetricParam[LongSQLMetricValue, Long] {
override def addAccumulator(r: LongSQLMetricValue, t: Long): LongSQLMetricValue = r.add(t)
@@ -102,20 +96,56 @@ private object LongSQLMetricParam extends SQLMetricParam[LongSQLMetricValue, Lon
override def zero(initialValue: LongSQLMetricValue): LongSQLMetricValue = zero
- override def zero: LongSQLMetricValue = new LongSQLMetricValue(0L)
+ override def zero: LongSQLMetricValue = new LongSQLMetricValue(initialValue)
}
private[sql] object SQLMetrics {
- def createLongMetric(sc: SparkContext, name: String): LongSQLMetric = {
- val acc = new LongSQLMetric(name)
+ private def createLongMetric(
+ sc: SparkContext,
+ name: String,
+ stringValue: Seq[Long] => String,
+ initialValue: Long): LongSQLMetric = {
+ val param = new LongSQLMetricParam(stringValue, initialValue)
+ val acc = new LongSQLMetric(name, param)
sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc))
acc
}
+ def createLongMetric(sc: SparkContext, name: String): LongSQLMetric = {
+ createLongMetric(sc, name, _.sum.toString, 0L)
+ }
+
+ /**
+ * Create a metric to report the size information (including total, min, med, max) like data size,
+ * spill size, etc.
+ */
+ def createSizeMetric(sc: SparkContext, name: String): LongSQLMetric = {
+ val stringValue = (values: Seq[Long]) => {
+ // This is a workaround for SPARK-11013.
+ // We use -1 as initial value of the accumulator, if the accumulator is valid, we will update
+ // it at the end of task and the value will be at least 0.
+ val validValues = values.filter(_ >= 0)
+ val Seq(sum, min, med, max) = {
+ val metric = if (validValues.length == 0) {
+ Seq.fill(4)(0L)
+ } else {
+ val sorted = validValues.sorted
+ Seq(sorted.sum, sorted(0), sorted(validValues.length / 2), sorted(validValues.length - 1))
+ }
+ metric.map(Utils.bytesToString)
+ }
+ s"\n$sum ($min, $med, $max)"
+ }
+ // The final result of this metric in physical operator UI may looks like:
+ // data size total (min, med, max):
+ // 100GB (100MB, 1GB, 10GB)
+ createLongMetric(sc, s"$name total (min, med, max)", stringValue, -1L)
+ }
+
/**
* A metric that its value will be ignored. Use this one when we need a metric parameter but don't
* care about the value.
*/
- val nullLongMetric = new LongSQLMetric("null")
+ val nullLongMetric = new LongSQLMetric("null", new LongSQLMetricParam(_.sum.toString, 0L))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
index 27f26245a5..9385e5734d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
@@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, OrderedDistribution, UnspecifiedDistribution}
+import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter
@@ -93,10 +94,17 @@ case class TungstenSort(
override def requiredChildDistribution: Seq[Distribution] =
if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil
+ override private[sql] lazy val metrics = Map(
+ "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
+ "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"))
+
protected override def doExecute(): RDD[InternalRow] = {
val schema = child.schema
val childOutput = child.output
+ val dataSize = longMetric("dataSize")
+ val spillSize = longMetric("spillSize")
+
/**
* Set up the sorter in each partition before computing the parent partition.
* This makes sure our sorter is not starved by other sorters used in the same task.
@@ -131,7 +139,15 @@ case class TungstenSort(
partitionIndex: Int,
sorter: UnsafeExternalRowSorter,
parentIterator: Iterator[InternalRow]): Iterator[InternalRow] = {
+ // Remember spill data size of this task before execute this operator so that we can
+ // figure out how many bytes we spilled for this operator.
+ val spillSizeBefore = TaskContext.get().taskMetrics().memoryBytesSpilled
+
val sortedIterator = sorter.sort(parentIterator.asInstanceOf[Iterator[UnsafeRow]])
+
+ dataSize += sorter.getPeakMemoryUsage
+ spillSize += TaskContext.get().taskMetrics().memoryBytesSpilled - spillSizeBefore
+
taskContext.internalMetricsToAccumulators(
InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.getPeakMemoryUsage)
sortedIterator
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala
index a4dbd2e197..e74d6fb396 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala
@@ -100,7 +100,7 @@ private[sql] class ExecutionPage(parent: SQLTab) extends WebUIPage("execution")
// scalastyle:on
}
- private def planVisualization(metrics: Map[Long, Any], graph: SparkPlanGraph): Seq[Node] = {
+ private def planVisualization(metrics: Map[Long, String], graph: SparkPlanGraph): Seq[Node] = {
val metadata = graph.nodes.flatMap { node =>
val nodeId = s"plan-meta-data-${node.id}"
<div id={nodeId}>{node.desc}</div>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
index d6472400a6..b302b51999 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
@@ -252,7 +252,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi
/**
* Get all accumulator updates from all tasks which belong to this execution and merge them.
*/
- def getExecutionMetrics(executionId: Long): Map[Long, Any] = synchronized {
+ def getExecutionMetrics(executionId: Long): Map[Long, String] = synchronized {
_executionIdToData.get(executionId) match {
case Some(executionUIData) =>
val accumulatorUpdates = {
@@ -264,8 +264,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi
}
}.filter { case (id, _) => executionUIData.accumulatorMetrics.contains(id) }
mergeAccumulatorUpdates(accumulatorUpdates, accumulatorId =>
- executionUIData.accumulatorMetrics(accumulatorId).metricParam).
- mapValues(_.asInstanceOf[SQLMetricValue[_]].value)
+ executionUIData.accumulatorMetrics(accumulatorId).metricParam)
case None =>
// This execution has been dropped
Map.empty
@@ -274,11 +273,11 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi
private def mergeAccumulatorUpdates(
accumulatorUpdates: Seq[(Long, Any)],
- paramFunc: Long => SQLMetricParam[SQLMetricValue[Any], Any]): Map[Long, Any] = {
+ paramFunc: Long => SQLMetricParam[SQLMetricValue[Any], Any]): Map[Long, String] = {
accumulatorUpdates.groupBy(_._1).map { case (accumulatorId, values) =>
val param = paramFunc(accumulatorId)
(accumulatorId,
- values.map(_._2.asInstanceOf[SQLMetricValue[Any]]).foldLeft(param.zero)(param.addInPlace))
+ param.stringValue(values.map(_._2.asInstanceOf[SQLMetricValue[Any]].value)))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
index ae3d752dde..f1fce5478a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
@@ -33,7 +33,7 @@ import org.apache.spark.sql.execution.metric.{SQLMetricParam, SQLMetricValue}
private[ui] case class SparkPlanGraph(
nodes: Seq[SparkPlanGraphNode], edges: Seq[SparkPlanGraphEdge]) {
- def makeDotFile(metrics: Map[Long, Any]): String = {
+ def makeDotFile(metrics: Map[Long, String]): String = {
val dotFile = new StringBuilder
dotFile.append("digraph G {\n")
nodes.foreach(node => dotFile.append(node.makeDotNode(metrics) + "\n"))
@@ -87,7 +87,7 @@ private[sql] object SparkPlanGraph {
private[ui] case class SparkPlanGraphNode(
id: Long, name: String, desc: String, metrics: Seq[SQLPlanMetric]) {
- def makeDotNode(metricsValue: Map[Long, Any]): String = {
+ def makeDotNode(metricsValue: Map[Long, String]): String = {
val values = {
for (metric <- metrics;
value <- metricsValue.get(metric.accumulatorId)) yield {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
index 0cc4988ff6..cc0ac1b07c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
@@ -39,7 +39,8 @@ class TungstenAggregationIteratorSuite extends SparkFunSuite with SharedSQLConte
}
val dummyAccum = SQLMetrics.createLongMetric(sparkContext, "dummy")
iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty, Seq.empty, Seq.empty,
- 0, Seq.empty, newMutableProjection, Seq.empty, None, dummyAccum, dummyAccum)
+ 0, Seq.empty, newMutableProjection, Seq.empty, None,
+ dummyAccum, dummyAccum, dummyAccum, dummyAccum)
val numPages = iter.getHashMap.getNumDataPages
assert(numPages === 1)
} finally {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index 6afffae161..cdd885ba14 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -93,7 +93,16 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
}.toMap
(node.id, node.name -> nodeMetrics)
}.toMap
- assert(expectedMetrics === actualMetrics)
+
+ assert(expectedMetrics.keySet === actualMetrics.keySet)
+ for (nodeId <- expectedMetrics.keySet) {
+ val (expectedNodeName, expectedMetricsMap) = expectedMetrics(nodeId)
+ val (actualNodeName, actualMetricsMap) = actualMetrics(nodeId)
+ assert(expectedNodeName === actualNodeName)
+ for (metricName <- expectedMetricsMap.keySet) {
+ assert(expectedMetricsMap(metricName).toString === actualMetricsMap(metricName))
+ }
+ }
} else {
// TODO Remove this "else" once we fix the race condition that missing the JobStarted event.
// Since we cannot track all jobs, the metric values could be wrong and we should not check
@@ -489,7 +498,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
val metricValues = sqlContext.listener.getExecutionMetrics(executionId)
// Because "save" will create a new DataFrame internally, we cannot get the real metric id.
// However, we still can check the value.
- assert(metricValues.values.toSeq === Seq(2L))
+ assert(metricValues.values.toSeq === Seq("2"))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
index 727cf3665a..cc1c1e10e9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
@@ -74,6 +74,10 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
}
test("basic") {
+ def checkAnswer(actual: Map[Long, String], expected: Map[Long, Long]): Unit = {
+ assert(actual === expected.mapValues(_.toString))
+ }
+
val listener = new SQLListener(sqlContext.sparkContext.conf)
val executionId = 0
val df = createTestDataFrame
@@ -114,7 +118,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
(1L, 0, 0, createTaskMetrics(accumulatorUpdates))
)))
- assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 2))
+ checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2))
listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq(
// (task id, stage id, stage attempt, metrics)
@@ -122,7 +126,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
(1L, 0, 0, createTaskMetrics(accumulatorUpdates.mapValues(_ * 2)))
)))
- assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 3))
+ checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 3))
// Retrying a stage should reset the metrics
listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(0, 1)))
@@ -133,7 +137,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
(1L, 0, 1, createTaskMetrics(accumulatorUpdates))
)))
- assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 2))
+ checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2))
// Ignore the task end for the first attempt
listener.onTaskEnd(SparkListenerTaskEnd(
@@ -144,7 +148,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
createTaskInfo(0, 0),
createTaskMetrics(accumulatorUpdates.mapValues(_ * 100))))
- assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 2))
+ checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2))
// Finish two tasks
listener.onTaskEnd(SparkListenerTaskEnd(
@@ -162,7 +166,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
createTaskInfo(1, 0),
createTaskMetrics(accumulatorUpdates.mapValues(_ * 3))))
- assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 5))
+ checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 5))
// Summit a new stage
listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(1, 0)))
@@ -173,7 +177,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
(1L, 1, 0, createTaskMetrics(accumulatorUpdates))
)))
- assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 7))
+ checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 7))
// Finish two tasks
listener.onTaskEnd(SparkListenerTaskEnd(
@@ -191,7 +195,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
createTaskInfo(1, 0),
createTaskMetrics(accumulatorUpdates.mapValues(_ * 3))))
- assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 11))
+ checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 11))
assert(executionUIData.runningJobs === Seq(0))
assert(executionUIData.succeededJobs.isEmpty)
@@ -208,7 +212,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
assert(executionUIData.succeededJobs === Seq(0))
assert(executionUIData.failedJobs.isEmpty)
- assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 11))
+ checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 11))
}
test("onExecutionEnd happens before onJobEnd(JobSucceeded)") {