aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorAndrew Or <andrew@databricks.com>2016-01-27 11:15:48 -0800
committerJosh Rosen <joshrosen@databricks.com>2016-01-27 11:15:48 -0800
commit87abcf7df921a5937fdb2bae8bfb30bfabc4970a (patch)
tree74b5a1cb19f06c40bd99a85feee3f35efbb9a496 /sql
parentedd473751b59b55fa3daede5ed7bc19ea8bd7170 (diff)
downloadspark-87abcf7df921a5937fdb2bae8bfb30bfabc4970a.tar.gz
spark-87abcf7df921a5937fdb2bae8bfb30bfabc4970a.tar.bz2
spark-87abcf7df921a5937fdb2bae8bfb30bfabc4970a.zip
[SPARK-12895][SPARK-12896] Migrate TaskMetrics to accumulators
The high level idea is that instead of having the executors send both accumulator updates and TaskMetrics, we should have them send only accumulator updates. This eliminates the need to maintain both code paths since one can be implemented in terms of the other. This effort is split into two parts: **SPARK-12895: Implement TaskMetrics using accumulators.** TaskMetrics is basically just a bunch of accumulable fields. This patch makes TaskMetrics a syntactic wrapper around a collection of accumulators so we don't need to send TaskMetrics from the executors to the driver. **SPARK-12896: Send only accumulator updates to the driver.** Now that TaskMetrics are expressed in terms of accumulators, we can capture all TaskMetrics values if we just send accumulator updates from the executors to the driver. This completes the parent issue SPARK-10620. While an effort has been made to preserve as much of the public API as possible, there were a few known breaking DeveloperApi changes that would be very awkward to maintain. I will gather the full list shortly and post it here. Note: This was once part of #10717. This patch is split out into its own patch from there to make it easier for others to review. Other smaller pieces of already been merged into master. Author: Andrew Or <andrew@databricks.com> Closes #10835 from andrewor14/task-metrics-use-accums.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala22
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala22
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala38
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala34
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala2
15 files changed, 92 insertions, 66 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
index 73dc8cb984..75cb6d1137 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
@@ -79,17 +79,17 @@ case class Sort(
sorter.setTestSpillFrequency(testSpillFrequency)
}
+ val metrics = TaskContext.get().taskMetrics()
// Remember spill data size of this task before execute this operator so that we can
// figure out how many bytes we spilled for this operator.
- val spillSizeBefore = TaskContext.get().taskMetrics().memoryBytesSpilled
+ val spillSizeBefore = metrics.memoryBytesSpilled
val sortedIterator = sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]])
dataSize += sorter.getPeakMemoryUsage
- spillSize += TaskContext.get().taskMetrics().memoryBytesSpilled - spillSizeBefore
+ spillSize += metrics.memoryBytesSpilled - spillSizeBefore
+ metrics.incPeakExecutionMemory(sorter.getPeakMemoryUsage)
- TaskContext.get().internalMetricsToAccumulators(
- InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.getPeakMemoryUsage)
sortedIterator
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index 41799c596b..001e9c306a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -418,10 +418,10 @@ class TungstenAggregationIterator(
val mapMemory = hashMap.getPeakMemoryUsedBytes
val sorterMemory = Option(externalSorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L)
val peakMemory = Math.max(mapMemory, sorterMemory)
+ val metrics = TaskContext.get().taskMetrics()
dataSize += peakMemory
- spillSize += TaskContext.get().taskMetrics().memoryBytesSpilled - spillSizeBefore
- TaskContext.get().internalMetricsToAccumulators(
- InternalAccumulator.PEAK_EXECUTION_MEMORY).add(peakMemory)
+ spillSize += metrics.memoryBytesSpilled - spillSizeBefore
+ metrics.incPeakExecutionMemory(peakMemory)
}
numOutputRows += 1
res
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala
index 8222b84d33..edd87c2d8e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala
@@ -136,14 +136,17 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag](
// Find a function that will return the FileSystem bytes read by this thread. Do this before
// creating RecordReader, because RecordReader's constructor might read some bytes
- val bytesReadCallback = inputMetrics.bytesReadCallback.orElse {
- split.serializableHadoopSplit.value match {
- case _: FileSplit | _: CombineFileSplit =>
- SparkHadoopUtil.get.getFSBytesReadOnThreadCallback()
- case _ => None
+ val getBytesReadCallback: Option[() => Long] = split.serializableHadoopSplit.value match {
+ case _: FileSplit | _: CombineFileSplit =>
+ SparkHadoopUtil.get.getFSBytesReadOnThreadCallback()
+ case _ => None
+ }
+
+ def updateBytesRead(): Unit = {
+ getBytesReadCallback.foreach { getBytesRead =>
+ inputMetrics.setBytesRead(getBytesRead())
}
}
- inputMetrics.setBytesReadCallback(bytesReadCallback)
val format = inputFormatClass.newInstance
format match {
@@ -208,6 +211,9 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag](
if (!finished) {
inputMetrics.incRecordsRead(1)
}
+ if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) {
+ updateBytesRead()
+ }
reader.getCurrentValue
}
@@ -228,8 +234,8 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag](
} finally {
reader = null
}
- if (bytesReadCallback.isDefined) {
- inputMetrics.updateBytesRead()
+ if (getBytesReadCallback.isDefined) {
+ updateBytesRead()
} else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] ||
split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) {
// If we can't get the bytes read from the FS stats, fall back to the split size,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index c9ea579b5e..04640711d9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -111,8 +111,7 @@ case class BroadcastHashJoin(
val hashedRelation = broadcastRelation.value
hashedRelation match {
case unsafe: UnsafeHashedRelation =>
- TaskContext.get().internalMetricsToAccumulators(
- InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize)
+ TaskContext.get().taskMetrics().incPeakExecutionMemory(unsafe.getUnsafeSize)
case _ =>
}
hashJoin(streamedIter, numStreamedRows, hashedRelation, numOutputRows)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
index 6c7fa2eee5..db8edd169d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
@@ -119,8 +119,7 @@ case class BroadcastHashOuterJoin(
hashTable match {
case unsafe: UnsafeHashedRelation =>
- TaskContext.get().internalMetricsToAccumulators(
- InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize)
+ TaskContext.get().taskMetrics().incPeakExecutionMemory(unsafe.getUnsafeSize)
case _ =>
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
index 004407b2e6..8929dc3af1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
@@ -66,8 +66,7 @@ case class BroadcastLeftSemiJoinHash(
val hashedRelation = broadcastedRelation.value
hashedRelation match {
case unsafe: UnsafeHashedRelation =>
- TaskContext.get().internalMetricsToAccumulators(
- InternalAccumulator.PEAK_EXECUTION_MEMORY).add(unsafe.getUnsafeSize)
+ TaskContext.get().taskMetrics().incPeakExecutionMemory(unsafe.getUnsafeSize)
case _ =>
}
hashSemiJoin(streamIter, numLeftRows, hashedRelation, numOutputRows)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
index 52735c9d7f..950dc78162 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.execution.metric
-import org.apache.spark.{Accumulable, AccumulableParam, SparkContext}
+import org.apache.spark.{Accumulable, AccumulableParam, Accumulators, SparkContext}
import org.apache.spark.util.Utils
/**
@@ -28,7 +28,7 @@ import org.apache.spark.util.Utils
*/
private[sql] abstract class SQLMetric[R <: SQLMetricValue[T], T](
name: String, val param: SQLMetricParam[R, T])
- extends Accumulable[R, T](param.zero, param, Some(name), true) {
+ extends Accumulable[R, T](param.zero, param, Some(name), internal = true) {
def reset(): Unit = {
this.value = param.zero
@@ -131,6 +131,8 @@ private[sql] object SQLMetrics {
name: String,
param: LongSQLMetricParam): LongSQLMetric = {
val acc = new LongSQLMetric(name, param)
+ // This is an internal accumulator so we need to register it explicitly.
+ Accumulators.register(acc)
sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc))
acc
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
index 83c64f755f..544606f116 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
@@ -139,9 +139,8 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi
override def onExecutorMetricsUpdate(
executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = synchronized {
- for ((taskId, stageId, stageAttemptID, metrics) <- executorMetricsUpdate.taskMetrics) {
- updateTaskAccumulatorValues(taskId, stageId, stageAttemptID, metrics.accumulatorUpdates(),
- finishTask = false)
+ for ((taskId, stageId, stageAttemptID, accumUpdates) <- executorMetricsUpdate.accumUpdates) {
+ updateTaskAccumulatorValues(taskId, stageId, stageAttemptID, accumUpdates, finishTask = false)
}
}
@@ -177,7 +176,7 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi
taskId: Long,
stageId: Int,
stageAttemptID: Int,
- accumulatorUpdates: Map[Long, Any],
+ accumulatorUpdates: Seq[AccumulableInfo],
finishTask: Boolean): Unit = {
_stageIdToStageMetrics.get(stageId) match {
@@ -289,8 +288,10 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi
for (stageId <- executionUIData.stages;
stageMetrics <- _stageIdToStageMetrics.get(stageId).toIterable;
taskMetrics <- stageMetrics.taskIdToMetricUpdates.values;
- accumulatorUpdate <- taskMetrics.accumulatorUpdates.toSeq) yield {
- accumulatorUpdate
+ accumulatorUpdate <- taskMetrics.accumulatorUpdates) yield {
+ assert(accumulatorUpdate.update.isDefined, s"accumulator update from " +
+ s"task did not have a partial value: ${accumulatorUpdate.name}")
+ (accumulatorUpdate.id, accumulatorUpdate.update.get)
}
}.filter { case (id, _) => executionUIData.accumulatorMetrics.contains(id) }
mergeAccumulatorUpdates(accumulatorUpdates, accumulatorId =>
@@ -328,9 +329,10 @@ private[spark] class SQLHistoryListener(conf: SparkConf, sparkUI: SparkUI)
taskEnd.taskInfo.taskId,
taskEnd.stageId,
taskEnd.stageAttemptId,
- taskEnd.taskInfo.accumulables.map { acc =>
- (acc.id, new LongSQLMetricValue(acc.update.getOrElse("0").toLong))
- }.toMap,
+ taskEnd.taskInfo.accumulables.map { a =>
+ val newValue = new LongSQLMetricValue(a.update.map(_.asInstanceOf[Long]).getOrElse(0L))
+ a.copy(update = Some(newValue))
+ },
finishTask = true)
}
@@ -406,4 +408,4 @@ private[ui] class SQLStageMetrics(
private[ui] class SQLTaskMetrics(
val attemptId: Long, // TODO not used yet
var finished: Boolean,
- var accumulatorUpdates: Map[Long, Any])
+ var accumulatorUpdates: Seq[AccumulableInfo])
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 47308966e9..10ccd4b8f6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1648,7 +1648,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
test("external sorting updates peak execution memory") {
AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "external sort") {
- sortTest()
+ sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC").collect()
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala
index 9575d26fd1..273937fa8c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReferenceSort.scala
@@ -49,8 +49,7 @@ case class ReferenceSort(
val context = TaskContext.get()
context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
- context.internalMetricsToAccumulators(
- InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes)
+ context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
CompletionIterator[InternalRow, Iterator[InternalRow]](baseIterator, sorter.stop())
}, preservesPartitioning = true)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
index 9c258cb31f..c7df8b51e2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
@@ -71,8 +71,7 @@ class UnsafeFixedWidthAggregationMapSuite
taskAttemptId = Random.nextInt(10000),
attemptNumber = 0,
taskMemoryManager = taskMemoryManager,
- metricsSystem = null,
- internalAccumulators = Seq.empty))
+ metricsSystem = null))
try {
f
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
index 8a95359d9d..e03bd6a3e7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
@@ -117,8 +117,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext {
taskAttemptId = 98456,
attemptNumber = 0,
taskMemoryManager = taskMemMgr,
- metricsSystem = null,
- internalAccumulators = Seq.empty))
+ metricsSystem = null))
val sorter = new UnsafeKVExternalSorter(
keySchema, valueSchema, SparkEnv.get.blockManager, pageSize)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala
index 647a7e9a4e..86c2c25c2c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/PartitionBatchPruningSuite.scala
@@ -17,12 +17,19 @@
package org.apache.spark.sql.execution.columnar
+import org.scalatest.BeforeAndAfterEach
+
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.test.SQLTestData._
-class PartitionBatchPruningSuite extends SparkFunSuite with SharedSQLContext {
+
+class PartitionBatchPruningSuite
+ extends SparkFunSuite
+ with BeforeAndAfterEach
+ with SharedSQLContext {
+
import testImplicits._
private lazy val originalColumnBatchSize = sqlContext.conf.columnBatchSize
@@ -32,30 +39,41 @@ class PartitionBatchPruningSuite extends SparkFunSuite with SharedSQLContext {
super.beforeAll()
// Make a table with 5 partitions, 2 batches per partition, 10 elements per batch
sqlContext.setConf(SQLConf.COLUMN_BATCH_SIZE, 10)
-
- val pruningData = sparkContext.makeRDD((1 to 100).map { key =>
- val string = if (((key - 1) / 10) % 2 == 0) null else key.toString
- TestData(key, string)
- }, 5).toDF()
- pruningData.registerTempTable("pruningData")
-
// Enable in-memory partition pruning
sqlContext.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true)
// Enable in-memory table scan accumulators
sqlContext.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true")
- sqlContext.cacheTable("pruningData")
}
override protected def afterAll(): Unit = {
try {
sqlContext.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize)
sqlContext.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning)
- sqlContext.uncacheTable("pruningData")
} finally {
super.afterAll()
}
}
+ override protected def beforeEach(): Unit = {
+ super.beforeEach()
+ // This creates accumulators, which get cleaned up after every single test,
+ // so we need to do this before every test.
+ val pruningData = sparkContext.makeRDD((1 to 100).map { key =>
+ val string = if (((key - 1) / 10) % 2 == 0) null else key.toString
+ TestData(key, string)
+ }, 5).toDF()
+ pruningData.registerTempTable("pruningData")
+ sqlContext.cacheTable("pruningData")
+ }
+
+ override protected def afterEach(): Unit = {
+ try {
+ sqlContext.uncacheTable("pruningData")
+ } finally {
+ super.afterEach()
+ }
+ }
+
// Comparisons
checkBatchPruning("SELECT key FROM pruningData WHERE key = 1", 1, 1)(Seq(1))
checkBatchPruning("SELECT key FROM pruningData WHERE 1 = key", 1, 1)(Seq(1))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
index 81a159d542..2c408c8878 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.ui
import java.util.Properties
+import org.mockito.Mockito.{mock, when}
+
import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite}
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.scheduler._
@@ -67,9 +69,11 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
)
private def createTaskMetrics(accumulatorUpdates: Map[Long, Long]): TaskMetrics = {
- val metrics = new TaskMetrics
- metrics.setAccumulatorsUpdater(() => accumulatorUpdates.mapValues(new LongSQLMetricValue(_)))
- metrics.updateAccumulators()
+ val metrics = mock(classOf[TaskMetrics])
+ when(metrics.accumulatorUpdates()).thenReturn(accumulatorUpdates.map { case (id, update) =>
+ new AccumulableInfo(id, Some(""), Some(new LongSQLMetricValue(update)),
+ value = None, internal = true, countFailedValues = true)
+ }.toSeq)
metrics
}
@@ -114,17 +118,17 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
assert(listener.getExecutionMetrics(0).isEmpty)
listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq(
- // (task id, stage id, stage attempt, metrics)
- (0L, 0, 0, createTaskMetrics(accumulatorUpdates)),
- (1L, 0, 0, createTaskMetrics(accumulatorUpdates))
+ // (task id, stage id, stage attempt, accum updates)
+ (0L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()),
+ (1L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates())
)))
checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2))
listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq(
- // (task id, stage id, stage attempt, metrics)
- (0L, 0, 0, createTaskMetrics(accumulatorUpdates)),
- (1L, 0, 0, createTaskMetrics(accumulatorUpdates.mapValues(_ * 2)))
+ // (task id, stage id, stage attempt, accum updates)
+ (0L, 0, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()),
+ (1L, 0, 0, createTaskMetrics(accumulatorUpdates.mapValues(_ * 2)).accumulatorUpdates())
)))
checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 3))
@@ -133,9 +137,9 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(0, 1)))
listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq(
- // (task id, stage id, stage attempt, metrics)
- (0L, 0, 1, createTaskMetrics(accumulatorUpdates)),
- (1L, 0, 1, createTaskMetrics(accumulatorUpdates))
+ // (task id, stage id, stage attempt, accum updates)
+ (0L, 0, 1, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()),
+ (1L, 0, 1, createTaskMetrics(accumulatorUpdates).accumulatorUpdates())
)))
checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 2))
@@ -173,9 +177,9 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(1, 0)))
listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq(
- // (task id, stage id, stage attempt, metrics)
- (0L, 1, 0, createTaskMetrics(accumulatorUpdates)),
- (1L, 1, 0, createTaskMetrics(accumulatorUpdates))
+ // (task id, stage id, stage attempt, accum updates)
+ (0L, 1, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates()),
+ (1L, 1, 0, createTaskMetrics(accumulatorUpdates).accumulatorUpdates())
)))
checkAnswer(listener.getExecutionMetrics(0), accumulatorUpdates.mapValues(_ * 7))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
index b46b0d2f60..9a24a2487a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
@@ -140,7 +140,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
.filter(_._2.name == InternalAccumulator.PEAK_EXECUTION_MEMORY)
assert(peakMemoryAccumulator.size == 1)
- peakMemoryAccumulator.head._2.value.toLong
+ peakMemoryAccumulator.head._2.value.get.asInstanceOf[Long]
}
assert(sparkListener.getCompletedStageInfos.length == 2)