aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorwangzhenhua <wangzhenhua@huawei.com>2017-04-01 22:19:08 +0800
committerWenchen Fan <wenchen@databricks.com>2017-04-01 22:19:08 +0800
commit2287f3d0b85730995bedc489a017de5700d6e1e4 (patch)
tree5146b42f6f31a66c5e98486c38a9155de8b580a7 /sql/catalyst
parent89d6822f722912d2b05571a95a539092091650b5 (diff)
downloadspark-2287f3d0b85730995bedc489a017de5700d6e1e4.tar.gz
spark-2287f3d0b85730995bedc489a017de5700d6e1e4.tar.bz2
spark-2287f3d0b85730995bedc489a017de5700d6e1e4.zip
[SPARK-20186][SQL] BroadcastHint should use child's stats
## What changes were proposed in this pull request? `BroadcastHint` should use child's statistics and set `isBroadcastable` to true. ## How was this patch tested? Added a new stats estimation test for `BroadcastHint`. Author: wangzhenhua <wangzhenhua@huawei.com> Closes #17504 from wzhfy/broadcastHintEstimation.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala21
2 files changed, 21 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 5cbf263d1c..19db42c808 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -383,7 +383,7 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode {
// set isBroadcastable to true so the child will be broadcasted
override def computeStats(conf: CatalystConf): Statistics =
- super.computeStats(conf).copy(isBroadcastable = true)
+ child.stats(conf).copy(isBroadcastable = true)
}
/**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala
index e5dc811c8b..0d92c1e355 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala
@@ -35,6 +35,23 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase {
// row count * (overhead + column size)
size = Some(10 * (8 + 4)))
+ test("BroadcastHint estimation") {
+ val filter = Filter(Literal(true), plan)
+ val filterStatsCboOn = Statistics(sizeInBytes = 10 * (8 +4), isBroadcastable = false,
+ rowCount = Some(10), attributeStats = AttributeMap(Seq(attribute -> colStat)))
+ val filterStatsCboOff = Statistics(sizeInBytes = 10 * (8 +4), isBroadcastable = false)
+ checkStats(
+ filter,
+ expectedStatsCboOn = filterStatsCboOn,
+ expectedStatsCboOff = filterStatsCboOff)
+
+ val broadcastHint = BroadcastHint(filter)
+ checkStats(
+ broadcastHint,
+ expectedStatsCboOn = filterStatsCboOn.copy(isBroadcastable = true),
+ expectedStatsCboOff = filterStatsCboOff.copy(isBroadcastable = true))
+ }
+
test("limit estimation: limit < child's rowCount") {
val localLimit = LocalLimit(Literal(2), plan)
val globalLimit = GlobalLimit(Literal(2), plan)
@@ -97,9 +114,11 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase {
plan: LogicalPlan,
expectedStatsCboOn: Statistics,
expectedStatsCboOff: Statistics): Unit = {
- assert(plan.stats(conf.copy(cboEnabled = true)) == expectedStatsCboOn)
// Invalidate statistics
plan.invalidateStatsCache()
+ assert(plan.stats(conf.copy(cboEnabled = true)) == expectedStatsCboOn)
+
+ plan.invalidateStatsCache()
assert(plan.stats(conf.copy(cboEnabled = false)) == expectedStatsCboOff)
}