diff options
author | Zhenhua Wang <wzh_zju@163.com> | 2017-01-09 11:29:42 -0800 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2017-01-09 11:29:42 -0800 |
commit | 15c2bd01b03b1a07f10779f68118cd28f2c62c9a (patch) | |
tree | 733787ab1c0b4c185533c320fe6bd2fff9ab1d98 /sql/catalyst/src | |
parent | 3ccabdfb4d760d684b1e0c0ed448a57331f209f2 (diff) | |
download | spark-15c2bd01b03b1a07f10779f68118cd28f2c62c9a.tar.gz spark-15c2bd01b03b1a07f10779f68118cd28f2c62c9a.tar.bz2 spark-15c2bd01b03b1a07f10779f68118cd28f2c62c9a.zip |
[SPARK-19020][SQL] Cardinality estimation of aggregate operator
## What changes were proposed in this pull request?
Support cardinality estimation of aggregate operator
## How was this patch tested?
Add test cases
Author: Zhenhua Wang <wzh_zju@163.com>
Author: wangzhenhua <wangzhenhua@huawei.com>
Closes #16431 from wzhfy/aggEstimation.
Diffstat (limited to 'sql/catalyst/src')
4 files changed, 198 insertions, 3 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 9b52a9cc81..b97c81ce01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.ProjectEstimation +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.{AggregateEstimation, ProjectEstimation} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -495,7 +495,7 @@ case class Aggregate( child.constraints.union(getAliasedConstraints(nonAgg)) } - override lazy val statistics: Statistics = { + override lazy val statistics: Statistics = AggregateEstimation.estimate(this).getOrElse { if (groupingExpressions.isEmpty) { super.statistics.copy(sizeInBytes = 1) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala new file mode 100644 index 0000000000..33ebc380d2 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical.statsEstimation + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Statistics} + + +object AggregateEstimation { + import EstimationUtils._ + + /** + * Estimate the number of output rows based on column stats of group-by columns, and propagate + * column stats for aggregate expressions. + */ + def estimate(agg: Aggregate): Option[Statistics] = { + val childStats = agg.child.statistics + // Check if we have column stats for all group-by columns. + val colStatsExist = agg.groupingExpressions.forall { e => + e.isInstanceOf[Attribute] && childStats.attributeStats.contains(e.asInstanceOf[Attribute]) + } + if (rowCountsExist(agg.child) && colStatsExist) { + // Multiply distinct counts of group-by columns. This is an upper bound, which assumes + // the data contains all combinations of distinct values of group-by columns. + var outputRows: BigInt = agg.groupingExpressions.foldLeft(BigInt(1))( + (res, expr) => res * childStats.attributeStats(expr.asInstanceOf[Attribute]).distinctCount) + + // Here we set another upper bound for the number of output rows: it must not be larger than + // child's number of rows. + outputRows = outputRows.min(childStats.rowCount.get) + + val outputAttrStats = getOutputMap(childStats.attributeStats, agg.output) + Some(Statistics( + sizeInBytes = outputRows * getRowSize(agg.output, outputAttrStats), + rowCount = Some(outputRows), + attributeStats = outputAttrStats, + isBroadcastable = childStats.isBroadcastable)) + } else { + None + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggEstimationSuite.scala new file mode 100644 index 0000000000..42ce2f8c5e --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggEstimationSuite.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.statsEstimation + +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, Literal} +import org.apache.spark.sql.catalyst.expressions.aggregate.Count +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ + + +class AggEstimationSuite extends StatsEstimationTestBase { + + /** Columns for testing */ + private val columnInfo: Map[Attribute, ColumnStat] = + Map( + attr("key11") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, + avgLen = 4, maxLen = 4), + attr("key12") -> ColumnStat(distinctCount = 1, min = Some(10), max = Some(10), nullCount = 0, + avgLen = 4, maxLen = 4), + attr("key21") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, + avgLen = 4, maxLen = 4), + attr("key22") -> ColumnStat(distinctCount = 4, min = Some(10), max = Some(40), nullCount = 0, + avgLen = 4, maxLen = 4), + attr("key31") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0, + avgLen = 4, maxLen = 4), + attr("key32") -> ColumnStat(distinctCount = 2, min = Some(10), max = Some(20), nullCount = 0, + avgLen = 4, maxLen = 4)) + + private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) + private val nameToColInfo: Map[String, (Attribute, ColumnStat)] = + columnInfo.map(kv => kv._1.name -> kv) + + test("empty group-by column") { + val colNames = Seq("key11", "key12") + // Suppose table1 has 2 records: (1, 10), (2, 10) + val table1 = StatsTestPlan( + outputList = colNames.map(nameToAttr), + stats = Statistics( + sizeInBytes = 2 * (4 + 4), + rowCount = Some(2), + attributeStats = AttributeMap(colNames.map(nameToColInfo)))) + + checkAggStats( + child = table1, + colNames = Nil, + expectedRowCount = 1) + } + + test("there's a primary key in group-by columns") { + val colNames = Seq("key11", "key12") + // Suppose table1 has 2 records: (1, 10), (2, 10) + val table1 = StatsTestPlan( + outputList = colNames.map(nameToAttr), + stats = Statistics( + sizeInBytes = 2 * (4 + 4), + rowCount = Some(2), + attributeStats = AttributeMap(colNames.map(nameToColInfo)))) + + checkAggStats( + child = table1, + colNames = colNames, + // Column key11 a primary key, so row count = ndv of key11 = child's row count + expectedRowCount = table1.stats.rowCount.get) + } + + test("the product of ndv's of group-by columns is too large") { + val colNames = Seq("key21", "key22") + // Suppose table2 has 4 records: (1, 10), (1, 20), (2, 30), (2, 40) + val table2 = StatsTestPlan( + outputList = colNames.map(nameToAttr), + stats = Statistics( + sizeInBytes = 4 * (4 + 4), + rowCount = Some(4), + attributeStats = AttributeMap(colNames.map(nameToColInfo)))) + + checkAggStats( + child = table2, + colNames = colNames, + // Use child's row count as an upper bound + expectedRowCount = table2.stats.rowCount.get) + } + + test("data contains all combinations of distinct values of group-by columns.") { + val colNames = Seq("key31", "key32") + // Suppose table3 has 6 records: (1, 10), (1, 10), (1, 20), (2, 20), (2, 10), (2, 10) + val table3 = StatsTestPlan( + outputList = colNames.map(nameToAttr), + stats = Statistics( + sizeInBytes = 6 * (4 + 4), + rowCount = Some(6), + attributeStats = AttributeMap(colNames.map(nameToColInfo)))) + + checkAggStats( + child = table3, + colNames = colNames, + // Row count = product of ndv + expectedRowCount = nameToColInfo("key31")._2.distinctCount * nameToColInfo("key32")._2 + .distinctCount) + } + + private def checkAggStats( + child: LogicalPlan, + colNames: Seq[String], + expectedRowCount: BigInt): Unit = { + + val columns = colNames.map(nameToAttr) + val testAgg = Aggregate( + groupingExpressions = columns, + aggregateExpressions = columns :+ Alias(Count(Literal(1)), "cnt")(), + child = child) + + val expectedAttrStats = AttributeMap(colNames.map(nameToColInfo)) + val expectedStats = Statistics( + sizeInBytes = expectedRowCount * getRowSize(testAgg.output, expectedAttrStats), + rowCount = Some(expectedRowCount), + attributeStats = expectedAttrStats) + + assert(testAgg.statistics == expectedStats) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala index fa5b290ecb..0d81aa3f68 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala @@ -18,12 +18,15 @@ package org.apache.spark.sql.catalyst.statsEstimation import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LogicalPlan, Statistics} +import org.apache.spark.sql.types.IntegerType class StatsEstimationTestBase extends SparkFunSuite { + def attr(colName: String): AttributeReference = AttributeReference(colName, IntegerType)() + /** Convert (column name, column stat) pairs to an AttributeMap based on plan output. */ def toAttributeMap(colStats: Seq[(String, ColumnStat)], plan: LogicalPlan) : AttributeMap[ColumnStat] = { |