aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-04-19 21:53:19 -0700
committerDavies Liu <davies.liu@gmail.com>2016-04-19 21:53:19 -0700
commit856bc465d53ccfdfda75c82c85d7f318a5158088 (patch)
tree4d168f7380313ceaf3bc56249113dfae004150ee /sql
parent85d759ca3aebb7d60b963207dcada83c75502e52 (diff)
downloadspark-856bc465d53ccfdfda75c82c85d7f318a5158088.tar.gz
spark-856bc465d53ccfdfda75c82c85d7f318a5158088.tar.bz2
spark-856bc465d53ccfdfda75c82c85d7f318a5158088.zip
[SPARK-14600] [SQL] Push predicates through Expand
## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-14600 This PR makes `Expand.output` have different attributes from the grouping attributes produced by the underlying `Project`, as they have different meaning, so that we can safely push down filter through `Expand` ## How was this patch tested? existing tests. Author: Wenchen Fan <wenchen@databricks.com> Closes #12496 from cloud-fan/expand.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala5
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala15
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala14
5 files changed, 34 insertions, 14 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 236476900a..8595762988 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -296,10 +296,13 @@ class Analyzer(
val nonNullBitmask = x.bitmasks.reduce(_ & _)
- val groupByAttributes = groupByAliases.zipWithIndex.map { case (a, idx) =>
+ val expandedAttributes = groupByAliases.zipWithIndex.map { case (a, idx) =>
a.toAttribute.withNullability((nonNullBitmask & 1 << idx) == 0)
}
+ val expand = Expand(x.bitmasks, groupByAliases, expandedAttributes, gid, x.child)
+ val groupingAttrs = expand.output.drop(x.child.output.length)
+
val aggregations: Seq[NamedExpression] = x.aggregations.map { case expr =>
// collect all the found AggregateExpression, so we can check an expression is part of
// any AggregateExpression or not.
@@ -321,15 +324,12 @@ class Analyzer(
if (index == -1) {
e
} else {
- groupByAttributes(index)
+ groupingAttrs(index)
}
}.asInstanceOf[NamedExpression]
}
- Aggregate(
- groupByAttributes :+ gid,
- aggregations,
- Expand(x.bitmasks, groupByAliases, groupByAttributes, gid, x.child))
+ Aggregate(groupingAttrs, aggregations, expand)
case f @ Filter(cond, child) if hasGroupingFunction(cond) =>
val groupingExprs = findGroupingExprs(child)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index ecc2d773e7..e6d554565d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -1020,8 +1020,6 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
case filter @ Filter(_, f: Filter) => filter
// should not push predicates through sample, or will generate different results.
case filter @ Filter(_, s: Sample) => filter
- // TODO: push predicates through expand
- case filter @ Filter(_, e: Expand) => filter
case filter @ Filter(condition, u: UnaryNode) if u.expressions.forall(_.deterministic) =>
pushDownPredicate(filter, u.child) { predicate =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index d4fc9e4da9..a445ce6947 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -516,7 +516,10 @@ private[sql] object Expand {
// groupingId is the last output, here we use the bit mask as the concrete value for it.
} :+ Literal.create(bitmask, IntegerType)
}
- val output = child.output ++ groupByAttrs :+ gid
+
+ // the `groupByAttrs` has different meaning in `Expand.output`, it could be the original
+ // grouping expression or null, so here we create new instance of it.
+ val output = child.output ++ groupByAttrs.map(_.newInstance) :+ gid
Expand(projections, output, Project(child.output ++ groupByAliases, child))
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index df7529d83f..9174b4e649 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -743,4 +743,19 @@ class FilterPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
+
+ test("expand") {
+ val agg = testRelation
+ .groupBy(Cube(Seq('a, 'b)))('a, 'b, sum('c))
+ .analyze
+ .asInstanceOf[Aggregate]
+
+ val a = agg.output(0)
+ val b = agg.output(1)
+
+ val query = agg.where(a > 1 && b > 2)
+ val optimized = Optimize.execute(query)
+ val correctedAnswer = agg.copy(child = agg.child.where(a > 1 && b > 2)).analyze
+ comparePlans(optimized, correctedAnswer)
+ }
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
index e54358e657..2d44813f0e 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
@@ -288,8 +288,9 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
private def isGroupingSet(a: Aggregate, e: Expand, p: Project): Boolean = {
assert(a.child == e && e.child == p)
- a.groupingExpressions.forall(_.isInstanceOf[Attribute]) &&
- sameOutput(e.output, p.child.output ++ a.groupingExpressions.map(_.asInstanceOf[Attribute]))
+ a.groupingExpressions.forall(_.isInstanceOf[Attribute]) && sameOutput(
+ e.output.drop(p.child.output.length),
+ a.groupingExpressions.map(_.asInstanceOf[Attribute]))
}
private def groupingSetToSQL(
@@ -303,25 +304,28 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
val numOriginalOutput = project.child.output.length
// Assumption: Aggregate's groupingExpressions is composed of
- // 1) the attributes of aliased group by expressions
+ // 1) the grouping attributes
// 2) gid, which is always the last one
val groupByAttributes = agg.groupingExpressions.dropRight(1).map(_.asInstanceOf[Attribute])
// Assumption: Project's projectList is composed of
// 1) the original output (Project's child.output),
// 2) the aliased group by expressions.
+ val expandedAttributes = project.output.drop(numOriginalOutput)
val groupByExprs = project.projectList.drop(numOriginalOutput).map(_.asInstanceOf[Alias].child)
val groupingSQL = groupByExprs.map(_.sql).mkString(", ")
// a map from group by attributes to the original group by expressions.
val groupByAttrMap = AttributeMap(groupByAttributes.zip(groupByExprs))
+ // a map from expanded attributes to the original group by expressions.
+ val expandedAttrMap = AttributeMap(expandedAttributes.zip(groupByExprs))
val groupingSet: Seq[Seq[Expression]] = expand.projections.map { project =>
// Assumption: expand.projections is composed of
// 1) the original output (Project's child.output),
- // 2) group by attributes(or null literal)
+ // 2) expanded attributes(or null literal)
// 3) gid, which is always the last one in each project in Expand
project.drop(numOriginalOutput).dropRight(1).collect {
- case attr: Attribute if groupByAttrMap.contains(attr) => groupByAttrMap(attr)
+ case attr: Attribute if expandedAttrMap.contains(attr) => expandedAttrMap(attr)
}
}
val groupingSetSQL = "GROUPING SETS(" +