aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/test
diff options
context:
space:
mode:
authorAla Luszczak <ala@databricks.com>2017-02-18 07:51:41 -0800
committerReynold Xin <rxin@databricks.com>2017-02-18 07:51:41 -0800
commitb486ffc86d8ad6c303321dcf8514afee723f61f8 (patch)
tree090b1eeb158c80cd51e6670d997351516bf22e15 /sql/core/src/test
parent729ce3703257aa34c00c5c8253e6971faf6a0c8d (diff)
downloadspark-b486ffc86d8ad6c303321dcf8514afee723f61f8.tar.gz
spark-b486ffc86d8ad6c303321dcf8514afee723f61f8.tar.bz2
spark-b486ffc86d8ad6c303321dcf8514afee723f61f8.zip
[SPARK-19447] Make Range operator generate "recordsRead" metric
## What changes were proposed in this pull request? The Range was modified to produce "recordsRead" metric instead of "generated rows". The tests were updated and partially moved to SQLMetricsSuite. ## How was this patch tested? Unit tests. Author: Ala Luszczak <ala@databricks.com> Closes #16960 from ala/range-records-read.
Diffstat (limited to 'sql/core/src/test')
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/InputGeneratedOutputMetricsSuite.scala131
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala104
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala11
3 files changed, 108 insertions, 138 deletions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/InputGeneratedOutputMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/InputGeneratedOutputMetricsSuite.scala
deleted file mode 100644
index ddd7a03e80..0000000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/InputGeneratedOutputMetricsSuite.scala
+++ /dev/null
@@ -1,131 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution
-
-import java.io.File
-
-import org.scalatest.concurrent.Eventually
-
-import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
-import org.apache.spark.sql.{DataFrame, QueryTest}
-import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.util.Utils
-
-class InputGeneratedOutputMetricsSuite extends QueryTest with SharedSQLContext with Eventually {
-
- test("Range query input/output/generated metrics") {
- val numRows = 150L
- val numSelectedRows = 100L
- val res = MetricsTestHelper.runAndGetMetrics(spark.range(0, numRows, 1).
- filter(x => x < numSelectedRows).toDF())
-
- assert(res.recordsRead.sum === 0)
- assert(res.shuffleRecordsRead.sum === 0)
- assert(res.generatedRows === numRows :: Nil)
- assert(res.outputRows === numSelectedRows :: numRows :: Nil)
- }
-
- test("Input/output/generated metrics with repartitioning") {
- val numRows = 100L
- val res = MetricsTestHelper.runAndGetMetrics(
- spark.range(0, numRows).repartition(3).filter(x => x % 5 == 0).toDF())
-
- assert(res.recordsRead.sum === 0)
- assert(res.shuffleRecordsRead.sum === numRows)
- assert(res.generatedRows === numRows :: Nil)
- assert(res.outputRows === 20 :: numRows :: Nil)
- }
-
- test("Input/output/generated metrics with more repartitioning") {
- withTempDir { tempDir =>
- val dir = new File(tempDir, "pqS").getCanonicalPath
-
- spark.range(10).write.parquet(dir)
- spark.read.parquet(dir).createOrReplaceTempView("pqS")
-
- val res = MetricsTestHelper.runAndGetMetrics(
- spark.range(0, 30).repartition(3).crossJoin(sql("select * from pqS")).repartition(2)
- .toDF()
- )
-
- assert(res.recordsRead.sum == 10)
- assert(res.shuffleRecordsRead.sum == 3 * 10 + 2 * 150)
- assert(res.generatedRows == 30 :: Nil)
- assert(res.outputRows == 10 :: 30 :: 300 :: Nil)
- }
- }
-}
-
-object MetricsTestHelper {
- case class AggregatedMetricsResult(
- recordsRead: List[Long],
- shuffleRecordsRead: List[Long],
- generatedRows: List[Long],
- outputRows: List[Long])
-
- private[this] def extractMetricValues(
- df: DataFrame,
- metricValues: Map[Long, String],
- metricName: String): List[Long] = {
- df.queryExecution.executedPlan.collect {
- case plan if plan.metrics.contains(metricName) =>
- metricValues(plan.metrics(metricName).id).toLong
- }.toList.sorted
- }
-
- def runAndGetMetrics(df: DataFrame, useWholeStageCodeGen: Boolean = false):
- AggregatedMetricsResult = {
- val spark = df.sparkSession
- val sparkContext = spark.sparkContext
-
- var recordsRead = List[Long]()
- var shuffleRecordsRead = List[Long]()
- val listener = new SparkListener() {
- override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
- if (taskEnd.taskMetrics != null) {
- recordsRead = taskEnd.taskMetrics.inputMetrics.recordsRead ::
- recordsRead
- shuffleRecordsRead = taskEnd.taskMetrics.shuffleReadMetrics.recordsRead ::
- shuffleRecordsRead
- }
- }
- }
-
- val oldExecutionIds = spark.sharedState.listener.executionIdToData.keySet
-
- val prevUseWholeStageCodeGen =
- spark.sessionState.conf.getConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED)
- try {
- spark.sessionState.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, useWholeStageCodeGen)
- sparkContext.listenerBus.waitUntilEmpty(10000)
- sparkContext.addSparkListener(listener)
- df.collect()
- sparkContext.listenerBus.waitUntilEmpty(10000)
- } finally {
- spark.sessionState.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, prevUseWholeStageCodeGen)
- }
-
- val executionId = spark.sharedState.listener.executionIdToData.keySet.diff(oldExecutionIds).head
- val metricValues = spark.sharedState.listener.getExecutionMetrics(executionId)
- val outputRes = extractMetricValues(df, metricValues, "numOutputRows")
- val generatedRes = extractMetricValues(df, metricValues, "numGeneratedRows")
-
- AggregatedMetricsResult(recordsRead.sorted, shuffleRecordsRead.sorted, generatedRes, outputRes)
- }
-}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index 229d8814e0..2ce7db6a22 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -17,7 +17,12 @@
package org.apache.spark.sql.execution.metric
+import java.io.File
+
+import scala.collection.mutable.HashMap
+
import org.apache.spark.SparkFunSuite
+import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.execution.SparkPlanInfo
@@ -309,4 +314,103 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
assert(metricInfoDeser.metadata === Some(AccumulatorContext.SQL_ACCUM_IDENTIFIER))
}
+ test("range metrics") {
+ val res1 = InputOutputMetricsHelper.run(
+ spark.range(30).filter(x => x % 3 == 0).toDF()
+ )
+ assert(res1 === (30L, 0L, 30L) :: Nil)
+
+ val res2 = InputOutputMetricsHelper.run(
+ spark.range(150).repartition(4).filter(x => x < 10).toDF()
+ )
+ assert(res2 === (150L, 0L, 150L) :: (0L, 150L, 10L) :: Nil)
+
+ withTempDir { tempDir =>
+ val dir = new File(tempDir, "pqS").getCanonicalPath
+
+ spark.range(10).write.parquet(dir)
+ spark.read.parquet(dir).createOrReplaceTempView("pqS")
+
+ val res3 = InputOutputMetricsHelper.run(
+ spark.range(30).repartition(3).crossJoin(sql("select * from pqS")).repartition(2).toDF()
+ )
+ // The query above is executed in the following stages:
+ // 1. sql("select * from pqS") => (10, 0, 10)
+ // 2. range(30) => (30, 0, 30)
+ // 3. crossJoin(...) of 1. and 2. => (0, 30, 300)
+ // 4. shuffle & return results => (0, 300, 0)
+ assert(res3 === (10L, 0L, 10L) :: (30L, 0L, 30L) :: (0L, 30L, 300L) :: (0L, 300L, 0L) :: Nil)
+ }
+ }
+}
+
+object InputOutputMetricsHelper {
+ private class InputOutputMetricsListener extends SparkListener {
+ private case class MetricsResult(
+ var recordsRead: Long = 0L,
+ var shuffleRecordsRead: Long = 0L,
+ var sumMaxOutputRows: Long = 0L)
+
+ private[this] val stageIdToMetricsResult = HashMap.empty[Int, MetricsResult]
+
+ def reset(): Unit = {
+ stageIdToMetricsResult.clear()
+ }
+
+ /**
+ * Return a list of recorded metrics aggregated per stage.
+ *
+ * The list is sorted in the ascending order on the stageId.
+ * For each recorded stage, the following tuple is returned:
+ * - sum of inputMetrics.recordsRead for all the tasks in the stage
+ * - sum of shuffleReadMetrics.recordsRead for all the tasks in the stage
+ * - sum of the highest values of "number of output rows" metric for all the tasks in the stage
+ */
+ def getResults(): List[(Long, Long, Long)] = {
+ stageIdToMetricsResult.keySet.toList.sorted.map { stageId =>
+ val res = stageIdToMetricsResult(stageId)
+ (res.recordsRead, res.shuffleRecordsRead, res.sumMaxOutputRows)
+ }
+ }
+
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized {
+ val res = stageIdToMetricsResult.getOrElseUpdate(taskEnd.stageId, MetricsResult())
+
+ res.recordsRead += taskEnd.taskMetrics.inputMetrics.recordsRead
+ res.shuffleRecordsRead += taskEnd.taskMetrics.shuffleReadMetrics.recordsRead
+
+ var maxOutputRows = 0L
+ for (accum <- taskEnd.taskMetrics.externalAccums) {
+ val info = accum.toInfo(Some(accum.value), None)
+ if (info.name.toString.contains("number of output rows")) {
+ info.update match {
+ case Some(n: Number) =>
+ if (n.longValue() > maxOutputRows) {
+ maxOutputRows = n.longValue()
+ }
+ case _ => // Ignore.
+ }
+ }
+ }
+ res.sumMaxOutputRows += maxOutputRows
+ }
+ }
+
+ // Run df.collect() and return aggregated metrics for each stage.
+ def run(df: DataFrame): List[(Long, Long, Long)] = {
+ val spark = df.sparkSession
+ val sparkContext = spark.sparkContext
+ val listener = new InputOutputMetricsListener()
+ sparkContext.addSparkListener(listener)
+
+ try {
+ sparkContext.listenerBus.waitUntilEmpty(5000)
+ listener.reset()
+ df.collect()
+ sparkContext.listenerBus.waitUntilEmpty(5000)
+ } finally {
+ sparkContext.removeSparkListener(listener)
+ }
+ listener.getResults()
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index 92d3e9519f..5463728ca0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.DataSourceScanExec
import org.apache.spark.sql.execution.command.ExplainCommand
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation, JdbcUtils}
-import org.apache.spark.sql.execution.MetricsTestHelper
+import org.apache.spark.sql.execution.metric.InputOutputMetricsHelper
import org.apache.spark.sql.sources._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
@@ -917,13 +917,10 @@ class JDBCSuite extends SparkFunSuite
assert(e2.contains("User specified schema not supported with `jdbc`"))
}
- test("Input/generated/output metrics on JDBC") {
+ test("Checking metrics correctness with JDBC") {
val foobarCnt = spark.table("foobar").count()
- val res = MetricsTestHelper.runAndGetMetrics(sql("SELECT * FROM foobar").toDF())
- assert(res.recordsRead === foobarCnt :: Nil)
- assert(res.shuffleRecordsRead.sum === 0)
- assert(res.generatedRows.isEmpty)
- assert(res.outputRows === foobarCnt :: Nil)
+ val res = InputOutputMetricsHelper.run(sql("SELECT * FROM foobar").toDF())
+ assert(res === (foobarCnt, 0L, foobarCnt) :: Nil)
}
test("SPARK-19318: Connection properties keys should be case-sensitive.") {