aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-09-02 11:32:27 -0700
committerMichael Armbrust <michael@databricks.com>2015-09-02 11:32:27 -0700
commitfc48307797912dc1d53893dce741ddda8630957b (patch)
tree1478760123a5e4919bf4ab0f6333693dbd690a54 /sql
parent56c4c172e99a5e14f4bc3308e7ff36d94113b63e (diff)
downloadspark-fc48307797912dc1d53893dce741ddda8630957b.tar.gz
spark-fc48307797912dc1d53893dce741ddda8630957b.tar.bz2
spark-fc48307797912dc1d53893dce741ddda8630957b.zip
[SPARK-10389] [SQL] support order by non-attribute grouping expression on Aggregate
For example, we can write `SELECT MAX(value) FROM src GROUP BY key + 1 ORDER BY key + 1` in PostgreSQL, and we should support this in Spark SQL. Author: Wenchen Fan <cloud0fan@outlook.com> Closes #8548 from cloud-fan/support-order-by-non-attribute.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala72
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala19
2 files changed, 52 insertions, 39 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 1a5de15c61..591747b45c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -560,43 +560,47 @@ class Analyzer(
filter
}
- case sort @ Sort(sortOrder, global,
- aggregate @ Aggregate(grouping, originalAggExprs, child))
+ case sort @ Sort(sortOrder, global, aggregate: Aggregate)
if aggregate.resolved && !sort.resolved =>
// Try resolving the ordering as though it is in the aggregate clause.
try {
- val aliasedOrder = sortOrder.map(o => Alias(o.child, "aggOrder")())
- val aggregatedOrdering = Aggregate(grouping, aliasedOrder, child)
- val resolvedOperator: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate]
- def resolvedAggregateOrdering = resolvedOperator.aggregateExpressions
-
- // Expressions that have an aggregate can be pushed down.
- val needsAggregate = resolvedAggregateOrdering.exists(containsAggregate)
-
- // Attribute references, that are missing from the order but are present in the grouping
- // expressions can also be pushed down.
- val requiredAttributes = resolvedAggregateOrdering.map(_.references).reduce(_ ++ _)
- val missingAttributes = requiredAttributes -- aggregate.outputSet
- val validPushdownAttributes =
- missingAttributes.filter(a => grouping.exists(a.semanticEquals))
-
- // If resolution was successful and we see the ordering either has an aggregate in it or
- // it is missing something that is projected away by the aggregate, add the ordering
- // the original aggregate operator.
- if (resolvedOperator.resolved && (needsAggregate || validPushdownAttributes.nonEmpty)) {
- val evaluatedOrderings: Seq[SortOrder] = sortOrder.zip(resolvedAggregateOrdering).map {
- case (order, evaluated) => order.copy(child = evaluated.toAttribute)
- }
- val aggExprsWithOrdering: Seq[NamedExpression] =
- resolvedAggregateOrdering ++ originalAggExprs
-
- Project(aggregate.output,
- Sort(evaluatedOrderings, global,
- aggregate.copy(aggregateExpressions = aggExprsWithOrdering)))
- } else {
- sort
+ val aliasedOrdering = sortOrder.map(o => Alias(o.child, "aggOrder")())
+ val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering)
+ val resolvedAggregate: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate]
+ val resolvedAliasedOrdering: Seq[Alias] =
+ resolvedAggregate.aggregateExpressions.asInstanceOf[Seq[Alias]]
+
+ // If we pass the analysis check, then the ordering expressions should only reference to
+ // aggregate expressions or grouping expressions, and it's safe to push them down to
+ // Aggregate.
+ checkAnalysis(resolvedAggregate)
+
+ val originalAggExprs = aggregate.aggregateExpressions.map(
+ CleanupAliases.trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])
+
+ // If the ordering expression is same with original aggregate expression, we don't need
+ // to push down this ordering expression and can reference the original aggregate
+ // expression instead.
+ val needsPushDown = ArrayBuffer.empty[NamedExpression]
+ val evaluatedOrderings = resolvedAliasedOrdering.zip(sortOrder).map {
+ case (evaluated, order) =>
+ val index = originalAggExprs.indexWhere {
+ case Alias(child, _) => child semanticEquals evaluated.child
+ case other => other semanticEquals evaluated.child
+ }
+
+ if (index == -1) {
+ needsPushDown += evaluated
+ order.copy(child = evaluated.toAttribute)
+ } else {
+ order.copy(child = originalAggExprs(index).toAttribute)
+ }
}
+
+ Project(aggregate.output,
+ Sort(evaluatedOrderings, global,
+ aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown)))
} catch {
// Attempting to resolve in the aggregate can result in ambiguity. When this happens,
// just return the original plan.
@@ -605,9 +609,7 @@ class Analyzer(
}
protected def containsAggregate(condition: Expression): Boolean = {
- condition
- .collect { case ae: AggregateExpression => ae }
- .nonEmpty
+ condition.find(_.isInstanceOf[AggregateExpression]).isDefined
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 28201073a2..0ef25fe0fa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1722,9 +1722,20 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
test("SPARK-10130 type coercion for IF should have children resolved first") {
- val df = Seq((1, 1), (-1, 1)).toDF("key", "value")
- df.registerTempTable("src")
- checkAnswer(
- sql("SELECT IF(a > 0, a, 0) FROM (SELECT key a FROM src) temp"), Seq(Row(1), Row(0)))
+ withTempTable("src") {
+ Seq((1, 1), (-1, 1)).toDF("key", "value").registerTempTable("src")
+ checkAnswer(
+ sql("SELECT IF(a > 0, a, 0) FROM (SELECT key a FROM src) temp"), Seq(Row(1), Row(0)))
+ }
+ }
+
+ test("SPARK-10389: order by non-attribute grouping expression on Aggregate") {
+ withTempTable("src") {
+ Seq((1, 1), (-1, 1)).toDF("key", "value").registerTempTable("src")
+ checkAnswer(sql("SELECT MAX(value) FROM src GROUP BY key + 1 ORDER BY key + 1"),
+ Seq(Row(1), Row(1)))
+ checkAnswer(sql("SELECT MAX(value) FROM src GROUP BY key + 1 ORDER BY (key + 1) * 2"),
+ Seq(Row(1), Row(1)))
+ }
}
}