diff options
author | Sean Zhong <seanzhong@databricks.com> | 2016-08-16 15:51:30 +0800 |
---|---|---|
committer | Wenchen Fan <wenchen@databricks.com> | 2016-08-16 15:51:30 +0800 |
commit | 7b65030e7a0af3a0bd09370fb069d659b36ff7f0 (patch) | |
tree | c820a00facee5871059e8412c23125823944b838 /sql | |
parent | 7de30d6e9e5d3020d2ba8c2ce08893d9cd822b56 (diff) | |
download | spark-7b65030e7a0af3a0bd09370fb069d659b36ff7f0.tar.gz spark-7b65030e7a0af3a0bd09370fb069d659b36ff7f0.tar.bz2 spark-7b65030e7a0af3a0bd09370fb069d659b36ff7f0.zip |
[SPARK-17034][SQL] adds expression UnresolvedOrdinal to represent the ordinals in GROUP BY or ORDER BY
## What changes were proposed in this pull request?
This PR adds expression `UnresolvedOrdinal` to represent the ordinal in GROUP BY or ORDER BY, and fixes the rules when resolving ordinals.
Ordinals in GROUP BY or ORDER BY like `1` in `order by 1` or `group by 1` should be considered as unresolved before analysis. But in current code, it uses `Literal` expression to store the ordinal. This is inappropriate as `Literal` itself is a resolved expression, it gives the user a wrong message that the ordinals has already been resolved.
### Before this change
Ordinal is stored as `Literal` expression
```
scala> sc.setLogLevel("TRACE")
scala> sql("select a from t group by 1 order by 1")
...
'Sort [1 ASC], true
+- 'Aggregate [1], ['a]
+- 'UnresolvedRelation `t
```
For query:
```
scala> Seq(1).toDF("a").createOrReplaceTempView("t")
scala> sql("select count(a), a from t group by 2 having a > 0").show
```
During analysis, the intermediate plan before applying rule `ResolveAggregateFunctions` is:
```
'Filter ('a > 0)
+- Aggregate [2], [count(1) AS count(1)#83L, a#81]
+- LocalRelation [value#7 AS a#9]
```
Before this PR, rule `ResolveAggregateFunctions` believes all expressions of `Aggregate` have already been resolved, and tries to resolve the expressions in `Filter` directly. But this is wrong, as ordinal `2` in Aggregate is not really resolved!
### After this change
Ordinals are stored as `UnresolvedOrdinal`.
```
scala> sc.setLogLevel("TRACE")
scala> sql("select a from t group by 1 order by 1")
...
'Sort [unresolvedordinal(1) ASC], true
+- 'Aggregate [unresolvedordinal(1)], ['a]
+- 'UnresolvedRelation `t`
```
## How was this patch tested?
Unit tests.
Author: Sean Zhong <seanzhong@databricks.com>
Closes #14616 from clockfly/spark-16955.
Diffstat (limited to 'sql')
7 files changed, 175 insertions, 19 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a2e276e8a2..a2a022c247 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -22,17 +22,16 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf} -import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogRelation, InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.encoders.OuterScopes import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects.NewInstance import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification -import org.apache.spark.sql.catalyst.planning.IntegerIndex import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _} import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.catalyst.trees.TreeNodeRef +import org.apache.spark.sql.catalyst.trees.{TreeNodeRef} import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.types._ @@ -84,7 +83,8 @@ class Analyzer( Batch("Substitution", fixedPoint, CTESubstitution, WindowsSubstitution, - EliminateUnions), + EliminateUnions, + new UnresolvedOrdinalSubstitution(conf)), Batch("Resolution", fixedPoint, ResolveRelations :: ResolveReferences :: @@ -545,7 +545,7 @@ class Analyzer( p.copy(projectList = buildExpandedProjectList(p.projectList, p.child)) // If the aggregate function argument contains Stars, expand it. case a: Aggregate if containsStar(a.aggregateExpressions) => - if (conf.groupByOrdinal && a.groupingExpressions.exists(IntegerIndex.unapply(_).nonEmpty)) { + if (a.groupingExpressions.exists(_.isInstanceOf[UnresolvedOrdinal])) { failAnalysis( "Star (*) is not allowed in select list when GROUP BY ordinal position is used") } else { @@ -716,9 +716,9 @@ class Analyzer( // Replace the index with the related attribute for ORDER BY, // which is a 1-base position of the projection list. case s @ Sort(orders, global, child) - if conf.orderByOrdinal && orders.exists(o => IntegerIndex.unapply(o.child).nonEmpty) => + if orders.exists(_.child.isInstanceOf[UnresolvedOrdinal]) => val newOrders = orders map { - case s @ SortOrder(IntegerIndex(index), direction) => + case s @ SortOrder(UnresolvedOrdinal(index), direction) => if (index > 0 && index <= child.output.size) { SortOrder(child.output(index - 1), direction) } else { @@ -732,11 +732,10 @@ class Analyzer( // Replace the index with the corresponding expression in aggregateExpressions. The index is // a 1-base position of aggregateExpressions, which is output columns (select expression) - case a @ Aggregate(groups, aggs, child) - if conf.groupByOrdinal && aggs.forall(_.resolved) && - groups.exists(IntegerIndex.unapply(_).nonEmpty) => + case a @ Aggregate(groups, aggs, child) if aggs.forall(_.resolved) && + groups.exists(_.isInstanceOf[UnresolvedOrdinal]) => val newGroups = groups.map { - case ordinal @ IntegerIndex(index) if index > 0 && index <= aggs.size => + case ordinal @ UnresolvedOrdinal(index) if index > 0 && index <= aggs.size => aggs(index - 1) match { case e if ResolveAggregateFunctions.containsAggregate(e) => ordinal.failAnalysis( @@ -744,7 +743,7 @@ class Analyzer( "aggregate functions are not allowed in GROUP BY") case o => o } - case ordinal @ IntegerIndex(index) => + case ordinal @ UnresolvedOrdinal(index) => ordinal.failAnalysis( s"GROUP BY position $index is not in select list " + s"(valid range is [1, ${aggs.size}])") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolvedOrdinalSubstitution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolvedOrdinalSubstitution.scala new file mode 100644 index 0000000000..e21cd08af8 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolvedOrdinalSubstitution.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.CatalystConf +import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder} +import org.apache.spark.sql.catalyst.planning.IntegerIndex +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Sort} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin + +/** + * Replaces ordinal in 'order by' or 'group by' with UnresolvedOrdinal expression. + */ +class UnresolvedOrdinalSubstitution(conf: CatalystConf) extends Rule[LogicalPlan] { + private def isIntegerLiteral(sorter: Expression) = IntegerIndex.unapply(sorter).nonEmpty + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case s @ Sort(orders, global, child) if conf.orderByOrdinal && + orders.exists(o => isIntegerLiteral(o.child)) => + val newOrders = orders.map { + case order @ SortOrder(ordinal @ IntegerIndex(index: Int), _) => + val newOrdinal = withOrigin(ordinal.origin)(UnresolvedOrdinal(index)) + withOrigin(order.origin)(order.copy(child = newOrdinal)) + case other => other + } + withOrigin(s.origin)(s.copy(order = newOrders)) + case a @ Aggregate(groups, aggs, child) if conf.groupByOrdinal && + groups.exists(isIntegerLiteral(_)) => + val newGroups = groups.map { + case ordinal @ IntegerIndex(index) => + withOrigin(ordinal.origin)(UnresolvedOrdinal(index)) + case other => other + } + withOrigin(a.origin)(a.copy(groupingExpressions = newGroups)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 609089a302..42e7aae0b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -370,3 +370,21 @@ case class GetColumnByOrdinal(ordinal: Int, dataType: DataType) extends LeafExpr override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false } + +/** + * Represents unresolved ordinal used in order by or group by. + * + * For example: + * {{{ + * select a from table order by 1 + * select a from table group by 1 + * }}} + * @param ordinal ordinal starts from 1, instead of 0 + */ +case class UnresolvedOrdinal(ordinal: Int) + extends LeafExpression with Unevaluable with NonSQLExpression { + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def foldable: Boolean = throw new UnresolvedException(this, "foldable") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + override lazy val resolved = false +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 102c78bd72..22e1c9be05 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.{SimpleCatalystConf, TableIdentifier} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnresolvedOrdinalSubstitutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnresolvedOrdinalSubstitutionSuite.scala new file mode 100644 index 0000000000..23995e96e1 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnresolvedOrdinalSubstitutionSuite.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.analysis.TestRelations.testRelation2 +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.SimpleCatalystConf + +class UnresolvedOrdinalSubstitutionSuite extends AnalysisTest { + + test("test rule UnresolvedOrdinalSubstitution, replaces ordinal in order by or group by") { + val a = testRelation2.output(0) + val b = testRelation2.output(1) + val conf = new SimpleCatalystConf(caseSensitiveAnalysis = true) + + // Expression OrderByOrdinal is unresolved. + assert(!UnresolvedOrdinal(0).resolved) + + // Tests order by ordinal, apply single rule. + val plan = testRelation2.orderBy(Literal(1).asc, Literal(2).asc) + comparePlans( + new UnresolvedOrdinalSubstitution(conf).apply(plan), + testRelation2.orderBy(UnresolvedOrdinal(1).asc, UnresolvedOrdinal(2).asc)) + + // Tests order by ordinal, do full analysis + checkAnalysis(plan, testRelation2.orderBy(a.asc, b.asc)) + + // order by ordinal can be turned off by config + comparePlans( + new UnresolvedOrdinalSubstitution(conf.copy(orderByOrdinal = false)).apply(plan), + testRelation2.orderBy(Literal(1).asc, Literal(2).asc)) + + + // Tests group by ordinal, apply single rule. + val plan2 = testRelation2.groupBy(Literal(1), Literal(2))('a, 'b) + comparePlans( + new UnresolvedOrdinalSubstitution(conf).apply(plan2), + testRelation2.groupBy(UnresolvedOrdinal(1), UnresolvedOrdinal(2))('a, 'b)) + + // Tests group by ordinal, do full analysis + checkAnalysis(plan2, testRelation2.groupBy(a, b)(a, b)) + + // group by ordinal can be turned off by config + comparePlans( + new UnresolvedOrdinalSubstitution(conf.copy(groupByOrdinal = false)).apply(plan2), + testRelation2.groupBy(Literal(1), Literal(2))('a, 'b)) + } +} diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql index 36b469c617..9c8d851e36 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql @@ -43,6 +43,12 @@ select a, rand(0), sum(b) from data group by a, 2; -- negative case: star select * from data group by a, b, 1; +-- group by ordinal followed by order by +select a, count(a) from (select 1 as a) tmp group by 1 order by 1; + +-- group by ordinal followed by having +select count(a), a from (select 1 as a) tmp group by 2 having a > 0; + -- turn of group by ordinal set spark.sql.groupByOrdinal=false; diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out index 2f10b7ebc6..9c3a145f3a 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 17 +-- Number of queries: 19 -- !query 0 @@ -153,16 +153,32 @@ Star (*) is not allowed in select list when GROUP BY ordinal position is used; -- !query 15 -set spark.sql.groupByOrdinal=false +select a, count(a) from (select 1 as a) tmp group by 1 order by 1 -- !query 15 schema -struct<key:string,value:string> +struct<a:int,count(a):bigint> -- !query 15 output -spark.sql.groupByOrdinal +1 1 -- !query 16 -select sum(b) from data group by -1 +select count(a), a from (select 1 as a) tmp group by 2 having a > 0 -- !query 16 schema -struct<sum(b):bigint> +struct<count(a):bigint,a:int> -- !query 16 output +1 1 + + +-- !query 17 +set spark.sql.groupByOrdinal=false +-- !query 17 schema +struct<key:string,value:string> +-- !query 17 output +spark.sql.groupByOrdinal + + +-- !query 18 +select sum(b) from data group by -1 +-- !query 18 schema +struct<sum(b):bigint> +-- !query 18 output 9 |