diff options
author | Cheng Lian <lian@databricks.com> | 2014-11-14 15:09:36 -0800 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2014-11-14 15:09:36 -0800 |
commit | 0c7b66bd449093bb5d2dafaf91d54e63e601e320 (patch) | |
tree | 598c2985d9281a75fccbfd55e8ca06cd910955c7 /sql/catalyst | |
parent | 4b4b50c9e596673c1534df97effad50d107a8007 (diff) | |
download | spark-0c7b66bd449093bb5d2dafaf91d54e63e601e320.tar.gz spark-0c7b66bd449093bb5d2dafaf91d54e63e601e320.tar.bz2 spark-0c7b66bd449093bb5d2dafaf91d54e63e601e320.zip |
[SPARK-4322][SQL] Enables struct fields as sub expressions of grouping fields
While resolving struct fields, the resulted `GetField` expression is wrapped with an `Alias` to make it a named expression. Assume `a` is a struct instance with a field `b`, then `"a.b"` will be resolved as `Alias(GetField(a, "b"), "b")`. Thus, for this following SQL query:
```sql
SELECT a.b + 1 FROM t GROUP BY a.b + 1
```
the grouping expression is
```scala
Add(GetField(a, "b"), Literal(1, IntegerType))
```
while the aggregation expression is
```scala
Add(Alias(GetField(a, "b"), "b"), Literal(1, IntegerType))
```
This mismatch makes the above SQL query fail during the both analysis and execution phases. This PR fixes this issue by removing the alias when substituting aggregation expressions.
<!-- Reviewable:start -->
[<img src="https://reviewable.io/review_button.png" height=40 alt="Review on Reviewable"/>](https://reviewable.io/reviews/apache/spark/3248)
<!-- Reviewable:end -->
Author: Cheng Lian <lian@databricks.com>
Closes #3248 from liancheng/spark-4322 and squashes the following commits:
23a46ea [Cheng Lian] Code simplification
dd20a79 [Cheng Lian] Should only trim aliases around `GetField`s
7f46532 [Cheng Lian] Enables struct fields as sub expressions of grouping fields
Diffstat (limited to 'sql/catalyst')
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 27 | ||||
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala | 15 |
2 files changed, 23 insertions, 19 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a448c79421..d3b4cf8e34 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -60,7 +60,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool ResolveFunctions :: GlobalAggregates :: UnresolvedHavingClauseAttributes :: - TrimAliases :: + TrimGroupingAliases :: typeCoercionRules ++ extendedRules : _*), Batch("Check Analysis", Once, @@ -93,17 +93,10 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool /** * Removes no-op Alias expressions from the plan. */ - object TrimAliases extends Rule[LogicalPlan] { + object TrimGroupingAliases extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case Aggregate(groups, aggs, child) => - Aggregate( - groups.map { - _ transform { - case Alias(c, _) => c - } - }, - aggs, - child) + Aggregate(groups.map(_.transform { case Alias(c, _) => c }), aggs, child) } } @@ -122,10 +115,15 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool case e => e.children.forall(isValidAggregateExpression) } - aggregateExprs.foreach { e => - if (!isValidAggregateExpression(e)) { - throw new TreeNodeException(plan, s"Expression not in GROUP BY: $e") - } + aggregateExprs.find { e => + !isValidAggregateExpression(e.transform { + // Should trim aliases around `GetField`s. These aliases are introduced while + // resolving struct field accesses, because `GetField` is not a `NamedExpression`. + // (Should we just turn `GetField` into a `NamedExpression`?) + case Alias(g: GetField, _) => g + }) + }.foreach { e => + throw new TreeNodeException(plan, s"Expression not in GROUP BY: $e") } aggregatePlan @@ -328,4 +326,3 @@ object EliminateAnalysisOperators extends Rule[LogicalPlan] { case Subquery(_, child) => child } } - diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index f0fd9a8b9a..310d127506 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -151,8 +151,15 @@ object PartialAggregation { val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp { case e: Expression if partialEvaluations.contains(new TreeNodeRef(e)) => partialEvaluations(new TreeNodeRef(e)).finalEvaluation - case e: Expression if namedGroupingExpressions.contains(e) => - namedGroupingExpressions(e).toAttribute + + case e: Expression => + // Should trim aliases around `GetField`s. These aliases are introduced while + // resolving struct field accesses, because `GetField` is not a `NamedExpression`. + // (Should we just turn `GetField` into a `NamedExpression`?) + namedGroupingExpressions + .get(e.transform { case Alias(g: GetField, _) => g }) + .map(_.toAttribute) + .getOrElse(e) }).asInstanceOf[Seq[NamedExpression]] val partialComputation = @@ -188,7 +195,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { logDebug(s"Considering join on: $condition") // Find equi-join predicates that can be evaluated before the join, and thus can be used // as join keys. - val (joinPredicates, otherPredicates) = + val (joinPredicates, otherPredicates) = condition.map(splitConjunctivePredicates).getOrElse(Nil).partition { case EqualTo(l, r) if (canEvaluate(l, left) && canEvaluate(r, right)) || (canEvaluate(l, right) && canEvaluate(r, left)) => true @@ -203,7 +210,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { val rightKeys = joinKeys.map(_._2) if (joinKeys.nonEmpty) { - logDebug(s"leftKeys:${leftKeys} | rightKeys:${rightKeys}") + logDebug(s"leftKeys:$leftKeys | rightKeys:$rightKeys") Some((joinType, leftKeys, rightKeys, otherPredicates.reduceOption(And), left, right)) } else { None |