diff options
6 files changed, 350 insertions, 75 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index fb90799534..792fb3e795 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -339,7 +339,8 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) override val output: Seq[Attribute] = range.output override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "numGeneratedRows" -> SQLMetrics.createMetric(sparkContext, "number of generated rows")) // output attributes should not affect the results override lazy val cleanArgs: Seq[Any] = Seq(start, step, numSlices, numElements) @@ -351,24 +352,37 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) protected override def doProduce(ctx: CodegenContext): String = { val numOutput = metricTerm(ctx, "numOutputRows") + val numGenerated = metricTerm(ctx, "numGeneratedRows") val initTerm = ctx.freshName("initRange") ctx.addMutableState("boolean", initTerm, s"$initTerm = false;") - val partitionEnd = ctx.freshName("partitionEnd") - ctx.addMutableState("long", partitionEnd, s"$partitionEnd = 0L;") val number = ctx.freshName("number") ctx.addMutableState("long", number, s"$number = 0L;") - val overflow = ctx.freshName("overflow") - ctx.addMutableState("boolean", overflow, s"$overflow = false;") val value = ctx.freshName("value") val ev = ExprCode("", "false", value) val BigInt = classOf[java.math.BigInteger].getName - val checkEnd = if (step > 0) { - s"$number < $partitionEnd" - } else { - s"$number > $partitionEnd" - } + + // In order to periodically update the metrics without inflicting performance penalty, this + // operator produces elements in batches. After a batch is complete, the metrics are updated + // and a new batch is started. + // In the implementation below, the code in the inner loop is producing all the values + // within a batch, while the code in the outer loop is setting batch parameters and updating + // the metrics. + + // Once number == batchEnd, it's time to progress to the next batch. + val batchEnd = ctx.freshName("batchEnd") + ctx.addMutableState("long", batchEnd, s"$batchEnd = 0;") + + // How many values should still be generated by this range operator. + val numElementsTodo = ctx.freshName("numElementsTodo") + ctx.addMutableState("long", numElementsTodo, s"$numElementsTodo = 0L;") + + // How many values should be generated in the next batch. + val nextBatchTodo = ctx.freshName("nextBatchTodo") + + // The default size of a batch. + val batchSize = 1000L ctx.addNewFunction("initRange", s""" @@ -378,6 +392,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) | $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L); | $BigInt step = $BigInt.valueOf(${step}L); | $BigInt start = $BigInt.valueOf(${start}L); + | long partitionEnd; | | $BigInt st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); | if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { @@ -387,18 +402,26 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) | } else { | $number = st.longValue(); | } + | $batchEnd = $number; | | $BigInt end = index.add($BigInt.ONE).multiply(numElement).divide(numSlice) | .multiply(step).add(start); | if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) { - | $partitionEnd = Long.MAX_VALUE; + | partitionEnd = Long.MAX_VALUE; | } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) { - | $partitionEnd = Long.MIN_VALUE; + | partitionEnd = Long.MIN_VALUE; | } else { - | $partitionEnd = end.longValue(); + | partitionEnd = end.longValue(); | } | - | $numOutput.add(($partitionEnd - $number) / ${step}L); + | $BigInt startToEnd = $BigInt.valueOf(partitionEnd).subtract( + | $BigInt.valueOf($number)); + | $numElementsTodo = startToEnd.divide(step).longValue(); + | if ($numElementsTodo < 0) { + | $numElementsTodo = 0; + | } else if (startToEnd.remainder(step).compareTo($BigInt.valueOf(0L)) != 0) { + | $numElementsTodo++; + | } | } """.stripMargin) @@ -412,20 +435,34 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) | initRange(partitionIndex); | } | - | while (!$overflow && $checkEnd) { - | long $value = $number; - | $number += ${step}L; - | if ($number < $value ^ ${step}L < 0) { - | $overflow = true; - | } - | ${consume(ctx, Seq(ev))} - | if (shouldStop()) return; + | while (true) { + | while ($number != $batchEnd) { + | long $value = $number; + | $number += ${step}L; + | ${consume(ctx, Seq(ev))} + | if (shouldStop()) return; + | } + | + | long $nextBatchTodo; + | if ($numElementsTodo > ${batchSize}L) { + | $nextBatchTodo = ${batchSize}L; + | $numElementsTodo -= ${batchSize}L; + | } else { + | $nextBatchTodo = $numElementsTodo; + | $numElementsTodo = 0; + | if ($nextBatchTodo == 0) break; + | } + | $numOutput.add($nextBatchTodo); + | $numGenerated.add($nextBatchTodo); + | + | $batchEnd += $nextBatchTodo * ${step}L; | } """.stripMargin } protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") + val numGeneratedRows = longMetric("numGeneratedRows") sqlContext .sparkContext .parallelize(0 until numSlices, numSlices) @@ -469,6 +506,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) } numOutputRows += 1 + numGeneratedRows += 1 unsafeRow.setLong(0, ret) unsafeRow } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala new file mode 100644 index 0000000000..6d2d776c92 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.math.abs +import scala.util.Random + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext + +class DataFrameRangeSuite extends QueryTest with SharedSQLContext { + + test("SPARK-7150 range api") { + // numSlice is greater than length + val res1 = spark.range(0, 10, 1, 15).select("id") + assert(res1.count == 10) + assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) + + val res2 = spark.range(3, 15, 3, 2).select("id") + assert(res2.count == 4) + assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30))) + + val res3 = spark.range(1, -2).select("id") + assert(res3.count == 0) + + // start is positive, end is negative, step is negative + val res4 = spark.range(1, -2, -2, 6).select("id") + assert(res4.count == 2) + assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0))) + + // start, end, step are negative + val res5 = spark.range(-3, -8, -2, 1).select("id") + assert(res5.count == 3) + assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15))) + + // start, end are negative, step is positive + val res6 = spark.range(-8, -4, 2, 1).select("id") + assert(res6.count == 2) + assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14))) + + val res7 = spark.range(-10, -9, -20, 1).select("id") + assert(res7.count == 0) + + val res8 = spark.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id") + assert(res8.count == 3) + assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3))) + + val res9 = spark.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id") + assert(res9.count == 2) + assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1))) + + // only end provided as argument + val res10 = spark.range(10).select("id") + assert(res10.count == 10) + assert(res10.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) + + val res11 = spark.range(-1).select("id") + assert(res11.count == 0) + + // using the default slice number + val res12 = spark.range(3, 15, 3).select("id") + assert(res12.count == 4) + assert(res12.agg(sum("id")).as("sumid").collect() === Seq(Row(30))) + + // difference between range start and end does not fit in a 64-bit integer + val n = 9L * 1000 * 1000 * 1000 * 1000 * 1000 * 1000 + val res13 = spark.range(-n, n, n / 9).select("id") + assert(res13.count == 18) + } + + test("Range with randomized parameters") { + val MAX_NUM_STEPS = 10L * 1000 + + val seed = System.currentTimeMillis() + val random = new Random(seed) + + def randomBound(): Long = { + val n = if (random.nextBoolean()) { + random.nextLong() % (Long.MaxValue / (100 * MAX_NUM_STEPS)) + } else { + random.nextLong() / 2 + } + if (random.nextBoolean()) n else -n + } + + for (l <- 1 to 10) { + val start = randomBound() + val end = randomBound() + val numSteps = (abs(random.nextLong()) % MAX_NUM_STEPS) + 1 + val stepAbs = (abs(end - start) / numSteps) + 1 + val step = if (start < end) stepAbs else -stepAbs + val partitions = random.nextInt(20) + 1 + + val expCount = (start until end by step).size + val expSum = (start until end by step).sum + + for (codegen <- List(false, true)) { + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegen.toString()) { + val res = spark.range(start, end, step, partitions).toDF("id"). + agg(count("id"), sum("id")).collect() + + withClue(s"seed = $seed start = $start end = $end step = $step partitions = " + + s"$partitions codegen = $codegen") { + assert(!res.isEmpty) + assert(res.head.getLong(0) == expCount) + if (expCount > 0) { + assert(res.head.getLong(1) == expSum) + } + } + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 6a190b98ea..e6338ab7cd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -979,59 +979,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Seq(Row(2, 1, 2), Row(1, 2, 1), Row(1, 1, 1), Row(2, 2, 2))) } - test("SPARK-7150 range api") { - // numSlice is greater than length - val res1 = spark.range(0, 10, 1, 15).select("id") - assert(res1.count == 10) - assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) - - val res2 = spark.range(3, 15, 3, 2).select("id") - assert(res2.count == 4) - assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30))) - - val res3 = spark.range(1, -2).select("id") - assert(res3.count == 0) - - // start is positive, end is negative, step is negative - val res4 = spark.range(1, -2, -2, 6).select("id") - assert(res4.count == 2) - assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0))) - - // start, end, step are negative - val res5 = spark.range(-3, -8, -2, 1).select("id") - assert(res5.count == 3) - assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15))) - - // start, end are negative, step is positive - val res6 = spark.range(-8, -4, 2, 1).select("id") - assert(res6.count == 2) - assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14))) - - val res7 = spark.range(-10, -9, -20, 1).select("id") - assert(res7.count == 0) - - val res8 = spark.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id") - assert(res8.count == 3) - assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3))) - - val res9 = spark.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id") - assert(res9.count == 2) - assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1))) - - // only end provided as argument - val res10 = spark.range(10).select("id") - assert(res10.count == 10) - assert(res10.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) - - val res11 = spark.range(-1).select("id") - assert(res11.count == 0) - - // using the default slice number - val res12 = spark.range(3, 15, 3).select("id") - assert(res12.count == 4) - assert(res12.agg(sum("id")).as("sumid").collect() === Seq(Row(30))) - } - test("SPARK-8621: support empty string column name") { val df = Seq(Tuple1(1)).toDF("").as("t") // We should allow empty string as column name diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/InputGeneratedOutputMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/InputGeneratedOutputMetricsSuite.scala new file mode 100644 index 0000000000..ddd7a03e80 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/InputGeneratedOutputMetricsSuite.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.io.File + +import org.scalatest.concurrent.Eventually + +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} +import org.apache.spark.sql.{DataFrame, QueryTest} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.Utils + +class InputGeneratedOutputMetricsSuite extends QueryTest with SharedSQLContext with Eventually { + + test("Range query input/output/generated metrics") { + val numRows = 150L + val numSelectedRows = 100L + val res = MetricsTestHelper.runAndGetMetrics(spark.range(0, numRows, 1). + filter(x => x < numSelectedRows).toDF()) + + assert(res.recordsRead.sum === 0) + assert(res.shuffleRecordsRead.sum === 0) + assert(res.generatedRows === numRows :: Nil) + assert(res.outputRows === numSelectedRows :: numRows :: Nil) + } + + test("Input/output/generated metrics with repartitioning") { + val numRows = 100L + val res = MetricsTestHelper.runAndGetMetrics( + spark.range(0, numRows).repartition(3).filter(x => x % 5 == 0).toDF()) + + assert(res.recordsRead.sum === 0) + assert(res.shuffleRecordsRead.sum === numRows) + assert(res.generatedRows === numRows :: Nil) + assert(res.outputRows === 20 :: numRows :: Nil) + } + + test("Input/output/generated metrics with more repartitioning") { + withTempDir { tempDir => + val dir = new File(tempDir, "pqS").getCanonicalPath + + spark.range(10).write.parquet(dir) + spark.read.parquet(dir).createOrReplaceTempView("pqS") + + val res = MetricsTestHelper.runAndGetMetrics( + spark.range(0, 30).repartition(3).crossJoin(sql("select * from pqS")).repartition(2) + .toDF() + ) + + assert(res.recordsRead.sum == 10) + assert(res.shuffleRecordsRead.sum == 3 * 10 + 2 * 150) + assert(res.generatedRows == 30 :: Nil) + assert(res.outputRows == 10 :: 30 :: 300 :: Nil) + } + } +} + +object MetricsTestHelper { + case class AggregatedMetricsResult( + recordsRead: List[Long], + shuffleRecordsRead: List[Long], + generatedRows: List[Long], + outputRows: List[Long]) + + private[this] def extractMetricValues( + df: DataFrame, + metricValues: Map[Long, String], + metricName: String): List[Long] = { + df.queryExecution.executedPlan.collect { + case plan if plan.metrics.contains(metricName) => + metricValues(plan.metrics(metricName).id).toLong + }.toList.sorted + } + + def runAndGetMetrics(df: DataFrame, useWholeStageCodeGen: Boolean = false): + AggregatedMetricsResult = { + val spark = df.sparkSession + val sparkContext = spark.sparkContext + + var recordsRead = List[Long]() + var shuffleRecordsRead = List[Long]() + val listener = new SparkListener() { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { + if (taskEnd.taskMetrics != null) { + recordsRead = taskEnd.taskMetrics.inputMetrics.recordsRead :: + recordsRead + shuffleRecordsRead = taskEnd.taskMetrics.shuffleReadMetrics.recordsRead :: + shuffleRecordsRead + } + } + } + + val oldExecutionIds = spark.sharedState.listener.executionIdToData.keySet + + val prevUseWholeStageCodeGen = + spark.sessionState.conf.getConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED) + try { + spark.sessionState.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, useWholeStageCodeGen) + sparkContext.listenerBus.waitUntilEmpty(10000) + sparkContext.addSparkListener(listener) + df.collect() + sparkContext.listenerBus.waitUntilEmpty(10000) + } finally { + spark.sessionState.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, prevUseWholeStageCodeGen) + } + + val executionId = spark.sharedState.listener.executionIdToData.keySet.diff(oldExecutionIds).head + val metricValues = spark.sharedState.listener.getExecutionMetrics(executionId) + val outputRes = extractMetricValues(df, metricValues, "numOutputRows") + val generatedRes = extractMetricValues(df, metricValues, "numGeneratedRows") + + AggregatedMetricsResult(recordsRead.sorted, shuffleRecordsRead.sorted, generatedRes, outputRes) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 039625421e..14fbe9f443 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation, JdbcUtils} +import org.apache.spark.sql.execution.MetricsTestHelper import org.apache.spark.sql.sources._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -915,4 +916,13 @@ class JDBCSuite extends SparkFunSuite }.getMessage assert(e2.contains("User specified schema not supported with `jdbc`")) } + + test("Input/generated/output metrics on JDBC") { + val foobarCnt = spark.table("foobar").count() + val res = MetricsTestHelper.runAndGetMetrics(sql("SELECT * FROM foobar").toDF()) + assert(res.recordsRead === foobarCnt :: Nil) + assert(res.shuffleRecordsRead.sum === 0) + assert(res.generatedRows.isEmpty) + assert(res.outputRows === foobarCnt :: Nil) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala index ec620c2403..35c41b531c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive.execution import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.execution.MetricsTestHelper import org.apache.spark.sql.hive.test.TestHive /** @@ -47,4 +48,22 @@ class HiveSerDeSuite extends HiveComparisonTest with BeforeAndAfterAll { createQueryTest("Read with AvroSerDe", "SELECT * FROM episodes") createQueryTest("Read Partitioned with AvroSerDe", "SELECT * FROM episodes_part") + + test("Test input/generated/output metrics") { + import TestHive._ + + val episodesCnt = sql("select * from episodes").count() + val episodesRes = MetricsTestHelper.runAndGetMetrics(sql("select * from episodes").toDF()) + assert(episodesRes.recordsRead === episodesCnt :: Nil) + assert(episodesRes.shuffleRecordsRead.sum === 0) + assert(episodesRes.generatedRows.isEmpty) + assert(episodesRes.outputRows === episodesCnt :: Nil) + + val serdeinsCnt = sql("select * from serdeins").count() + val serdeinsRes = MetricsTestHelper.runAndGetMetrics(sql("select * from serdeins").toDF()) + assert(serdeinsRes.recordsRead === serdeinsCnt :: Nil) + assert(serdeinsRes.shuffleRecordsRead.sum === 0) + assert(serdeinsRes.generatedRows.isEmpty) + assert(serdeinsRes.outputRows === serdeinsCnt :: Nil) + } } |