aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/InputGeneratedOutputMetricsSuite.scala131
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala104
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala11
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala18
6 files changed, 125 insertions, 155 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 87932e09a1..760ead42c7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -31,6 +31,7 @@ import org.codehaus.janino.{ByteArrayClassLoader, ClassBodyEvaluator, SimpleComp
import org.codehaus.janino.util.ClassFile
import org.apache.spark.{SparkEnv, TaskContext, TaskKilledException}
+import org.apache.spark.executor.InputMetrics
import org.apache.spark.internal.Logging
import org.apache.spark.metrics.source.CodegenMetrics
import org.apache.spark.sql.catalyst.InternalRow
@@ -933,7 +934,8 @@ object CodeGenerator extends Logging {
classOf[UnsafeMapData].getName,
classOf[Expression].getName,
classOf[TaskContext].getName,
- classOf[TaskKilledException].getName
+ classOf[TaskKilledException].getName,
+ classOf[InputMetrics].getName
))
evaluator.setExtendedClass(classOf[GeneratedClass])
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index c01f9c5e3d..87e90ed685 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -365,6 +365,9 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
val taskContext = ctx.freshName("taskContext")
ctx.addMutableState("TaskContext", taskContext, s"$taskContext = TaskContext.get();")
+ val inputMetrics = ctx.freshName("inputMetrics")
+ ctx.addMutableState("InputMetrics", inputMetrics,
+ s"$inputMetrics = $taskContext.taskMetrics().inputMetrics();")
// In order to periodically update the metrics without inflicting performance penalty, this
// operator produces elements in batches. After a batch is complete, the metrics are updated
@@ -460,7 +463,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
| if ($nextBatchTodo == 0) break;
| }
| $numOutput.add($nextBatchTodo);
- | $numGenerated.add($nextBatchTodo);
+ | $inputMetrics.incRecordsRead($nextBatchTodo);
|
| $batchEnd += $nextBatchTodo * ${step}L;
| }
@@ -469,7 +472,6 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
- val numGeneratedRows = longMetric("numGeneratedRows")
sqlContext
.sparkContext
.parallelize(0 until numSlices, numSlices)
@@ -488,10 +490,12 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
val safePartitionEnd = getSafeMargin(partitionEnd)
val rowSize = UnsafeRow.calculateBitSetWidthInBytes(1) + LongType.defaultSize
val unsafeRow = UnsafeRow.createFromByteArray(rowSize, 1)
+ val taskContext = TaskContext.get()
val iter = new Iterator[InternalRow] {
private[this] var number: Long = safePartitionStart
private[this] var overflow: Boolean = false
+ private[this] val inputMetrics = taskContext.taskMetrics().inputMetrics
override def hasNext =
if (!overflow) {
@@ -513,12 +517,12 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
}
numOutputRows += 1
- numGeneratedRows += 1
+ inputMetrics.incRecordsRead(1)
unsafeRow.setLong(0, ret)
unsafeRow
}
}
- new InterruptibleIterator(TaskContext.get(), iter)
+ new InterruptibleIterator(taskContext, iter)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/InputGeneratedOutputMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/InputGeneratedOutputMetricsSuite.scala
deleted file mode 100644
index ddd7a03e80..0000000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/InputGeneratedOutputMetricsSuite.scala
+++ /dev/null
@@ -1,131 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution
-
-import java.io.File
-
-import org.scalatest.concurrent.Eventually
-
-import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
-import org.apache.spark.sql.{DataFrame, QueryTest}
-import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.util.Utils
-
-class InputGeneratedOutputMetricsSuite extends QueryTest with SharedSQLContext with Eventually {
-
- test("Range query input/output/generated metrics") {
- val numRows = 150L
- val numSelectedRows = 100L
- val res = MetricsTestHelper.runAndGetMetrics(spark.range(0, numRows, 1).
- filter(x => x < numSelectedRows).toDF())
-
- assert(res.recordsRead.sum === 0)
- assert(res.shuffleRecordsRead.sum === 0)
- assert(res.generatedRows === numRows :: Nil)
- assert(res.outputRows === numSelectedRows :: numRows :: Nil)
- }
-
- test("Input/output/generated metrics with repartitioning") {
- val numRows = 100L
- val res = MetricsTestHelper.runAndGetMetrics(
- spark.range(0, numRows).repartition(3).filter(x => x % 5 == 0).toDF())
-
- assert(res.recordsRead.sum === 0)
- assert(res.shuffleRecordsRead.sum === numRows)
- assert(res.generatedRows === numRows :: Nil)
- assert(res.outputRows === 20 :: numRows :: Nil)
- }
-
- test("Input/output/generated metrics with more repartitioning") {
- withTempDir { tempDir =>
- val dir = new File(tempDir, "pqS").getCanonicalPath
-
- spark.range(10).write.parquet(dir)
- spark.read.parquet(dir).createOrReplaceTempView("pqS")
-
- val res = MetricsTestHelper.runAndGetMetrics(
- spark.range(0, 30).repartition(3).crossJoin(sql("select * from pqS")).repartition(2)
- .toDF()
- )
-
- assert(res.recordsRead.sum == 10)
- assert(res.shuffleRecordsRead.sum == 3 * 10 + 2 * 150)
- assert(res.generatedRows == 30 :: Nil)
- assert(res.outputRows == 10 :: 30 :: 300 :: Nil)
- }
- }
-}
-
-object MetricsTestHelper {
- case class AggregatedMetricsResult(
- recordsRead: List[Long],
- shuffleRecordsRead: List[Long],
- generatedRows: List[Long],
- outputRows: List[Long])
-
- private[this] def extractMetricValues(
- df: DataFrame,
- metricValues: Map[Long, String],
- metricName: String): List[Long] = {
- df.queryExecution.executedPlan.collect {
- case plan if plan.metrics.contains(metricName) =>
- metricValues(plan.metrics(metricName).id).toLong
- }.toList.sorted
- }
-
- def runAndGetMetrics(df: DataFrame, useWholeStageCodeGen: Boolean = false):
- AggregatedMetricsResult = {
- val spark = df.sparkSession
- val sparkContext = spark.sparkContext
-
- var recordsRead = List[Long]()
- var shuffleRecordsRead = List[Long]()
- val listener = new SparkListener() {
- override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
- if (taskEnd.taskMetrics != null) {
- recordsRead = taskEnd.taskMetrics.inputMetrics.recordsRead ::
- recordsRead
- shuffleRecordsRead = taskEnd.taskMetrics.shuffleReadMetrics.recordsRead ::
- shuffleRecordsRead
- }
- }
- }
-
- val oldExecutionIds = spark.sharedState.listener.executionIdToData.keySet
-
- val prevUseWholeStageCodeGen =
- spark.sessionState.conf.getConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED)
- try {
- spark.sessionState.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, useWholeStageCodeGen)
- sparkContext.listenerBus.waitUntilEmpty(10000)
- sparkContext.addSparkListener(listener)
- df.collect()
- sparkContext.listenerBus.waitUntilEmpty(10000)
- } finally {
- spark.sessionState.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, prevUseWholeStageCodeGen)
- }
-
- val executionId = spark.sharedState.listener.executionIdToData.keySet.diff(oldExecutionIds).head
- val metricValues = spark.sharedState.listener.getExecutionMetrics(executionId)
- val outputRes = extractMetricValues(df, metricValues, "numOutputRows")
- val generatedRes = extractMetricValues(df, metricValues, "numGeneratedRows")
-
- AggregatedMetricsResult(recordsRead.sorted, shuffleRecordsRead.sorted, generatedRes, outputRes)
- }
-}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index 229d8814e0..2ce7db6a22 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -17,7 +17,12 @@
package org.apache.spark.sql.execution.metric
+import java.io.File
+
+import scala.collection.mutable.HashMap
+
import org.apache.spark.SparkFunSuite
+import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.execution.SparkPlanInfo
@@ -309,4 +314,103 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
assert(metricInfoDeser.metadata === Some(AccumulatorContext.SQL_ACCUM_IDENTIFIER))
}
+ test("range metrics") {
+ val res1 = InputOutputMetricsHelper.run(
+ spark.range(30).filter(x => x % 3 == 0).toDF()
+ )
+ assert(res1 === (30L, 0L, 30L) :: Nil)
+
+ val res2 = InputOutputMetricsHelper.run(
+ spark.range(150).repartition(4).filter(x => x < 10).toDF()
+ )
+ assert(res2 === (150L, 0L, 150L) :: (0L, 150L, 10L) :: Nil)
+
+ withTempDir { tempDir =>
+ val dir = new File(tempDir, "pqS").getCanonicalPath
+
+ spark.range(10).write.parquet(dir)
+ spark.read.parquet(dir).createOrReplaceTempView("pqS")
+
+ val res3 = InputOutputMetricsHelper.run(
+ spark.range(30).repartition(3).crossJoin(sql("select * from pqS")).repartition(2).toDF()
+ )
+ // The query above is executed in the following stages:
+ // 1. sql("select * from pqS") => (10, 0, 10)
+ // 2. range(30) => (30, 0, 30)
+ // 3. crossJoin(...) of 1. and 2. => (0, 30, 300)
+ // 4. shuffle & return results => (0, 300, 0)
+ assert(res3 === (10L, 0L, 10L) :: (30L, 0L, 30L) :: (0L, 30L, 300L) :: (0L, 300L, 0L) :: Nil)
+ }
+ }
+}
+
+object InputOutputMetricsHelper {
+ private class InputOutputMetricsListener extends SparkListener {
+ private case class MetricsResult(
+ var recordsRead: Long = 0L,
+ var shuffleRecordsRead: Long = 0L,
+ var sumMaxOutputRows: Long = 0L)
+
+ private[this] val stageIdToMetricsResult = HashMap.empty[Int, MetricsResult]
+
+ def reset(): Unit = {
+ stageIdToMetricsResult.clear()
+ }
+
+ /**
+ * Return a list of recorded metrics aggregated per stage.
+ *
+ * The list is sorted in the ascending order on the stageId.
+ * For each recorded stage, the following tuple is returned:
+ * - sum of inputMetrics.recordsRead for all the tasks in the stage
+ * - sum of shuffleReadMetrics.recordsRead for all the tasks in the stage
+ * - sum of the highest values of "number of output rows" metric for all the tasks in the stage
+ */
+ def getResults(): List[(Long, Long, Long)] = {
+ stageIdToMetricsResult.keySet.toList.sorted.map { stageId =>
+ val res = stageIdToMetricsResult(stageId)
+ (res.recordsRead, res.shuffleRecordsRead, res.sumMaxOutputRows)
+ }
+ }
+
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized {
+ val res = stageIdToMetricsResult.getOrElseUpdate(taskEnd.stageId, MetricsResult())
+
+ res.recordsRead += taskEnd.taskMetrics.inputMetrics.recordsRead
+ res.shuffleRecordsRead += taskEnd.taskMetrics.shuffleReadMetrics.recordsRead
+
+ var maxOutputRows = 0L
+ for (accum <- taskEnd.taskMetrics.externalAccums) {
+ val info = accum.toInfo(Some(accum.value), None)
+ if (info.name.toString.contains("number of output rows")) {
+ info.update match {
+ case Some(n: Number) =>
+ if (n.longValue() > maxOutputRows) {
+ maxOutputRows = n.longValue()
+ }
+ case _ => // Ignore.
+ }
+ }
+ }
+ res.sumMaxOutputRows += maxOutputRows
+ }
+ }
+
+ // Run df.collect() and return aggregated metrics for each stage.
+ def run(df: DataFrame): List[(Long, Long, Long)] = {
+ val spark = df.sparkSession
+ val sparkContext = spark.sparkContext
+ val listener = new InputOutputMetricsListener()
+ sparkContext.addSparkListener(listener)
+
+ try {
+ sparkContext.listenerBus.waitUntilEmpty(5000)
+ listener.reset()
+ df.collect()
+ sparkContext.listenerBus.waitUntilEmpty(5000)
+ } finally {
+ sparkContext.removeSparkListener(listener)
+ }
+ listener.getResults()
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index 92d3e9519f..5463728ca0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.DataSourceScanExec
import org.apache.spark.sql.execution.command.ExplainCommand
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation, JdbcUtils}
-import org.apache.spark.sql.execution.MetricsTestHelper
+import org.apache.spark.sql.execution.metric.InputOutputMetricsHelper
import org.apache.spark.sql.sources._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
@@ -917,13 +917,10 @@ class JDBCSuite extends SparkFunSuite
assert(e2.contains("User specified schema not supported with `jdbc`"))
}
- test("Input/generated/output metrics on JDBC") {
+ test("Checking metrics correctness with JDBC") {
val foobarCnt = spark.table("foobar").count()
- val res = MetricsTestHelper.runAndGetMetrics(sql("SELECT * FROM foobar").toDF())
- assert(res.recordsRead === foobarCnt :: Nil)
- assert(res.shuffleRecordsRead.sum === 0)
- assert(res.generatedRows.isEmpty)
- assert(res.outputRows === foobarCnt :: Nil)
+ val res = InputOutputMetricsHelper.run(sql("SELECT * FROM foobar").toDF())
+ assert(res === (foobarCnt, 0L, foobarCnt) :: Nil)
}
test("SPARK-19318: Connection properties keys should be case-sensitive.") {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala
index 35c41b531c..7803ac39e5 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.hive.execution
import org.scalatest.BeforeAndAfterAll
-import org.apache.spark.sql.execution.MetricsTestHelper
+import org.apache.spark.sql.execution.metric.InputOutputMetricsHelper
import org.apache.spark.sql.hive.test.TestHive
/**
@@ -49,21 +49,15 @@ class HiveSerDeSuite extends HiveComparisonTest with BeforeAndAfterAll {
createQueryTest("Read Partitioned with AvroSerDe", "SELECT * FROM episodes_part")
- test("Test input/generated/output metrics") {
+ test("Checking metrics correctness") {
import TestHive._
val episodesCnt = sql("select * from episodes").count()
- val episodesRes = MetricsTestHelper.runAndGetMetrics(sql("select * from episodes").toDF())
- assert(episodesRes.recordsRead === episodesCnt :: Nil)
- assert(episodesRes.shuffleRecordsRead.sum === 0)
- assert(episodesRes.generatedRows.isEmpty)
- assert(episodesRes.outputRows === episodesCnt :: Nil)
+ val episodesRes = InputOutputMetricsHelper.run(sql("select * from episodes").toDF())
+ assert(episodesRes === (episodesCnt, 0L, episodesCnt) :: Nil)
val serdeinsCnt = sql("select * from serdeins").count()
- val serdeinsRes = MetricsTestHelper.runAndGetMetrics(sql("select * from serdeins").toDF())
- assert(serdeinsRes.recordsRead === serdeinsCnt :: Nil)
- assert(serdeinsRes.shuffleRecordsRead.sum === 0)
- assert(serdeinsRes.generatedRows.isEmpty)
- assert(serdeinsRes.outputRows === serdeinsCnt :: Nil)
+ val serdeinsRes = InputOutputMetricsHelper.run(sql("select * from serdeins").toDF())
+ assert(serdeinsRes === (serdeinsCnt, 0L, serdeinsCnt) :: Nil)
}
}