aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorwangzhenhua <wangzhenhua@huawei.com>2017-01-19 22:18:47 -0800
committergatorsmile <gatorsmile@gmail.com>2017-01-19 22:18:47 -0800
commit039ed9fe8a2fdcd99e0561af64cda8fe3406bc12 (patch)
treea3fe8043551f0ae9ee88cc1f1b9df8b9bcaf92fa /sql
parent0bf605c2c67ca361cd4aa3a3b4492bef4aef76b9 (diff)
downloadspark-039ed9fe8a2fdcd99e0561af64cda8fe3406bc12.tar.gz
spark-039ed9fe8a2fdcd99e0561af64cda8fe3406bc12.tar.bz2
spark-039ed9fe8a2fdcd99e0561af64cda8fe3406bc12.zip
[SPARK-19271][SQL] Change non-cbo estimation of aggregate
## What changes were proposed in this pull request? Change non-cbo estimation behavior of aggregate: - If groupExpression is empty, we can know row count (=1) and the corresponding size; - otherwise, estimation falls back to UnaryNode's computeStats method, which should not propagate rowCount and attributeStats in Statistics because they are not estimated in that method. ## How was this patch tested? Added test case Author: wangzhenhua <wangzhenhua@huawei.com> Closes #16631 from wzhfy/aggNoCbo.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala24
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala7
7 files changed, 38 insertions, 11 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 0587a59214..93550e1fc3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -344,7 +344,8 @@ abstract class UnaryNode extends LogicalPlan {
sizeInBytes = 1
}
- child.stats(conf).copy(sizeInBytes = sizeInBytes)
+ // Don't propagate rowCount and attributeStats, since they are not estimated here.
+ Statistics(sizeInBytes = sizeInBytes, isBroadcastable = child.stats(conf).isBroadcastable)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 3bd314315d..432097d621 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTypes}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.{AggregateEstimation, ProjectEstimation}
+import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.{AggregateEstimation, EstimationUtils, ProjectEstimation}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -541,7 +541,10 @@ case class Aggregate(
override def computeStats(conf: CatalystConf): Statistics = {
def simpleEstimation: Statistics = {
if (groupingExpressions.isEmpty) {
- super.computeStats(conf).copy(sizeInBytes = 1)
+ Statistics(
+ sizeInBytes = EstimationUtils.getOutputSize(output, outputRowCount = 1),
+ rowCount = Some(1),
+ isBroadcastable = child.stats(conf).isBroadcastable)
} else {
super.computeStats(conf)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
index 21e94fc941..ce74554c17 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala
@@ -53,7 +53,7 @@ object AggregateEstimation {
val outputAttrStats = getOutputMap(childStats.attributeStats, agg.output)
Some(Statistics(
- sizeInBytes = getOutputSize(agg.output, outputAttrStats, outputRows),
+ sizeInBytes = getOutputSize(agg.output, outputRows, outputAttrStats),
rowCount = Some(outputRows),
attributeStats = outputAttrStats,
isBroadcastable = childStats.isBroadcastable))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
index cf4452d0fd..e8b794212c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
@@ -37,8 +37,8 @@ object EstimationUtils {
def getOutputSize(
attributes: Seq[Attribute],
- attrStats: AttributeMap[ColumnStat],
- outputRowCount: BigInt): BigInt = {
+ outputRowCount: BigInt,
+ attrStats: AttributeMap[ColumnStat] = AttributeMap(Nil)): BigInt = {
// We assign a generic overhead for a Row object, the actual overhead is different for different
// Row format.
val sizePerRow = 8 + attributes.map { attr =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala
index 50b869ab3a..e9084ad8b8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala
@@ -36,7 +36,7 @@ object ProjectEstimation {
val outputAttrStats =
getOutputMap(AttributeMap(inputAttrStats.toSeq ++ aliasStats), project.output)
Some(childStats.copy(
- sizeInBytes = getOutputSize(project.output, outputAttrStats, childStats.rowCount.get),
+ sizeInBytes = getOutputSize(project.output, childStats.rowCount.get, outputAttrStats),
attributeStats = outputAttrStats))
} else {
None
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala
index 41a4bc359e..c0b9515ca7 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala
@@ -90,6 +90,28 @@ class AggregateEstimationSuite extends StatsEstimationTestBase {
expectedOutputRowCount = 0)
}
+ test("non-cbo estimation") {
+ val attributes = Seq("key12").map(nameToAttr)
+ val child = StatsTestPlan(
+ outputList = attributes,
+ rowCount = 4,
+ // rowCount * (overhead + column size)
+ size = Some(4 * (8 + 4)),
+ attributeStats = AttributeMap(Seq("key12").map(nameToColInfo)))
+
+ val noGroupAgg = Aggregate(groupingExpressions = Nil,
+ aggregateExpressions = Seq(Alias(Count(Literal(1)), "cnt")()), child)
+ assert(noGroupAgg.stats(conf.copy(cboEnabled = false)) ==
+ // overhead + count result size
+ Statistics(sizeInBytes = 8 + 8, rowCount = Some(1)))
+
+ val hasGroupAgg = Aggregate(groupingExpressions = attributes,
+ aggregateExpressions = attributes :+ Alias(Count(Literal(1)), "cnt")(), child)
+ assert(hasGroupAgg.stats(conf.copy(cboEnabled = false)) ==
+ // From UnaryNode.computeStats, childSize * outputRowSize / childRowSize
+ Statistics(sizeInBytes = 48 * (8 + 4 + 8) / (8 + 4)))
+ }
+
private def checkAggStats(
tableColumns: Seq[String],
tableRowCount: BigInt,
@@ -107,7 +129,7 @@ class AggregateEstimationSuite extends StatsEstimationTestBase {
val expectedAttrStats = AttributeMap(groupByColumns.map(nameToColInfo))
val expectedStats = Statistics(
- sizeInBytes = getOutputSize(testAgg.output, expectedAttrStats, expectedOutputRowCount),
+ sizeInBytes = getOutputSize(testAgg.output, expectedOutputRowCount, expectedAttrStats),
rowCount = Some(expectedOutputRowCount),
attributeStats = expectedAttrStats)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala
index e6adb6700a..a5fac4ba6f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala
@@ -45,11 +45,12 @@ class StatsEstimationTestBase extends SparkFunSuite {
protected case class StatsTestPlan(
outputList: Seq[Attribute],
rowCount: BigInt,
- attributeStats: AttributeMap[ColumnStat]) extends LeafNode {
+ attributeStats: AttributeMap[ColumnStat],
+ size: Option[BigInt] = None) extends LeafNode {
override def output: Seq[Attribute] = outputList
override def computeStats(conf: CatalystConf): Statistics = Statistics(
- // sizeInBytes in stats of StatsTestPlan is useless in cbo estimation, we just use a fake value
- sizeInBytes = Int.MaxValue,
+ // If sizeInBytes is useless in testing, we just use a fake value
+ sizeInBytes = size.getOrElse(Int.MaxValue),
rowCount = Some(rowCount),
attributeStats = attributeStats)
}