aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala38
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala122
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsConfSuite.scala64
3 files changed, 145 insertions, 79 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index ccebae3cc2..4d27ff2acd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -752,14 +752,13 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN
}
override def computeStats(conf: CatalystConf): Statistics = {
val limit = limitExpr.eval().asInstanceOf[Int]
- val sizeInBytes = if (limit == 0) {
- // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero
- // (product of children).
- 1
- } else {
- (limit: Long) * output.map(a => a.dataType.defaultSize).sum
- }
- child.stats(conf).copy(sizeInBytes = sizeInBytes)
+ val childStats = child.stats(conf)
+ val rowCount: BigInt = childStats.rowCount.map(_.min(limit)).getOrElse(limit)
+ // Don't propagate column stats, because we don't know the distribution after a limit operation
+ Statistics(
+ sizeInBytes = EstimationUtils.getOutputSize(output, rowCount, childStats.attributeStats),
+ rowCount = Some(rowCount),
+ isBroadcastable = childStats.isBroadcastable)
}
}
@@ -773,14 +772,21 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo
}
override def computeStats(conf: CatalystConf): Statistics = {
val limit = limitExpr.eval().asInstanceOf[Int]
- val sizeInBytes = if (limit == 0) {
+ val childStats = child.stats(conf)
+ if (limit == 0) {
// sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero
// (product of children).
- 1
+ Statistics(
+ sizeInBytes = 1,
+ rowCount = Some(0),
+ isBroadcastable = childStats.isBroadcastable)
} else {
- (limit: Long) * output.map(a => a.dataType.defaultSize).sum
+ // The output row count of LocalLimit should be the sum of row counts from each partition.
+ // However, since the number of partitions is not available here, we just use statistics of
+ // the child. Because the distribution after a limit operation is unknown, we do not propagate
+ // the column stats.
+ childStats.copy(attributeStats = AttributeMap(Nil))
}
- child.stats(conf).copy(sizeInBytes = sizeInBytes)
}
}
@@ -816,12 +822,14 @@ case class Sample(
override def computeStats(conf: CatalystConf): Statistics = {
val ratio = upperBound - lowerBound
- // BigInt can't multiply with Double
- var sizeInBytes = child.stats(conf).sizeInBytes * (ratio * 100).toInt / 100
+ val childStats = child.stats(conf)
+ var sizeInBytes = EstimationUtils.ceil(BigDecimal(childStats.sizeInBytes) * ratio)
if (sizeInBytes == 0) {
sizeInBytes = 1
}
- child.stats(conf).copy(sizeInBytes = sizeInBytes)
+ val sampledRowCount = childStats.rowCount.map(c => EstimationUtils.ceil(BigDecimal(c) * ratio))
+ // Don't propagate column stats, because we don't know the distribution after a sample operation
+ Statistics(sizeInBytes, sampledRowCount, isBroadcastable = childStats.isBroadcastable)
}
override protected def otherCopyArgs: Seq[AnyRef] = isTableSample :: Nil
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala
new file mode 100644
index 0000000000..e5dc811c8b
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.statsEstimation
+
+import org.apache.spark.sql.catalyst.CatalystConf
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Literal}
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.types.IntegerType
+
+
+class BasicStatsEstimationSuite extends StatsEstimationTestBase {
+ val attribute = attr("key")
+ val colStat = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
+ nullCount = 0, avgLen = 4, maxLen = 4)
+
+ val plan = StatsTestPlan(
+ outputList = Seq(attribute),
+ attributeStats = AttributeMap(Seq(attribute -> colStat)),
+ rowCount = 10,
+ // row count * (overhead + column size)
+ size = Some(10 * (8 + 4)))
+
+ test("limit estimation: limit < child's rowCount") {
+ val localLimit = LocalLimit(Literal(2), plan)
+ val globalLimit = GlobalLimit(Literal(2), plan)
+ // LocalLimit's stats is just its child's stats except column stats
+ checkStats(localLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil)))
+ checkStats(globalLimit, Statistics(sizeInBytes = 24, rowCount = Some(2)))
+ }
+
+ test("limit estimation: limit > child's rowCount") {
+ val localLimit = LocalLimit(Literal(20), plan)
+ val globalLimit = GlobalLimit(Literal(20), plan)
+ checkStats(localLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil)))
+ // Limit is larger than child's rowCount, so GlobalLimit's stats is equal to its child's stats.
+ checkStats(globalLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil)))
+ }
+
+ test("limit estimation: limit = 0") {
+ val localLimit = LocalLimit(Literal(0), plan)
+ val globalLimit = GlobalLimit(Literal(0), plan)
+ val stats = Statistics(sizeInBytes = 1, rowCount = Some(0))
+ checkStats(localLimit, stats)
+ checkStats(globalLimit, stats)
+ }
+
+ test("sample estimation") {
+ val sample = Sample(0.0, 0.5, withReplacement = false, (math.random * 1000).toLong, plan)()
+ checkStats(sample, Statistics(sizeInBytes = 60, rowCount = Some(5)))
+
+ // Child doesn't have rowCount in stats
+ val childStats = Statistics(sizeInBytes = 120)
+ val childPlan = DummyLogicalPlan(childStats, childStats)
+ val sample2 =
+ Sample(0.0, 0.11, withReplacement = false, (math.random * 1000).toLong, childPlan)()
+ checkStats(sample2, Statistics(sizeInBytes = 14))
+ }
+
+ test("estimate statistics when the conf changes") {
+ val expectedDefaultStats =
+ Statistics(
+ sizeInBytes = 40,
+ rowCount = Some(10),
+ attributeStats = AttributeMap(Seq(
+ AttributeReference("c1", IntegerType)() -> ColumnStat(10, Some(1), Some(10), 0, 4, 4))),
+ isBroadcastable = false)
+ val expectedCboStats =
+ Statistics(
+ sizeInBytes = 4,
+ rowCount = Some(1),
+ attributeStats = AttributeMap(Seq(
+ AttributeReference("c1", IntegerType)() -> ColumnStat(1, Some(5), Some(5), 0, 4, 4))),
+ isBroadcastable = false)
+
+ val plan = DummyLogicalPlan(defaultStats = expectedDefaultStats, cboStats = expectedCboStats)
+ checkStats(
+ plan, expectedStatsCboOn = expectedCboStats, expectedStatsCboOff = expectedDefaultStats)
+ }
+
+ /** Check estimated stats when cbo is turned on/off. */
+ private def checkStats(
+ plan: LogicalPlan,
+ expectedStatsCboOn: Statistics,
+ expectedStatsCboOff: Statistics): Unit = {
+ assert(plan.stats(conf.copy(cboEnabled = true)) == expectedStatsCboOn)
+ // Invalidate statistics
+ plan.invalidateStatsCache()
+ assert(plan.stats(conf.copy(cboEnabled = false)) == expectedStatsCboOff)
+ }
+
+ /** Check estimated stats when it's the same whether cbo is turned on or off. */
+ private def checkStats(plan: LogicalPlan, expectedStats: Statistics): Unit =
+ checkStats(plan, expectedStats, expectedStats)
+}
+
+/**
+ * This class is used for unit-testing the cbo switch, it mimics a logical plan which computes
+ * a simple statistics or a cbo estimated statistics based on the conf.
+ */
+private case class DummyLogicalPlan(
+ defaultStats: Statistics,
+ cboStats: Statistics) extends LogicalPlan {
+ override def output: Seq[Attribute] = Nil
+ override def children: Seq[LogicalPlan] = Nil
+ override def computeStats(conf: CatalystConf): Statistics =
+ if (conf.cboEnabled) cboStats else defaultStats
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsConfSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsConfSuite.scala
deleted file mode 100644
index 212d57a9bc..0000000000
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsConfSuite.scala
+++ /dev/null
@@ -1,64 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.statsEstimation
-
-import org.apache.spark.sql.catalyst.CatalystConf
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference}
-import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics}
-import org.apache.spark.sql.types.IntegerType
-
-
-class StatsConfSuite extends StatsEstimationTestBase {
- test("estimate statistics when the conf changes") {
- val expectedDefaultStats =
- Statistics(
- sizeInBytes = 40,
- rowCount = Some(10),
- attributeStats = AttributeMap(Seq(
- AttributeReference("c1", IntegerType)() -> ColumnStat(10, Some(1), Some(10), 0, 4, 4))),
- isBroadcastable = false)
- val expectedCboStats =
- Statistics(
- sizeInBytes = 4,
- rowCount = Some(1),
- attributeStats = AttributeMap(Seq(
- AttributeReference("c1", IntegerType)() -> ColumnStat(1, Some(5), Some(5), 0, 4, 4))),
- isBroadcastable = false)
-
- val plan = DummyLogicalPlan(defaultStats = expectedDefaultStats, cboStats = expectedCboStats)
- // Return the statistics estimated by cbo
- assert(plan.stats(conf.copy(cboEnabled = true)) == expectedCboStats)
- // Invalidate statistics
- plan.invalidateStatsCache()
- // Return the simple statistics
- assert(plan.stats(conf.copy(cboEnabled = false)) == expectedDefaultStats)
- }
-}
-
-/**
- * This class is used for unit-testing the cbo switch, it mimics a logical plan which computes
- * a simple statistics or a cbo estimated statistics based on the conf.
- */
-private case class DummyLogicalPlan(
- defaultStats: Statistics,
- cboStats: Statistics) extends LogicalPlan {
- override def output: Seq[Attribute] = Nil
- override def children: Seq[LogicalPlan] = Nil
- override def computeStats(conf: CatalystConf): Statistics =
- if (conf.cboEnabled) cboStats else defaultStats
-}