aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorAla Luszczak <ala@databricks.com>2017-02-07 14:21:30 +0100
committerReynold Xin <rxin@databricks.com>2017-02-07 14:21:30 +0100
commit6ed285c68fee451c45db7b01ca8ec1dea2efd479 (patch)
tree341cff8e7b21f695ccb90c825440bd82b15210ea /sql
parente99e34d0f370211a7c7b96d144cc932b2fc71d10 (diff)
downloadspark-6ed285c68fee451c45db7b01ca8ec1dea2efd479.tar.gz
spark-6ed285c68fee451c45db7b01ca8ec1dea2efd479.tar.bz2
spark-6ed285c68fee451c45db7b01ca8ec1dea2efd479.zip
[SPARK-19447] Fixing input metrics for range operator.
## What changes were proposed in this pull request? This change introduces a new metric "number of generated rows". It is used exclusively for Range, which is a leaf in the query tree, yet doesn't read any input data, and therefore cannot report "recordsRead". Additionally the way in which the metrics are reported by the JIT-compiled version of Range was changed. Previously, it was immediately reported that all the records were produced. This could be confusing for a user monitoring execution progress in the UI. Now, the metric is updated gradually. In order to avoid negative impact on Range performance, the code generation was reworked. The values are now produced in batches in the tighter inner loop, while the metrics are updated in the outer loop. The change also contains a number of unit tests, which should help ensure the correctness of metrics for various input sources. ## How was this patch tested? Unit tests. Author: Ala Luszczak <ala@databricks.com> Closes #16829 from ala/SPARK-19447.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala82
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala130
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala53
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/InputGeneratedOutputMetricsSuite.scala131
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala10
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala19
6 files changed, 350 insertions, 75 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index fb90799534..792fb3e795 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -339,7 +339,8 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
override val output: Seq[Attribute] = range.output
override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
+ "numGeneratedRows" -> SQLMetrics.createMetric(sparkContext, "number of generated rows"))
// output attributes should not affect the results
override lazy val cleanArgs: Seq[Any] = Seq(start, step, numSlices, numElements)
@@ -351,24 +352,37 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
protected override def doProduce(ctx: CodegenContext): String = {
val numOutput = metricTerm(ctx, "numOutputRows")
+ val numGenerated = metricTerm(ctx, "numGeneratedRows")
val initTerm = ctx.freshName("initRange")
ctx.addMutableState("boolean", initTerm, s"$initTerm = false;")
- val partitionEnd = ctx.freshName("partitionEnd")
- ctx.addMutableState("long", partitionEnd, s"$partitionEnd = 0L;")
val number = ctx.freshName("number")
ctx.addMutableState("long", number, s"$number = 0L;")
- val overflow = ctx.freshName("overflow")
- ctx.addMutableState("boolean", overflow, s"$overflow = false;")
val value = ctx.freshName("value")
val ev = ExprCode("", "false", value)
val BigInt = classOf[java.math.BigInteger].getName
- val checkEnd = if (step > 0) {
- s"$number < $partitionEnd"
- } else {
- s"$number > $partitionEnd"
- }
+
+ // In order to periodically update the metrics without inflicting performance penalty, this
+ // operator produces elements in batches. After a batch is complete, the metrics are updated
+ // and a new batch is started.
+ // In the implementation below, the code in the inner loop is producing all the values
+ // within a batch, while the code in the outer loop is setting batch parameters and updating
+ // the metrics.
+
+ // Once number == batchEnd, it's time to progress to the next batch.
+ val batchEnd = ctx.freshName("batchEnd")
+ ctx.addMutableState("long", batchEnd, s"$batchEnd = 0;")
+
+ // How many values should still be generated by this range operator.
+ val numElementsTodo = ctx.freshName("numElementsTodo")
+ ctx.addMutableState("long", numElementsTodo, s"$numElementsTodo = 0L;")
+
+ // How many values should be generated in the next batch.
+ val nextBatchTodo = ctx.freshName("nextBatchTodo")
+
+ // The default size of a batch.
+ val batchSize = 1000L
ctx.addNewFunction("initRange",
s"""
@@ -378,6 +392,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
| $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L);
| $BigInt step = $BigInt.valueOf(${step}L);
| $BigInt start = $BigInt.valueOf(${start}L);
+ | long partitionEnd;
|
| $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start);
| if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
@@ -387,18 +402,26 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
| } else {
| $number = st.longValue();
| }
+ | $batchEnd = $number;
|
| $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice)
| .multiply(step).add(start);
| if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
- | $partitionEnd = Long.MAX_VALUE;
+ | partitionEnd = Long.MAX_VALUE;
| } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) {
- | $partitionEnd = Long.MIN_VALUE;
+ | partitionEnd = Long.MIN_VALUE;
| } else {
- | $partitionEnd = end.longValue();
+ | partitionEnd = end.longValue();
| }
|
- | $numOutput.add(($partitionEnd - $number) / ${step}L);
+ | $BigInt startToEnd = $BigInt.valueOf(partitionEnd).subtract(
+ | $BigInt.valueOf($number));
+ | $numElementsTodo = startToEnd.divide(step).longValue();
+ | if ($numElementsTodo < 0) {
+ | $numElementsTodo = 0;
+ | } else if (startToEnd.remainder(step).compareTo($BigInt.valueOf(0L)) != 0) {
+ | $numElementsTodo++;
+ | }
| }
""".stripMargin)
@@ -412,20 +435,34 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
| initRange(partitionIndex);
| }
|
- | while (!$overflow && $checkEnd) {
- | long $value = $number;
- | $number += ${step}L;
- | if ($number < $value ^ ${step}L < 0) {
- | $overflow = true;
- | }
- | ${consume(ctx, Seq(ev))}
- | if (shouldStop()) return;
+ | while (true) {
+ | while ($number != $batchEnd) {
+ | long $value = $number;
+ | $number += ${step}L;
+ | ${consume(ctx, Seq(ev))}
+ | if (shouldStop()) return;
+ | }
+ |
+ | long $nextBatchTodo;
+ | if ($numElementsTodo > ${batchSize}L) {
+ | $nextBatchTodo = ${batchSize}L;
+ | $numElementsTodo -= ${batchSize}L;
+ | } else {
+ | $nextBatchTodo = $numElementsTodo;
+ | $numElementsTodo = 0;
+ | if ($nextBatchTodo == 0) break;
+ | }
+ | $numOutput.add($nextBatchTodo);
+ | $numGenerated.add($nextBatchTodo);
+ |
+ | $batchEnd += $nextBatchTodo * ${step}L;
| }
""".stripMargin
}
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
+ val numGeneratedRows = longMetric("numGeneratedRows")
sqlContext
.sparkContext
.parallelize(0 until numSlices, numSlices)
@@ -469,6 +506,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
}
numOutputRows += 1
+ numGeneratedRows += 1
unsafeRow.setLong(0, ret)
unsafeRow
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala
new file mode 100644
index 0000000000..6d2d776c92
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import scala.math.abs
+import scala.util.Random
+
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSQLContext
+
+class DataFrameRangeSuite extends QueryTest with SharedSQLContext {
+
+ test("SPARK-7150 range api") {
+ // numSlice is greater than length
+ val res1 = spark.range(0, 10, 1, 15).select("id")
+ assert(res1.count == 10)
+ assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45)))
+
+ val res2 = spark.range(3, 15, 3, 2).select("id")
+ assert(res2.count == 4)
+ assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30)))
+
+ val res3 = spark.range(1, -2).select("id")
+ assert(res3.count == 0)
+
+ // start is positive, end is negative, step is negative
+ val res4 = spark.range(1, -2, -2, 6).select("id")
+ assert(res4.count == 2)
+ assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0)))
+
+ // start, end, step are negative
+ val res5 = spark.range(-3, -8, -2, 1).select("id")
+ assert(res5.count == 3)
+ assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15)))
+
+ // start, end are negative, step is positive
+ val res6 = spark.range(-8, -4, 2, 1).select("id")
+ assert(res6.count == 2)
+ assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14)))
+
+ val res7 = spark.range(-10, -9, -20, 1).select("id")
+ assert(res7.count == 0)
+
+ val res8 = spark.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id")
+ assert(res8.count == 3)
+ assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3)))
+
+ val res9 = spark.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id")
+ assert(res9.count == 2)
+ assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1)))
+
+ // only end provided as argument
+ val res10 = spark.range(10).select("id")
+ assert(res10.count == 10)
+ assert(res10.agg(sum("id")).as("sumid").collect() === Seq(Row(45)))
+
+ val res11 = spark.range(-1).select("id")
+ assert(res11.count == 0)
+
+ // using the default slice number
+ val res12 = spark.range(3, 15, 3).select("id")
+ assert(res12.count == 4)
+ assert(res12.agg(sum("id")).as("sumid").collect() === Seq(Row(30)))
+
+ // difference between range start and end does not fit in a 64-bit integer
+ val n = 9L * 1000 * 1000 * 1000 * 1000 * 1000 * 1000
+ val res13 = spark.range(-n, n, n / 9).select("id")
+ assert(res13.count == 18)
+ }
+
+ test("Range with randomized parameters") {
+ val MAX_NUM_STEPS = 10L * 1000
+
+ val seed = System.currentTimeMillis()
+ val random = new Random(seed)
+
+ def randomBound(): Long = {
+ val n = if (random.nextBoolean()) {
+ random.nextLong() % (Long.MaxValue / (100 * MAX_NUM_STEPS))
+ } else {
+ random.nextLong() / 2
+ }
+ if (random.nextBoolean()) n else -n
+ }
+
+ for (l <- 1 to 10) {
+ val start = randomBound()
+ val end = randomBound()
+ val numSteps = (abs(random.nextLong()) % MAX_NUM_STEPS) + 1
+ val stepAbs = (abs(end - start) / numSteps) + 1
+ val step = if (start < end) stepAbs else -stepAbs
+ val partitions = random.nextInt(20) + 1
+
+ val expCount = (start until end by step).size
+ val expSum = (start until end by step).sum
+
+ for (codegen <- List(false, true)) {
+ withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString()) {
+ val res = spark.range(start, end, step, partitions).toDF("id").
+ agg(count("id"), sum("id")).collect()
+
+ withClue(s"seed = $seed start = $start end = $end step = $step partitions = " +
+ s"$partitions codegen = $codegen") {
+ assert(!res.isEmpty)
+ assert(res.head.getLong(0) == expCount)
+ if (expCount > 0) {
+ assert(res.head.getLong(1) == expSum)
+ }
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 6a190b98ea..e6338ab7cd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -979,59 +979,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
Seq(Row(2, 1, 2), Row(1, 2, 1), Row(1, 1, 1), Row(2, 2, 2)))
}
- test("SPARK-7150 range api") {
- // numSlice is greater than length
- val res1 = spark.range(0, 10, 1, 15).select("id")
- assert(res1.count == 10)
- assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45)))
-
- val res2 = spark.range(3, 15, 3, 2).select("id")
- assert(res2.count == 4)
- assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30)))
-
- val res3 = spark.range(1, -2).select("id")
- assert(res3.count == 0)
-
- // start is positive, end is negative, step is negative
- val res4 = spark.range(1, -2, -2, 6).select("id")
- assert(res4.count == 2)
- assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0)))
-
- // start, end, step are negative
- val res5 = spark.range(-3, -8, -2, 1).select("id")
- assert(res5.count == 3)
- assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15)))
-
- // start, end are negative, step is positive
- val res6 = spark.range(-8, -4, 2, 1).select("id")
- assert(res6.count == 2)
- assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14)))
-
- val res7 = spark.range(-10, -9, -20, 1).select("id")
- assert(res7.count == 0)
-
- val res8 = spark.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id")
- assert(res8.count == 3)
- assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3)))
-
- val res9 = spark.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id")
- assert(res9.count == 2)
- assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1)))
-
- // only end provided as argument
- val res10 = spark.range(10).select("id")
- assert(res10.count == 10)
- assert(res10.agg(sum("id")).as("sumid").collect() === Seq(Row(45)))
-
- val res11 = spark.range(-1).select("id")
- assert(res11.count == 0)
-
- // using the default slice number
- val res12 = spark.range(3, 15, 3).select("id")
- assert(res12.count == 4)
- assert(res12.agg(sum("id")).as("sumid").collect() === Seq(Row(30)))
- }
-
test("SPARK-8621: support empty string column name") {
val df = Seq(Tuple1(1)).toDF("").as("t")
// We should allow empty string as column name
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/InputGeneratedOutputMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/InputGeneratedOutputMetricsSuite.scala
new file mode 100644
index 0000000000..ddd7a03e80
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/InputGeneratedOutputMetricsSuite.scala
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import java.io.File
+
+import org.scalatest.concurrent.Eventually
+
+import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
+import org.apache.spark.sql.{DataFrame, QueryTest}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.util.Utils
+
+class InputGeneratedOutputMetricsSuite extends QueryTest with SharedSQLContext with Eventually {
+
+ test("Range query input/output/generated metrics") {
+ val numRows = 150L
+ val numSelectedRows = 100L
+ val res = MetricsTestHelper.runAndGetMetrics(spark.range(0, numRows, 1).
+ filter(x => x < numSelectedRows).toDF())
+
+ assert(res.recordsRead.sum === 0)
+ assert(res.shuffleRecordsRead.sum === 0)
+ assert(res.generatedRows === numRows :: Nil)
+ assert(res.outputRows === numSelectedRows :: numRows :: Nil)
+ }
+
+ test("Input/output/generated metrics with repartitioning") {
+ val numRows = 100L
+ val res = MetricsTestHelper.runAndGetMetrics(
+ spark.range(0, numRows).repartition(3).filter(x => x % 5 == 0).toDF())
+
+ assert(res.recordsRead.sum === 0)
+ assert(res.shuffleRecordsRead.sum === numRows)
+ assert(res.generatedRows === numRows :: Nil)
+ assert(res.outputRows === 20 :: numRows :: Nil)
+ }
+
+ test("Input/output/generated metrics with more repartitioning") {
+ withTempDir { tempDir =>
+ val dir = new File(tempDir, "pqS").getCanonicalPath
+
+ spark.range(10).write.parquet(dir)
+ spark.read.parquet(dir).createOrReplaceTempView("pqS")
+
+ val res = MetricsTestHelper.runAndGetMetrics(
+ spark.range(0, 30).repartition(3).crossJoin(sql("select * from pqS")).repartition(2)
+ .toDF()
+ )
+
+ assert(res.recordsRead.sum == 10)
+ assert(res.shuffleRecordsRead.sum == 3 * 10 + 2 * 150)
+ assert(res.generatedRows == 30 :: Nil)
+ assert(res.outputRows == 10 :: 30 :: 300 :: Nil)
+ }
+ }
+}
+
+object MetricsTestHelper {
+ case class AggregatedMetricsResult(
+ recordsRead: List[Long],
+ shuffleRecordsRead: List[Long],
+ generatedRows: List[Long],
+ outputRows: List[Long])
+
+ private[this] def extractMetricValues(
+ df: DataFrame,
+ metricValues: Map[Long, String],
+ metricName: String): List[Long] = {
+ df.queryExecution.executedPlan.collect {
+ case plan if plan.metrics.contains(metricName) =>
+ metricValues(plan.metrics(metricName).id).toLong
+ }.toList.sorted
+ }
+
+ def runAndGetMetrics(df: DataFrame, useWholeStageCodeGen: Boolean = false):
+ AggregatedMetricsResult = {
+ val spark = df.sparkSession
+ val sparkContext = spark.sparkContext
+
+ var recordsRead = List[Long]()
+ var shuffleRecordsRead = List[Long]()
+ val listener = new SparkListener() {
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
+ if (taskEnd.taskMetrics != null) {
+ recordsRead = taskEnd.taskMetrics.inputMetrics.recordsRead ::
+ recordsRead
+ shuffleRecordsRead = taskEnd.taskMetrics.shuffleReadMetrics.recordsRead ::
+ shuffleRecordsRead
+ }
+ }
+ }
+
+ val oldExecutionIds = spark.sharedState.listener.executionIdToData.keySet
+
+ val prevUseWholeStageCodeGen =
+ spark.sessionState.conf.getConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED)
+ try {
+ spark.sessionState.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, useWholeStageCodeGen)
+ sparkContext.listenerBus.waitUntilEmpty(10000)
+ sparkContext.addSparkListener(listener)
+ df.collect()
+ sparkContext.listenerBus.waitUntilEmpty(10000)
+ } finally {
+ spark.sessionState.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, prevUseWholeStageCodeGen)
+ }
+
+ val executionId = spark.sharedState.listener.executionIdToData.keySet.diff(oldExecutionIds).head
+ val metricValues = spark.sharedState.listener.getExecutionMetrics(executionId)
+ val outputRes = extractMetricValues(df, metricValues, "numOutputRows")
+ val generatedRes = extractMetricValues(df, metricValues, "numGeneratedRows")
+
+ AggregatedMetricsResult(recordsRead.sorted, shuffleRecordsRead.sorted, generatedRes, outputRes)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index 039625421e..14fbe9f443 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -31,6 +31,7 @@ import org.apache.spark.sql.execution.DataSourceScanExec
import org.apache.spark.sql.execution.command.ExplainCommand
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation, JdbcUtils}
+import org.apache.spark.sql.execution.MetricsTestHelper
import org.apache.spark.sql.sources._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
@@ -915,4 +916,13 @@ class JDBCSuite extends SparkFunSuite
}.getMessage
assert(e2.contains("User specified schema not supported with `jdbc`"))
}
+
+ test("Input/generated/output metrics on JDBC") {
+ val foobarCnt = spark.table("foobar").count()
+ val res = MetricsTestHelper.runAndGetMetrics(sql("SELECT * FROM foobar").toDF())
+ assert(res.recordsRead === foobarCnt :: Nil)
+ assert(res.shuffleRecordsRead.sum === 0)
+ assert(res.generatedRows.isEmpty)
+ assert(res.outputRows === foobarCnt :: Nil)
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala
index ec620c2403..35c41b531c 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.hive.execution
import org.scalatest.BeforeAndAfterAll
+import org.apache.spark.sql.execution.MetricsTestHelper
import org.apache.spark.sql.hive.test.TestHive
/**
@@ -47,4 +48,22 @@ class HiveSerDeSuite extends HiveComparisonTest with BeforeAndAfterAll {
createQueryTest("Read with AvroSerDe", "SELECT * FROM episodes")
createQueryTest("Read Partitioned with AvroSerDe", "SELECT * FROM episodes_part")
+
+ test("Test input/generated/output metrics") {
+ import TestHive._
+
+ val episodesCnt = sql("select * from episodes").count()
+ val episodesRes = MetricsTestHelper.runAndGetMetrics(sql("select * from episodes").toDF())
+ assert(episodesRes.recordsRead === episodesCnt :: Nil)
+ assert(episodesRes.shuffleRecordsRead.sum === 0)
+ assert(episodesRes.generatedRows.isEmpty)
+ assert(episodesRes.outputRows === episodesCnt :: Nil)
+
+ val serdeinsCnt = sql("select * from serdeins").count()
+ val serdeinsRes = MetricsTestHelper.runAndGetMetrics(sql("select * from serdeins").toDF())
+ assert(serdeinsRes.recordsRead === serdeinsCnt :: Nil)
+ assert(serdeinsRes.shuffleRecordsRead.sum === 0)
+ assert(serdeinsRes.generatedRows.isEmpty)
+ assert(serdeinsRes.outputRows === serdeinsCnt :: Nil)
+ }
}