diff options
author | Wenchen Fan <wenchen@databricks.com> | 2016-04-19 21:53:19 -0700 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2016-04-19 21:53:19 -0700 |
commit | 856bc465d53ccfdfda75c82c85d7f318a5158088 (patch) | |
tree | 4d168f7380313ceaf3bc56249113dfae004150ee /sql/hive/src/main/scala/org | |
parent | 85d759ca3aebb7d60b963207dcada83c75502e52 (diff) | |
download | spark-856bc465d53ccfdfda75c82c85d7f318a5158088.tar.gz spark-856bc465d53ccfdfda75c82c85d7f318a5158088.tar.bz2 spark-856bc465d53ccfdfda75c82c85d7f318a5158088.zip |
[SPARK-14600] [SQL] Push predicates through Expand
## What changes were proposed in this pull request?
https://issues.apache.org/jira/browse/SPARK-14600
This PR makes `Expand.output` have different attributes from the grouping attributes produced by the underlying `Project`, as they have different meaning, so that we can safely push down filter through `Expand`
## How was this patch tested?
existing tests.
Author: Wenchen Fan <wenchen@databricks.com>
Closes #12496 from cloud-fan/expand.
Diffstat (limited to 'sql/hive/src/main/scala/org')
-rw-r--r-- | sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala | 14 |
1 files changed, 9 insertions, 5 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index e54358e657..2d44813f0e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -288,8 +288,9 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi private def isGroupingSet(a: Aggregate, e: Expand, p: Project): Boolean = { assert(a.child == e && e.child == p) - a.groupingExpressions.forall(_.isInstanceOf[Attribute]) && - sameOutput(e.output, p.child.output ++ a.groupingExpressions.map(_.asInstanceOf[Attribute])) + a.groupingExpressions.forall(_.isInstanceOf[Attribute]) && sameOutput( + e.output.drop(p.child.output.length), + a.groupingExpressions.map(_.asInstanceOf[Attribute])) } private def groupingSetToSQL( @@ -303,25 +304,28 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi val numOriginalOutput = project.child.output.length // Assumption: Aggregate's groupingExpressions is composed of - // 1) the attributes of aliased group by expressions + // 1) the grouping attributes // 2) gid, which is always the last one val groupByAttributes = agg.groupingExpressions.dropRight(1).map(_.asInstanceOf[Attribute]) // Assumption: Project's projectList is composed of // 1) the original output (Project's child.output), // 2) the aliased group by expressions. + val expandedAttributes = project.output.drop(numOriginalOutput) val groupByExprs = project.projectList.drop(numOriginalOutput).map(_.asInstanceOf[Alias].child) val groupingSQL = groupByExprs.map(_.sql).mkString(", ") // a map from group by attributes to the original group by expressions. val groupByAttrMap = AttributeMap(groupByAttributes.zip(groupByExprs)) + // a map from expanded attributes to the original group by expressions. + val expandedAttrMap = AttributeMap(expandedAttributes.zip(groupByExprs)) val groupingSet: Seq[Seq[Expression]] = expand.projections.map { project => // Assumption: expand.projections is composed of // 1) the original output (Project's child.output), - // 2) group by attributes(or null literal) + // 2) expanded attributes(or null literal) // 3) gid, which is always the last one in each project in Expand project.drop(numOriginalOutput).dropRight(1).collect { - case attr: Attribute if groupByAttrMap.contains(attr) => groupByAttrMap(attr) + case attr: Attribute if expandedAttrMap.contains(attr) => expandedAttrMap(attr) } } val groupingSetSQL = "GROUPING SETS(" + |