aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala82
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala130
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala53
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/InputGeneratedOutputMetricsSuite.scala131
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala10
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala19
6 files changed, 350 insertions, 75 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index fb90799534..792fb3e795 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -339,7 +339,8 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
override val output: Seq[Attribute] = range.output
override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
+ "numGeneratedRows" -> SQLMetrics.createMetric(sparkContext, "number of generated rows"))
// output attributes should not affect the results
override lazy val cleanArgs: Seq[Any] = Seq(start, step, numSlices, numElements)
@@ -351,24 +352,37 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
protected override def doProduce(ctx: CodegenContext): String = {
val numOutput = metricTerm(ctx, "numOutputRows")
+ val numGenerated = metricTerm(ctx, "numGeneratedRows")
val initTerm = ctx.freshName("initRange")
ctx.addMutableState("boolean", initTerm, s"$initTerm = false;")
- val partitionEnd = ctx.freshName("partitionEnd")
- ctx.addMutableState("long", partitionEnd, s"$partitionEnd = 0L;")
val number = ctx.freshName("number")
ctx.addMutableState("long", number, s"$number = 0L;")
- val overflow = ctx.freshName("overflow")
- ctx.addMutableState("boolean", overflow, s"$overflow = false;")
val value = ctx.freshName("value")
val ev = ExprCode("", "false", value)
val BigInt = classOf[java.math.BigInteger].getName
- val checkEnd = if (step > 0) {
- s"$number < $partitionEnd"
- } else {
- s"$number > $partitionEnd"
- }
+
+ // In order to periodically update the metrics without inflicting performance penalty, this
+ // operator produces elements in batches. After a batch is complete, the metrics are updated
+ // and a new batch is started.
+ // In the implementation below, the code in the inner loop is producing all the values
+ // within a batch, while the code in the outer loop is setting batch parameters and updating
+ // the metrics.
+
+ // Once number == batchEnd, it's time to progress to the next batch.
+ val batchEnd = ctx.freshName("batchEnd")
+ ctx.addMutableState("long", batchEnd, s"$batchEnd = 0;")
+
+ // How many values should still be generated by this range operator.
+ val numElementsTodo = ctx.freshName("numElementsTodo")
+ ctx.addMutableState("long", numElementsTodo, s"$numElementsTodo = 0L;")
+
+ // How many values should be generated in the next batch.
+ val nextBatchTodo = ctx.freshName("nextBatchTodo")
+
+ // The default size of a batch.
+ val batchSize = 1000L
ctx.addNewFunction("initRange",
s"""
@@ -378,6 +392,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
| $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L);
| $BigInt step = $BigInt.valueOf(${step}L);
| $BigInt start = $BigInt.valueOf(${start}L);
+ | long partitionEnd;
|
| $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start);
| if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
@@ -387,18 +402,26 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
| } else {
| $number = st.longValue();
| }
+ | $batchEnd = $number;
|
| $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice)
| .multiply(step).add(start);
| if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
- | $partitionEnd = Long.MAX_VALUE;
+ | partitionEnd = Long.MAX_VALUE;
| } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) {
- | $partitionEnd = Long.MIN_VALUE;
+ | partitionEnd = Long.MIN_VALUE;
| } else {
- | $partitionEnd = end.longValue();
+ | partitionEnd = end.longValue();
| }
|
- | $numOutput.add(($partitionEnd - $number) / ${step}L);
+ | $BigInt startToEnd = $BigInt.valueOf(partitionEnd).subtract(
+ | $BigInt.valueOf($number));
+ | $numElementsTodo = startToEnd.divide(step).longValue();
+ | if ($numElementsTodo < 0) {
+ | $numElementsTodo = 0;
+ | } else if (startToEnd.remainder(step).compareTo($BigInt.valueOf(0L)) != 0) {
+ | $numElementsTodo++;
+ | }
| }
""".stripMargin)
@@ -412,20 +435,34 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
| initRange(partitionIndex);
| }
|
- | while (!$overflow && $checkEnd) {
- | long $value = $number;
- | $number += ${step}L;
- | if ($number < $value ^ ${step}L < 0) {
- | $overflow = true;
- | }
- | ${consume(ctx, Seq(ev))}
- | if (shouldStop()) return;
+ | while (true) {
+ | while ($number != $batchEnd) {
+ | long $value = $number;
+ | $number += ${step}L;
+ | ${consume(ctx, Seq(ev))}
+ | if (shouldStop()) return;
+ | }
+ |
+ | long $nextBatchTodo;
+ | if ($numElementsTodo > ${batchSize}L) {
+ | $nextBatchTodo = ${batchSize}L;
+ | $numElementsTodo -= ${batchSize}L;
+ | } else {
+ | $nextBatchTodo = $numElementsTodo;
+ | $numElementsTodo = 0;
+ | if ($nextBatchTodo == 0) break;
+ | }
+ | $numOutput.add($nextBatchTodo);
+ | $numGenerated.add($nextBatchTodo);
+ |
+ | $batchEnd += $nextBatchTodo * ${step}L;
| }
""".stripMargin
}
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
+ val numGeneratedRows = longMetric("numGeneratedRows")
sqlContext
.sparkContext
.parallelize(0 until numSlices, numSlices)
@@ -469,6 +506,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
}
numOutputRows += 1
+ numGeneratedRows += 1
unsafeRow.setLong(0, ret)
unsafeRow
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala
new file mode 100644
index 0000000000..6d2d776c92
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import scala.math.abs
+import scala.util.Random
+
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSQLContext
+
+class DataFrameRangeSuite extends QueryTest with SharedSQLContext {
+
+ test("SPARK-7150 range api") {
+ // numSlice is greater than length
+ val res1 = spark.range(0, 10, 1, 15).select("id")
+ assert(res1.count == 10)
+ assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45)))
+
+ val res2 = spark.range(3, 15, 3, 2).select("id")
+ assert(res2.count == 4)
+ assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30)))
+
+ val res3 = spark.range(1, -2).select("id")
+ assert(res3.count == 0)
+
+ // start is positive, end is negative, step is negative
+ val res4 = spark.range(1, -2, -2, 6).select("id")
+ assert(res4.count == 2)
+ assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0)))
+
+ // start, end, step are negative
+ val res5 = spark.range(-3, -8, -2, 1).select("id")
+ assert(res5.count == 3)
+ assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15)))
+
+ // start, end are negative, step is positive
+ val res6 = spark.range(-8, -4, 2, 1).select("id")
+ assert(res6.count == 2)
+ assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14)))
+
+ val res7 = spark.range(-10, -9, -20, 1).select("id")
+ assert(res7.count == 0)
+
+ val res8 = spark.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id")
+ assert(res8.count == 3)
+ assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3)))
+
+ val res9 = spark.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id")
+ assert(res9.count == 2)
+ assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1)))
+
+ // only end provided as argument
+ val res10 = spark.range(10).select("id")
+ assert(res10.count == 10)
+ assert(res10.agg(sum("id")).as("sumid").collect() === Seq(Row(45)))
+
+ val res11 = spark.range(-1).select("id")
+ assert(res11.count == 0)
+
+ // using the default slice number
+ val res12 = spark.range(3, 15, 3).select("id")
+ assert(res12.count == 4)
+ assert(res12.agg(sum("id")).as("sumid").collect() === Seq(Row(30)))
+
+ // difference between range start and end does not fit in a 64-bit integer
+ val n = 9L * 1000 * 1000 * 1000 * 1000 * 1000 * 1000
+ val res13 = spark.range(-n, n, n / 9).select("id")
+ assert(res13.count == 18)
+ }
+
+ test("Range with randomized parameters") {
+ val MAX_NUM_STEPS = 10L * 1000
+
+ val seed = System.currentTimeMillis()
+ val random = new Random(seed)
+
+ def randomBound(): Long = {
+ val n = if (random.nextBoolean()) {
+ random.nextLong() % (Long.MaxValue / (100 * MAX_NUM_STEPS))
+ } else {
+ random.nextLong() / 2
+ }
+ if (random.nextBoolean()) n else -n
+ }
+
+ for (l <- 1 to 10) {
+ val start = randomBound()
+ val end = randomBound()
+ val numSteps = (abs(random.nextLong()) % MAX_NUM_STEPS) + 1
+ val stepAbs = (abs(end - start) / numSteps) + 1
+ val step = if (start < end) stepAbs else -stepAbs
+ val partitions = random.nextInt(20) + 1
+
+ val expCount = (start until end by step).size
+ val expSum = (start until end by step).sum
+
+ for (codegen <- List(false, true)) {
+ withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString()) {
+ val res = spark.range(start, end, step, partitions).toDF("id").
+ agg(count("id"), sum("id")).collect()
+
+ withClue(s"seed = $seed start = $start end = $end step = $step partitions = " +
+ s"$partitions codegen = $codegen") {
+ assert(!res.isEmpty)
+ assert(res.head.getLong(0) == expCount)
+ if (expCount > 0) {
+ assert(res.head.getLong(1) == expSum)
+ }
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 6a190b98ea..e6338ab7cd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -979,59 +979,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
Seq(Row(2, 1, 2), Row(1, 2, 1), Row(1, 1, 1), Row(2, 2, 2)))
}
- test("SPARK-7150 range api") {
- // numSlice is greater than length
- val res1 = spark.range(0, 10, 1, 15).select("id")
- assert(res1.count == 10)
- assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45)))
-
- val res2 = spark.range(3, 15, 3, 2).select("id")
- assert(res2.count == 4)
- assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30)))
-
- val res3 = spark.range(1, -2).select("id")
- assert(res3.count == 0)
-
- // start is positive, end is negative, step is negative
- val res4 = spark.range(1, -2, -2, 6).select("id")
- assert(res4.count == 2)
- assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0)))
-
- // start, end, step are negative
- val res5 = spark.range(-3, -8, -2, 1).select("id")
- assert(res5.count == 3)
- assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15)))
-
- // start, end are negative, step is positive
- val res6 = spark.range(-8, -4, 2, 1).select("id")
- assert(res6.count == 2)
- assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14)))
-
- val res7 = spark.range(-10, -9, -20, 1).select("id")
- assert(res7.count == 0)
-
- val res8 = spark.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id")
- assert(res8.count == 3)
- assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3)))
-
- val res9 = spark.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id")
- assert(res9.count == 2)
- assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1)))
-
- // only end provided as argument
- val res10 = spark.range(10).select("id")
- assert(res10.count == 10)
- assert(res10.agg(sum("id")).as("sumid").collect() === Seq(Row(45)))
-
- val res11 = spark.range(-1).select("id")
- assert(res11.count == 0)
-
- // using the default slice number
- val res12 = spark.range(3, 15, 3).select("id")
- assert(res12.count == 4)
- assert(res12.agg(sum("id")).as("sumid").collect() === Seq(Row(30)))
- }
-
test("SPARK-8621: support empty string column name") {
val df = Seq(Tuple1(1)).toDF("").as("t")
// We should allow empty string as column name
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/InputGeneratedOutputMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/InputGeneratedOutputMetricsSuite.scala
new file mode 100644
index 0000000000..ddd7a03e80
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/InputGeneratedOutputMetricsSuite.scala
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import java.io.File
+
+import org.scalatest.concurrent.Eventually
+
+import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
+import org.apache.spark.sql.{DataFrame, QueryTest}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.util.Utils
+
+class InputGeneratedOutputMetricsSuite extends QueryTest with SharedSQLContext with Eventually {
+
+ test("Range query input/output/generated metrics") {
+ val numRows = 150L
+ val numSelectedRows = 100L
+ val res = MetricsTestHelper.runAndGetMetrics(spark.range(0, numRows, 1).
+ filter(x => x < numSelectedRows).toDF())
+
+ assert(res.recordsRead.sum === 0)
+ assert(res.shuffleRecordsRead.sum === 0)
+ assert(res.generatedRows === numRows :: Nil)
+ assert(res.outputRows === numSelectedRows :: numRows :: Nil)
+ }
+
+ test("Input/output/generated metrics with repartitioning") {
+ val numRows = 100L
+ val res = MetricsTestHelper.runAndGetMetrics(
+ spark.range(0, numRows).repartition(3).filter(x => x % 5 == 0).toDF())
+
+ assert(res.recordsRead.sum === 0)
+ assert(res.shuffleRecordsRead.sum === numRows)
+ assert(res.generatedRows === numRows :: Nil)
+ assert(res.outputRows === 20 :: numRows :: Nil)
+ }
+
+ test("Input/output/generated metrics with more repartitioning") {
+ withTempDir { tempDir =>
+ val dir = new File(tempDir, "pqS").getCanonicalPath
+
+ spark.range(10).write.parquet(dir)
+ spark.read.parquet(dir).createOrReplaceTempView("pqS")
+
+ val res = MetricsTestHelper.runAndGetMetrics(
+ spark.range(0, 30).repartition(3).crossJoin(sql("select * from pqS")).repartition(2)
+ .toDF()
+ )
+
+ assert(res.recordsRead.sum == 10)
+ assert(res.shuffleRecordsRead.sum == 3 * 10 + 2 * 150)
+ assert(res.generatedRows == 30 :: Nil)
+ assert(res.outputRows == 10 :: 30 :: 300 :: Nil)
+ }
+ }
+}
+
+object MetricsTestHelper {
+ case class AggregatedMetricsResult(
+ recordsRead: List[Long],
+ shuffleRecordsRead: List[Long],
+ generatedRows: List[Long],
+ outputRows: List[Long])
+
+ private[this] def extractMetricValues(
+ df: DataFrame,
+ metricValues: Map[Long, String],
+ metricName: String): List[Long] = {
+ df.queryExecution.executedPlan.collect {
+ case plan if plan.metrics.contains(metricName) =>
+ metricValues(plan.metrics(metricName).id).toLong
+ }.toList.sorted
+ }
+
+ def runAndGetMetrics(df: DataFrame, useWholeStageCodeGen: Boolean = false):
+ AggregatedMetricsResult = {
+ val spark = df.sparkSession
+ val sparkContext = spark.sparkContext
+
+ var recordsRead = List[Long]()
+ var shuffleRecordsRead = List[Long]()
+ val listener = new SparkListener() {
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
+ if (taskEnd.taskMetrics != null) {
+ recordsRead = taskEnd.taskMetrics.inputMetrics.recordsRead ::
+ recordsRead
+ shuffleRecordsRead = taskEnd.taskMetrics.shuffleReadMetrics.recordsRead ::
+ shuffleRecordsRead
+ }
+ }
+ }
+
+ val oldExecutionIds = spark.sharedState.listener.executionIdToData.keySet
+
+ val prevUseWholeStageCodeGen =
+ spark.sessionState.conf.getConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED)
+ try {
+ spark.sessionState.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, useWholeStageCodeGen)
+ sparkContext.listenerBus.waitUntilEmpty(10000)
+ sparkContext.addSparkListener(listener)
+ df.collect()
+ sparkContext.listenerBus.waitUntilEmpty(10000)
+ } finally {
+ spark.sessionState.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, prevUseWholeStageCodeGen)
+ }
+
+ val executionId = spark.sharedState.listener.executionIdToData.keySet.diff(oldExecutionIds).head
+ val metricValues = spark.sharedState.listener.getExecutionMetrics(executionId)
+ val outputRes = extractMetricValues(df, metricValues, "numOutputRows")
+ val generatedRes = extractMetricValues(df, metricValues, "numGeneratedRows")
+
+ AggregatedMetricsResult(recordsRead.sorted, shuffleRecordsRead.sorted, generatedRes, outputRes)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index 039625421e..14fbe9f443 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -31,6 +31,7 @@ import org.apache.spark.sql.execution.DataSourceScanExec
import org.apache.spark.sql.execution.command.ExplainCommand
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation, JdbcUtils}
+import org.apache.spark.sql.execution.MetricsTestHelper
import org.apache.spark.sql.sources._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
@@ -915,4 +916,13 @@ class JDBCSuite extends SparkFunSuite
}.getMessage
assert(e2.contains("User specified schema not supported with `jdbc`"))
}
+
+ test("Input/generated/output metrics on JDBC") {
+ val foobarCnt = spark.table("foobar").count()
+ val res = MetricsTestHelper.runAndGetMetrics(sql("SELECT * FROM foobar").toDF())
+ assert(res.recordsRead === foobarCnt :: Nil)
+ assert(res.shuffleRecordsRead.sum === 0)
+ assert(res.generatedRows.isEmpty)
+ assert(res.outputRows === foobarCnt :: Nil)
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala
index ec620c2403..35c41b531c 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.hive.execution
import org.scalatest.BeforeAndAfterAll
+import org.apache.spark.sql.execution.MetricsTestHelper
import org.apache.spark.sql.hive.test.TestHive
/**
@@ -47,4 +48,22 @@ class HiveSerDeSuite extends HiveComparisonTest with BeforeAndAfterAll {
createQueryTest("Read with AvroSerDe", "SELECT * FROM episodes")
createQueryTest("Read Partitioned with AvroSerDe", "SELECT * FROM episodes_part")
+
+ test("Test input/generated/output metrics") {
+ import TestHive._
+
+ val episodesCnt = sql("select * from episodes").count()
+ val episodesRes = MetricsTestHelper.runAndGetMetrics(sql("select * from episodes").toDF())
+ assert(episodesRes.recordsRead === episodesCnt :: Nil)
+ assert(episodesRes.shuffleRecordsRead.sum === 0)
+ assert(episodesRes.generatedRows.isEmpty)
+ assert(episodesRes.outputRows === episodesCnt :: Nil)
+
+ val serdeinsCnt = sql("select * from serdeins").count()
+ val serdeinsRes = MetricsTestHelper.runAndGetMetrics(sql("select * from serdeins").toDF())
+ assert(serdeinsRes.recordsRead === serdeinsCnt :: Nil)
+ assert(serdeinsRes.shuffleRecordsRead.sum === 0)
+ assert(serdeinsRes.generatedRows.isEmpty)
+ assert(serdeinsRes.outputRows === serdeinsCnt :: Nil)
+ }
}