aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorCheng Lian <lian@databricks.com>2014-11-14 15:09:36 -0800
committerMichael Armbrust <michael@databricks.com>2014-11-14 15:09:55 -0800
commit1cac30083b97c98c3663e2d2cd057124f033eb34 (patch)
treee836fe9ce22b90d20288f9f914a6c33d3d9a4d62 /sql
parent680bc06195ecdc6ff2390c55adeb637649f2c8f3 (diff)
downloadspark-1cac30083b97c98c3663e2d2cd057124f033eb34.tar.gz
spark-1cac30083b97c98c3663e2d2cd057124f033eb34.tar.bz2
spark-1cac30083b97c98c3663e2d2cd057124f033eb34.zip
[SPARK-4322][SQL] Enables struct fields as sub expressions of grouping fields
While resolving struct fields, the resulted `GetField` expression is wrapped with an `Alias` to make it a named expression. Assume `a` is a struct instance with a field `b`, then `"a.b"` will be resolved as `Alias(GetField(a, "b"), "b")`. Thus, for this following SQL query: ```sql SELECT a.b + 1 FROM t GROUP BY a.b + 1 ``` the grouping expression is ```scala Add(GetField(a, "b"), Literal(1, IntegerType)) ``` while the aggregation expression is ```scala Add(Alias(GetField(a, "b"), "b"), Literal(1, IntegerType)) ``` This mismatch makes the above SQL query fail during the both analysis and execution phases. This PR fixes this issue by removing the alias when substituting aggregation expressions. <!-- Reviewable:start --> [<img src="https://reviewable.io/review_button.png" height=40 alt="Review on Reviewable"/>](https://reviewable.io/reviews/apache/spark/3248) <!-- Reviewable:end --> Author: Cheng Lian <lian@databricks.com> Closes #3248 from liancheng/spark-4322 and squashes the following commits: 23a46ea [Cheng Lian] Code simplification dd20a79 [Cheng Lian] Should only trim aliases around `GetField`s 7f46532 [Cheng Lian] Enables struct fields as sub expressions of grouping fields (cherry picked from commit 0c7b66bd449093bb5d2dafaf91d54e63e601e320) Signed-off-by: Michael Armbrust <michael@databricks.com>
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala27
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala15
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala12
3 files changed, 34 insertions, 20 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index a448c79421..d3b4cf8e34 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -60,7 +60,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
ResolveFunctions ::
GlobalAggregates ::
UnresolvedHavingClauseAttributes ::
- TrimAliases ::
+ TrimGroupingAliases ::
typeCoercionRules ++
extendedRules : _*),
Batch("Check Analysis", Once,
@@ -93,17 +93,10 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
/**
* Removes no-op Alias expressions from the plan.
*/
- object TrimAliases extends Rule[LogicalPlan] {
+ object TrimGroupingAliases extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Aggregate(groups, aggs, child) =>
- Aggregate(
- groups.map {
- _ transform {
- case Alias(c, _) => c
- }
- },
- aggs,
- child)
+ Aggregate(groups.map(_.transform { case Alias(c, _) => c }), aggs, child)
}
}
@@ -122,10 +115,15 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
case e => e.children.forall(isValidAggregateExpression)
}
- aggregateExprs.foreach { e =>
- if (!isValidAggregateExpression(e)) {
- throw new TreeNodeException(plan, s"Expression not in GROUP BY: $e")
- }
+ aggregateExprs.find { e =>
+ !isValidAggregateExpression(e.transform {
+ // Should trim aliases around `GetField`s. These aliases are introduced while
+ // resolving struct field accesses, because `GetField` is not a `NamedExpression`.
+ // (Should we just turn `GetField` into a `NamedExpression`?)
+ case Alias(g: GetField, _) => g
+ })
+ }.foreach { e =>
+ throw new TreeNodeException(plan, s"Expression not in GROUP BY: $e")
}
aggregatePlan
@@ -328,4 +326,3 @@ object EliminateAnalysisOperators extends Rule[LogicalPlan] {
case Subquery(_, child) => child
}
}
-
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index f0fd9a8b9a..310d127506 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -151,8 +151,15 @@ object PartialAggregation {
val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp {
case e: Expression if partialEvaluations.contains(new TreeNodeRef(e)) =>
partialEvaluations(new TreeNodeRef(e)).finalEvaluation
- case e: Expression if namedGroupingExpressions.contains(e) =>
- namedGroupingExpressions(e).toAttribute
+
+ case e: Expression =>
+ // Should trim aliases around `GetField`s. These aliases are introduced while
+ // resolving struct field accesses, because `GetField` is not a `NamedExpression`.
+ // (Should we just turn `GetField` into a `NamedExpression`?)
+ namedGroupingExpressions
+ .get(e.transform { case Alias(g: GetField, _) => g })
+ .map(_.toAttribute)
+ .getOrElse(e)
}).asInstanceOf[Seq[NamedExpression]]
val partialComputation =
@@ -188,7 +195,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper {
logDebug(s"Considering join on: $condition")
// Find equi-join predicates that can be evaluated before the join, and thus can be used
// as join keys.
- val (joinPredicates, otherPredicates) =
+ val (joinPredicates, otherPredicates) =
condition.map(splitConjunctivePredicates).getOrElse(Nil).partition {
case EqualTo(l, r) if (canEvaluate(l, left) && canEvaluate(r, right)) ||
(canEvaluate(l, right) && canEvaluate(r, left)) => true
@@ -203,7 +210,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper {
val rightKeys = joinKeys.map(_._2)
if (joinKeys.nonEmpty) {
- logDebug(s"leftKeys:${leftKeys} | rightKeys:${rightKeys}")
+ logDebug(s"leftKeys:$leftKeys | rightKeys:$rightKeys")
Some((joinType, leftKeys, rightKeys, otherPredicates.reduceOption(And), left, right))
} else {
None
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 5dd777f1fb..ce5672c086 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -551,7 +551,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
sql("SELECT * FROM upperCaseData EXCEPT SELECT * FROM upperCaseData"), Nil)
}
- test("INTERSECT") {
+ test("INTERSECT") {
checkAnswer(
sql("SELECT * FROM lowerCaseData INTERSECT SELECT * FROM lowerCaseData"),
(1, "a") ::
@@ -949,4 +949,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
checkAnswer(sql("SELECT key FROM testData WHERE value not like '100%' order by key"),
(1 to 99).map(i => Seq(i)))
}
+
+ test("SPARK-4322 Grouping field with struct field as sub expression") {
+ jsonRDD(sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)).registerTempTable("data")
+ checkAnswer(sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), 1)
+ dropTempTable("data")
+
+ jsonRDD(sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data")
+ checkAnswer(sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), 2)
+ dropTempTable("data")
+ }
}