diff options
author | wangzhenhua <wangzhenhua@huawei.com> | 2017-01-19 22:18:47 -0800 |
---|---|---|
committer | gatorsmile <gatorsmile@gmail.com> | 2017-01-19 22:18:47 -0800 |
commit | 039ed9fe8a2fdcd99e0561af64cda8fe3406bc12 (patch) | |
tree | a3fe8043551f0ae9ee88cc1f1b9df8b9bcaf92fa /sql | |
parent | 0bf605c2c67ca361cd4aa3a3b4492bef4aef76b9 (diff) | |
download | spark-039ed9fe8a2fdcd99e0561af64cda8fe3406bc12.tar.gz spark-039ed9fe8a2fdcd99e0561af64cda8fe3406bc12.tar.bz2 spark-039ed9fe8a2fdcd99e0561af64cda8fe3406bc12.zip |
[SPARK-19271][SQL] Change non-cbo estimation of aggregate
## What changes were proposed in this pull request?
Change non-cbo estimation behavior of aggregate:
- If groupExpression is empty, we can know row count (=1) and the corresponding size;
- otherwise, estimation falls back to UnaryNode's computeStats method, which should not propagate rowCount and attributeStats in Statistics because they are not estimated in that method.
## How was this patch tested?
Added test case
Author: wangzhenhua <wangzhenhua@huawei.com>
Closes #16631 from wzhfy/aggNoCbo.
Diffstat (limited to 'sql')
7 files changed, 38 insertions, 11 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 0587a59214..93550e1fc3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -344,7 +344,8 @@ abstract class UnaryNode extends LogicalPlan { sizeInBytes = 1 } - child.stats(conf).copy(sizeInBytes = sizeInBytes) + // Don't propagate rowCount and attributeStats, since they are not estimated here. + Statistics(sizeInBytes = sizeInBytes, isBroadcastable = child.stats(conf).isBroadcastable) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 3bd314315d..432097d621 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTypes} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.{AggregateEstimation, ProjectEstimation} +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.{AggregateEstimation, EstimationUtils, ProjectEstimation} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -541,7 +541,10 @@ case class Aggregate( override def computeStats(conf: CatalystConf): Statistics = { def simpleEstimation: Statistics = { if (groupingExpressions.isEmpty) { - super.computeStats(conf).copy(sizeInBytes = 1) + Statistics( + sizeInBytes = EstimationUtils.getOutputSize(output, outputRowCount = 1), + rowCount = Some(1), + isBroadcastable = child.stats(conf).isBroadcastable) } else { super.computeStats(conf) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala index 21e94fc941..ce74554c17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala @@ -53,7 +53,7 @@ object AggregateEstimation { val outputAttrStats = getOutputMap(childStats.attributeStats, agg.output) Some(Statistics( - sizeInBytes = getOutputSize(agg.output, outputAttrStats, outputRows), + sizeInBytes = getOutputSize(agg.output, outputRows, outputAttrStats), rowCount = Some(outputRows), attributeStats = outputAttrStats, isBroadcastable = childStats.isBroadcastable)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index cf4452d0fd..e8b794212c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -37,8 +37,8 @@ object EstimationUtils { def getOutputSize( attributes: Seq[Attribute], - attrStats: AttributeMap[ColumnStat], - outputRowCount: BigInt): BigInt = { + outputRowCount: BigInt, + attrStats: AttributeMap[ColumnStat] = AttributeMap(Nil)): BigInt = { // We assign a generic overhead for a Row object, the actual overhead is different for different // Row format. val sizePerRow = 8 + attributes.map { attr => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala index 50b869ab3a..e9084ad8b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala @@ -36,7 +36,7 @@ object ProjectEstimation { val outputAttrStats = getOutputMap(AttributeMap(inputAttrStats.toSeq ++ aliasStats), project.output) Some(childStats.copy( - sizeInBytes = getOutputSize(project.output, outputAttrStats, childStats.rowCount.get), + sizeInBytes = getOutputSize(project.output, childStats.rowCount.get, outputAttrStats), attributeStats = outputAttrStats)) } else { None diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala index 41a4bc359e..c0b9515ca7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala @@ -90,6 +90,28 @@ class AggregateEstimationSuite extends StatsEstimationTestBase { expectedOutputRowCount = 0) } + test("non-cbo estimation") { + val attributes = Seq("key12").map(nameToAttr) + val child = StatsTestPlan( + outputList = attributes, + rowCount = 4, + // rowCount * (overhead + column size) + size = Some(4 * (8 + 4)), + attributeStats = AttributeMap(Seq("key12").map(nameToColInfo))) + + val noGroupAgg = Aggregate(groupingExpressions = Nil, + aggregateExpressions = Seq(Alias(Count(Literal(1)), "cnt")()), child) + assert(noGroupAgg.stats(conf.copy(cboEnabled = false)) == + // overhead + count result size + Statistics(sizeInBytes = 8 + 8, rowCount = Some(1))) + + val hasGroupAgg = Aggregate(groupingExpressions = attributes, + aggregateExpressions = attributes :+ Alias(Count(Literal(1)), "cnt")(), child) + assert(hasGroupAgg.stats(conf.copy(cboEnabled = false)) == + // From UnaryNode.computeStats, childSize * outputRowSize / childRowSize + Statistics(sizeInBytes = 48 * (8 + 4 + 8) / (8 + 4))) + } + private def checkAggStats( tableColumns: Seq[String], tableRowCount: BigInt, @@ -107,7 +129,7 @@ class AggregateEstimationSuite extends StatsEstimationTestBase { val expectedAttrStats = AttributeMap(groupByColumns.map(nameToColInfo)) val expectedStats = Statistics( - sizeInBytes = getOutputSize(testAgg.output, expectedAttrStats, expectedOutputRowCount), + sizeInBytes = getOutputSize(testAgg.output, expectedOutputRowCount, expectedAttrStats), rowCount = Some(expectedOutputRowCount), attributeStats = expectedAttrStats) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala index e6adb6700a..a5fac4ba6f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala @@ -45,11 +45,12 @@ class StatsEstimationTestBase extends SparkFunSuite { protected case class StatsTestPlan( outputList: Seq[Attribute], rowCount: BigInt, - attributeStats: AttributeMap[ColumnStat]) extends LeafNode { + attributeStats: AttributeMap[ColumnStat], + size: Option[BigInt] = None) extends LeafNode { override def output: Seq[Attribute] = outputList override def computeStats(conf: CatalystConf): Statistics = Statistics( - // sizeInBytes in stats of StatsTestPlan is useless in cbo estimation, we just use a fake value - sizeInBytes = Int.MaxValue, + // If sizeInBytes is useless in testing, we just use a fake value + sizeInBytes = size.getOrElse(Int.MaxValue), rowCount = Some(rowCount), attributeStats = attributeStats) } |