aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/test
diff options
context:
space:
mode:
authorgatorsmile <gatorsmile@gmail.com>2016-03-25 12:55:58 +0800
committerWenchen Fan <wenchen@databricks.com>2016-03-25 12:55:58 +0800
commit05f652d6c2bbd764a1dd5a45301811e14519486f (patch)
tree46640388ac61767de19808901ba0b59de456bb92 /sql/core/src/test
parent0874ff3aade705a97f174b642c5db01711d214b3 (diff)
downloadspark-05f652d6c2bbd764a1dd5a45301811e14519486f.tar.gz
spark-05f652d6c2bbd764a1dd5a45301811e14519486f.tar.bz2
spark-05f652d6c2bbd764a1dd5a45301811e14519486f.zip
[SPARK-13957][SQL] Support Group By Ordinal in SQL
#### What changes were proposed in this pull request? This PR is to support group by position in SQL. For example, when users input the following query ```SQL select c1 as a, c2, c3, sum(*) from tbl group by 1, 3, c4 ``` The ordinals are recognized as the positions in the select list. Thus, `Analyzer` converts it to ```SQL select c1, c2, c3, sum(*) from tbl group by c1, c3, c4 ``` This is controlled by the config option `spark.sql.groupByOrdinal`. - When true, the ordinal numbers in group by clauses are treated as the position in the select list. - When false, the ordinal numbers are ignored. - Only convert integer literals (not foldable expressions). If found foldable expressions, ignore them. - When the positions specified in the group by clauses correspond to the aggregate functions in select list, output an exception message. - star is not allowed to use in the select list when users specify ordinals in group by Note: This PR is taken from https://github.com/apache/spark/pull/10731. When merging this PR, please give the credit to zhichao-li Also cc all the people who are involved in the previous discussion: rxin cloud-fan marmbrus yhuai hvanhovell adrian-wang chenghao-intel tejasapatil #### How was this patch tested? Added a few test cases for both positive and negative test cases. Author: gatorsmile <gatorsmile@gmail.com> Author: xiaoli <lixiao1983@gmail.com> Author: Xiao Li <xiaoli@Xiaos-MacBook-Pro.local> Closes #11846 from gatorsmile/groupByOrdinal.
Diffstat (limited to 'sql/core/src/test')
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala92
1 files changed, 85 insertions, 7 deletions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index eb486a135f..61358fda76 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -23,6 +23,7 @@ import java.sql.Timestamp
import org.apache.spark.AccumulatorSuite
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.expressions.SortOrder
+import org.apache.spark.sql.catalyst.plans.logical.Aggregate
import org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, CartesianProduct, SortMergeJoin}
import org.apache.spark.sql.functions._
@@ -459,25 +460,103 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
Seq(Row(1, 3), Row(2, 3), Row(3, 3)))
}
- test("literal in agg grouping expressions") {
+ test("Group By Ordinal - basic") {
checkAnswer(
- sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"),
- Seq(Row(1, 2), Row(2, 2), Row(3, 2)))
- checkAnswer(
- sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"),
- Seq(Row(1, 2), Row(2, 2), Row(3, 2)))
+ sql("SELECT a, sum(b) FROM testData2 GROUP BY 1"),
+ sql("SELECT a, sum(b) FROM testData2 GROUP BY a"))
+ // duplicate group-by columns
checkAnswer(
sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1"),
sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a"))
+
+ checkAnswer(
+ sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY 1, 2"),
+ sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a"))
+ }
+
+ test("Group By Ordinal - non aggregate expressions") {
+ checkAnswer(
+ sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, 2"),
+ sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, b + 2"))
+
+ checkAnswer(
+ sql("SELECT a, b + 2 as c, count(2) FROM testData2 GROUP BY a, 2"),
+ sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, b + 2"))
+ }
+
+ test("Group By Ordinal - non-foldable constant expression") {
+ checkAnswer(
+ sql("SELECT a, b, sum(b) FROM testData2 GROUP BY a, b, 1 + 0"),
+ sql("SELECT a, b, sum(b) FROM testData2 GROUP BY a, b"))
+
checkAnswer(
sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"),
sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a"))
+ }
+
+ test("Group By Ordinal - alias") {
+ checkAnswer(
+ sql("SELECT a, (b + 2) as c, count(2) FROM testData2 GROUP BY a, 2"),
+ sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, b + 2"))
+
+ checkAnswer(
+ sql("SELECT a as b, b as a, sum(b) FROM testData2 GROUP BY 1, 2"),
+ sql("SELECT a, b, sum(b) FROM testData2 GROUP BY a, b"))
+ }
+
+ test("Group By Ordinal - constants") {
checkAnswer(
sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY 1, 2"),
sql("SELECT 1, 2, sum(b) FROM testData2"))
}
+ test("Group By Ordinal - negative cases") {
+ intercept[UnresolvedException[Aggregate]] {
+ sql("SELECT a, b FROM testData2 GROUP BY -1")
+ }
+
+ intercept[UnresolvedException[Aggregate]] {
+ sql("SELECT a, b FROM testData2 GROUP BY 3")
+ }
+
+ var e = intercept[UnresolvedException[Aggregate]](
+ sql("SELECT SUM(a) FROM testData2 GROUP BY 1"))
+ assert(e.getMessage contains
+ "Invalid call to Group by position: the '1'th column in the select contains " +
+ "an aggregate function")
+
+ e = intercept[UnresolvedException[Aggregate]](
+ sql("SELECT SUM(a) + 1 FROM testData2 GROUP BY 1"))
+ assert(e.getMessage contains
+ "Invalid call to Group by position: the '1'th column in the select contains " +
+ "an aggregate function")
+
+ var ae = intercept[AnalysisException](
+ sql("SELECT a, rand(0), sum(b) FROM testData2 GROUP BY a, 2"))
+ assert(ae.getMessage contains
+ "nondeterministic expression rand(0) should not appear in grouping expression")
+
+ ae = intercept[AnalysisException](
+ sql("SELECT * FROM testData2 GROUP BY a, b, 1"))
+ assert(ae.getMessage contains
+ "Group by position: star is not allowed to use in the select list " +
+ "when using ordinals in group by")
+ }
+
+ test("Group By Ordinal: spark.sql.groupByOrdinal=false") {
+ withSQLConf(SQLConf.GROUP_BY_ORDINAL.key -> "false") {
+ // If spark.sql.groupByOrdinal=false, ignore the position number.
+ intercept[AnalysisException] {
+ sql("SELECT a, sum(b) FROM testData2 GROUP BY 1")
+ }
+ // '*' is not allowed to use in the select list when users specify ordinals in group by
+ checkAnswer(
+ sql("SELECT * FROM testData2 GROUP BY a, b, 1"),
+ sql("SELECT * FROM testData2 GROUP BY a, b"))
+ }
+ }
+
test("aggregates with nulls") {
checkAnswer(
sql("SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a)," +
@@ -2174,7 +2253,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
checkAnswer(
sql("SELECT * FROM testData2 ORDER BY 1 + 0 DESC, b ASC"),
sql("SELECT * FROM testData2 ORDER BY b ASC"))
-
checkAnswer(
sql("SELECT * FROM testData2 ORDER BY 1 DESC, b ASC"),
sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC"))