aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src
diff options
context:
space:
mode:
authorZhenhua Wang <wzh_zju@163.com>2017-01-09 11:29:42 -0800
committerReynold Xin <rxin@databricks.com>2017-01-09 11:29:42 -0800
commit15c2bd01b03b1a07f10779f68118cd28f2c62c9a (patch)
tree733787ab1c0b4c185533c320fe6bd2fff9ab1d98 /sql/catalyst/src
parent3ccabdfb4d760d684b1e0c0ed448a57331f209f2 (diff)
downloadspark-15c2bd01b03b1a07f10779f68118cd28f2c62c9a.tar.gz
spark-15c2bd01b03b1a07f10779f68118cd28f2c62c9a.tar.bz2
spark-15c2bd01b03b1a07f10779f68118cd28f2c62c9a.zip
[SPARK-19020][SQL] Cardinality estimation of aggregate operator
## What changes were proposed in this pull request? Support cardinality estimation of aggregate operator ## How was this patch tested? Add test cases Author: Zhenhua Wang <wzh_zju@163.com> Author: wangzhenhua <wangzhenhua@huawei.com> Closes #16431 from wzhfy/aggEstimation.
Diffstat (limited to 'sql/catalyst/src')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala57
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggEstimationSuite.scala135
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala5
4 files changed, 198 insertions, 3 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 9b52a9cc81..b97c81ce01 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.ProjectEstimation
+import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.{AggregateEstimation, ProjectEstimation}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -495,7 +495,7 @@ case class Aggregate(
child.constraints.union(getAliasedConstraints(nonAgg))
}
- override lazy val statistics: Statistics = {
+ override lazy val statistics: Statistics = AggregateEstimation.estimate(this).getOrElse {
if (groupingExpressions.isEmpty) {
super.statistics.copy(sizeInBytes = 1)
} else {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
new file mode 100644
index 0000000000..33ebc380d2
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.plans.logical.statsEstimation
+
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Statistics}
+
+
+object AggregateEstimation {
+ import EstimationUtils._
+
+ /**
+ * Estimate the number of output rows based on column stats of group-by columns, and propagate
+ * column stats for aggregate expressions.
+ */
+ def estimate(agg: Aggregate): Option[Statistics] = {
+ val childStats = agg.child.statistics
+ // Check if we have column stats for all group-by columns.
+ val colStatsExist = agg.groupingExpressions.forall { e =>
+ e.isInstanceOf[Attribute] && childStats.attributeStats.contains(e.asInstanceOf[Attribute])
+ }
+ if (rowCountsExist(agg.child) && colStatsExist) {
+ // Multiply distinct counts of group-by columns. This is an upper bound, which assumes
+ // the data contains all combinations of distinct values of group-by columns.
+ var outputRows: BigInt = agg.groupingExpressions.foldLeft(BigInt(1))(
+ (res, expr) => res * childStats.attributeStats(expr.asInstanceOf[Attribute]).distinctCount)
+
+ // Here we set another upper bound for the number of output rows: it must not be larger than
+ // child's number of rows.
+ outputRows = outputRows.min(childStats.rowCount.get)
+
+ val outputAttrStats = getOutputMap(childStats.attributeStats, agg.output)
+ Some(Statistics(
+ sizeInBytes = outputRows * getRowSize(agg.output, outputAttrStats),
+ rowCount = Some(outputRows),
+ attributeStats = outputAttrStats,
+ isBroadcastable = childStats.isBroadcastable))
+ } else {
+ None
+ }
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggEstimationSuite.scala
new file mode 100644
index 0000000000..42ce2f8c5e
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggEstimationSuite.scala
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.statsEstimation
+
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, Literal}
+import org.apache.spark.sql.catalyst.expressions.aggregate.Count
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
+
+
+class AggEstimationSuite extends StatsEstimationTestBase {
+
+ /** Columns for testing */
+ private val columnInfo: Map[Attribute, ColumnStat] =
+ Map(
+ attr("key11") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0,
+ avgLen = 4, maxLen = 4),
+ attr("key12") -> ColumnStat(distinctCount = 1, min = Some(10), max = Some(10), nullCount = 0,
+ avgLen = 4, maxLen = 4),
+ attr("key21") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0,
+ avgLen = 4, maxLen = 4),
+ attr("key22") -> ColumnStat(distinctCount = 4, min = Some(10), max = Some(40), nullCount = 0,
+ avgLen = 4, maxLen = 4),
+ attr("key31") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), nullCount = 0,
+ avgLen = 4, maxLen = 4),
+ attr("key32") -> ColumnStat(distinctCount = 2, min = Some(10), max = Some(20), nullCount = 0,
+ avgLen = 4, maxLen = 4))
+
+ private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1)
+ private val nameToColInfo: Map[String, (Attribute, ColumnStat)] =
+ columnInfo.map(kv => kv._1.name -> kv)
+
+ test("empty group-by column") {
+ val colNames = Seq("key11", "key12")
+ // Suppose table1 has 2 records: (1, 10), (2, 10)
+ val table1 = StatsTestPlan(
+ outputList = colNames.map(nameToAttr),
+ stats = Statistics(
+ sizeInBytes = 2 * (4 + 4),
+ rowCount = Some(2),
+ attributeStats = AttributeMap(colNames.map(nameToColInfo))))
+
+ checkAggStats(
+ child = table1,
+ colNames = Nil,
+ expectedRowCount = 1)
+ }
+
+ test("there's a primary key in group-by columns") {
+ val colNames = Seq("key11", "key12")
+ // Suppose table1 has 2 records: (1, 10), (2, 10)
+ val table1 = StatsTestPlan(
+ outputList = colNames.map(nameToAttr),
+ stats = Statistics(
+ sizeInBytes = 2 * (4 + 4),
+ rowCount = Some(2),
+ attributeStats = AttributeMap(colNames.map(nameToColInfo))))
+
+ checkAggStats(
+ child = table1,
+ colNames = colNames,
+ // Column key11 a primary key, so row count = ndv of key11 = child's row count
+ expectedRowCount = table1.stats.rowCount.get)
+ }
+
+ test("the product of ndv's of group-by columns is too large") {
+ val colNames = Seq("key21", "key22")
+ // Suppose table2 has 4 records: (1, 10), (1, 20), (2, 30), (2, 40)
+ val table2 = StatsTestPlan(
+ outputList = colNames.map(nameToAttr),
+ stats = Statistics(
+ sizeInBytes = 4 * (4 + 4),
+ rowCount = Some(4),
+ attributeStats = AttributeMap(colNames.map(nameToColInfo))))
+
+ checkAggStats(
+ child = table2,
+ colNames = colNames,
+ // Use child's row count as an upper bound
+ expectedRowCount = table2.stats.rowCount.get)
+ }
+
+ test("data contains all combinations of distinct values of group-by columns.") {
+ val colNames = Seq("key31", "key32")
+ // Suppose table3 has 6 records: (1, 10), (1, 10), (1, 20), (2, 20), (2, 10), (2, 10)
+ val table3 = StatsTestPlan(
+ outputList = colNames.map(nameToAttr),
+ stats = Statistics(
+ sizeInBytes = 6 * (4 + 4),
+ rowCount = Some(6),
+ attributeStats = AttributeMap(colNames.map(nameToColInfo))))
+
+ checkAggStats(
+ child = table3,
+ colNames = colNames,
+ // Row count = product of ndv
+ expectedRowCount = nameToColInfo("key31")._2.distinctCount * nameToColInfo("key32")._2
+ .distinctCount)
+ }
+
+ private def checkAggStats(
+ child: LogicalPlan,
+ colNames: Seq[String],
+ expectedRowCount: BigInt): Unit = {
+
+ val columns = colNames.map(nameToAttr)
+ val testAgg = Aggregate(
+ groupingExpressions = columns,
+ aggregateExpressions = columns :+ Alias(Count(Literal(1)), "cnt")(),
+ child = child)
+
+ val expectedAttrStats = AttributeMap(colNames.map(nameToColInfo))
+ val expectedStats = Statistics(
+ sizeInBytes = expectedRowCount * getRowSize(testAgg.output, expectedAttrStats),
+ rowCount = Some(expectedRowCount),
+ attributeStats = expectedAttrStats)
+
+ assert(testAgg.statistics == expectedStats)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala
index fa5b290ecb..0d81aa3f68 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala
@@ -18,12 +18,15 @@
package org.apache.spark.sql.catalyst.statsEstimation
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LogicalPlan, Statistics}
+import org.apache.spark.sql.types.IntegerType
class StatsEstimationTestBase extends SparkFunSuite {
+ def attr(colName: String): AttributeReference = AttributeReference(colName, IntegerType)()
+
/** Convert (column name, column stat) pairs to an AttributeMap based on plan output. */
def toAttributeMap(colStats: Seq[(String, ColumnStat)], plan: LogicalPlan)
: AttributeMap[ColumnStat] = {