aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorgatorsmile <gatorsmile@gmail.com>2016-03-25 12:55:58 +0800
committerWenchen Fan <wenchen@databricks.com>2016-03-25 12:55:58 +0800
commit05f652d6c2bbd764a1dd5a45301811e14519486f (patch)
tree46640388ac61767de19808901ba0b59de456bb92
parent0874ff3aade705a97f174b642c5db01711d214b3 (diff)
downloadspark-05f652d6c2bbd764a1dd5a45301811e14519486f.tar.gz
spark-05f652d6c2bbd764a1dd5a45301811e14519486f.tar.bz2
spark-05f652d6c2bbd764a1dd5a45301811e14519486f.zip
[SPARK-13957][SQL] Support Group By Ordinal in SQL
#### What changes were proposed in this pull request? This PR is to support group by position in SQL. For example, when users input the following query ```SQL select c1 as a, c2, c3, sum(*) from tbl group by 1, 3, c4 ``` The ordinals are recognized as the positions in the select list. Thus, `Analyzer` converts it to ```SQL select c1, c2, c3, sum(*) from tbl group by c1, c3, c4 ``` This is controlled by the config option `spark.sql.groupByOrdinal`. - When true, the ordinal numbers in group by clauses are treated as the position in the select list. - When false, the ordinal numbers are ignored. - Only convert integer literals (not foldable expressions). If found foldable expressions, ignore them. - When the positions specified in the group by clauses correspond to the aggregate functions in select list, output an exception message. - star is not allowed to use in the select list when users specify ordinals in group by Note: This PR is taken from https://github.com/apache/spark/pull/10731. When merging this PR, please give the credit to zhichao-li Also cc all the people who are involved in the previous discussion: rxin cloud-fan marmbrus yhuai hvanhovell adrian-wang chenghao-intel tejasapatil #### How was this patch tested? Added a few test cases for both positive and negative test cases. Author: gatorsmile <gatorsmile@gmail.com> Author: xiaoli <lixiao1983@gmail.com> Author: Xiao Li <xiaoli@Xiaos-MacBook-Pro.local> Closes #11846 from gatorsmile/groupByOrdinal.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala72
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala92
5 files changed, 156 insertions, 25 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
index e10ab9790d..d5ac01500b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
@@ -23,6 +23,7 @@ private[spark] trait CatalystConf {
def caseSensitiveAnalysis: Boolean
def orderByOrdinal: Boolean
+ def groupByOrdinal: Boolean
/**
* Returns the [[Resolver]] for the current configuration, which can be used to determin if two
@@ -48,11 +49,16 @@ object EmptyConf extends CatalystConf {
override def orderByOrdinal: Boolean = {
throw new UnsupportedOperationException
}
+ override def groupByOrdinal: Boolean = {
+ throw new UnsupportedOperationException
+ }
}
/** A CatalystConf that can be used for local testing. */
case class SimpleCatalystConf(
caseSensitiveAnalysis: Boolean,
- orderByOrdinal: Boolean = true)
+ orderByOrdinal: Boolean = true,
+ groupByOrdinal: Boolean = true)
+
extends CatalystConf {
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 07b0f5ee70..d0a31e7620 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -85,6 +85,7 @@ class Analyzer(
ResolveGroupingAnalytics ::
ResolvePivot ::
ResolveUpCast ::
+ ResolveOrdinalInOrderByAndGroupBy ::
ResolveSortReferences ::
ResolveGenerate ::
ResolveFunctions ::
@@ -385,7 +386,13 @@ class Analyzer(
p.copy(projectList = buildExpandedProjectList(p.projectList, p.child))
// If the aggregate function argument contains Stars, expand it.
case a: Aggregate if containsStar(a.aggregateExpressions) =>
- a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child))
+ if (conf.groupByOrdinal && a.groupingExpressions.exists(IntegerIndex.unapply(_).nonEmpty)) {
+ failAnalysis(
+ "Group by position: star is not allowed to use in the select list " +
+ "when using ordinals in group by")
+ } else {
+ a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child))
+ }
// If the script transformation input contains Stars, expand it.
case t: ScriptTransformation if containsStar(t.input) =>
t.copy(
@@ -634,21 +641,23 @@ class Analyzer(
}
}
- /**
- * In many dialects of SQL it is valid to sort by attributes that are not present in the SELECT
- * clause. This rule detects such queries and adds the required attributes to the original
- * projection, so that they will be available during sorting. Another projection is added to
- * remove these attributes after sorting.
- *
- * This rule also resolves the position number in sort references. This support is introduced
- * in Spark 2.0. Before Spark 2.0, the integers in Order By has no effect on output sorting.
- * - When the sort references are not integer but foldable expressions, ignore them.
- * - When spark.sql.orderByOrdinal is set to false, ignore the position numbers too.
- */
- object ResolveSortReferences extends Rule[LogicalPlan] {
+ /**
+ * In many dialects of SQL it is valid to use ordinal positions in order/sort by and group by
+ * clauses. This rule is to convert ordinal positions to the corresponding expressions in the
+ * select list. This support is introduced in Spark 2.0.
+ *
+ * - When the sort references or group by expressions are not integer but foldable expressions,
+ * just ignore them.
+ * - When spark.sql.orderByOrdinal/spark.sql.groupByOrdinal is set to false, ignore the position
+ * numbers too.
+ *
+ * Before the release of Spark 2.0, the literals in order/sort by and group by clauses
+ * have no effect on the results.
+ */
+ object ResolveOrdinalInOrderByAndGroupBy extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
- case s: Sort if !s.child.resolved => s
- // Replace the index with the related attribute for ORDER BY
+ case p if !p.childrenResolved => p
+ // Replace the index with the related attribute for ORDER BY,
// which is a 1-base position of the projection list.
case s @ Sort(orders, global, child)
if conf.orderByOrdinal && orders.exists(o => IntegerIndex.unapply(o.child).nonEmpty) =>
@@ -665,10 +674,41 @@ class Analyzer(
}
Sort(newOrders, global, child)
+ // Replace the index with the corresponding expression in aggregateExpressions. The index is
+ // a 1-base position of aggregateExpressions, which is output columns (select expression)
+ case a @ Aggregate(groups, aggs, child)
+ if conf.groupByOrdinal && aggs.forall(_.resolved) &&
+ groups.exists(IntegerIndex.unapply(_).nonEmpty) =>
+ val newGroups = groups.map {
+ case IntegerIndex(index) if index > 0 && index <= aggs.size =>
+ aggs(index - 1) match {
+ case e if ResolveAggregateFunctions.containsAggregate(e) =>
+ throw new UnresolvedException(a,
+ s"Group by position: the '$index'th column in the select contains an " +
+ s"aggregate function: ${e.sql}. Aggregate functions are not allowed in GROUP BY")
+ case o => o
+ }
+ case IntegerIndex(index) =>
+ throw new UnresolvedException(a,
+ s"Group by position: '$index' exceeds the size of the select list '${aggs.size}'.")
+ case o => o
+ }
+ Aggregate(newGroups, aggs, child)
+ }
+ }
+
+ /**
+ * In many dialects of SQL it is valid to sort by attributes that are not present in the SELECT
+ * clause. This rule detects such queries and adds the required attributes to the original
+ * projection, so that they will be available during sorting. Another projection is added to
+ * remove these attributes after sorting.
+ */
+ object ResolveSortReferences extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
// Skip sort with aggregate. This will be handled in ResolveAggregateFunctions
case sa @ Sort(_, _, child: Aggregate) => sa
- case s @ Sort(order, _, child) if !s.resolved =>
+ case s @ Sort(order, _, child) if !s.resolved && child.resolved =>
try {
val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder])
val requiredAttrs = AttributeSet(newOrder).filter(_.resolved)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index ada8424771..9c927077d0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -210,7 +210,8 @@ object Unions {
object IntegerIndex {
def unapply(a: Any): Option[Int] = a match {
case Literal(a: Int, IntegerType) => Some(a)
- // When resolving ordinal in Sort, negative values are extracted for issuing error messages.
+ // When resolving ordinal in Sort and Group By, negative values are extracted
+ // for issuing error messages.
case UnaryMinus(IntegerLiteral(v)) => Some(-v)
case _ => None
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 863a876afe..77af0e000b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -445,6 +445,11 @@ object SQLConf {
doc = "When true, the ordinal numbers are treated as the position in the select list. " +
"When false, the ordinal numbers in order/sort By clause are ignored.")
+ val GROUP_BY_ORDINAL = booleanConf("spark.sql.groupByOrdinal",
+ defaultValue = Some(true),
+ doc = "When true, the ordinal numbers in group by clauses are treated as the position " +
+ "in the select list. When false, the ordinal numbers are ignored.")
+
// The output committer class used by HadoopFsRelation. The specified class needs to be a
// subclass of org.apache.hadoop.mapreduce.OutputCommitter.
//
@@ -668,6 +673,7 @@ class SQLConf extends Serializable with CatalystConf with ParserConf with Loggin
override def orderByOrdinal: Boolean = getConf(ORDER_BY_ORDINAL)
+ override def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL)
/** ********************** SQLConf functionality methods ************ */
/** Set Spark SQL configuration properties. */
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index eb486a135f..61358fda76 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -23,6 +23,7 @@ import java.sql.Timestamp
import org.apache.spark.AccumulatorSuite
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.expressions.SortOrder
+import org.apache.spark.sql.catalyst.plans.logical.Aggregate
import org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, CartesianProduct, SortMergeJoin}
import org.apache.spark.sql.functions._
@@ -459,25 +460,103 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
Seq(Row(1, 3), Row(2, 3), Row(3, 3)))
}
- test("literal in agg grouping expressions") {
+ test("Group By Ordinal - basic") {
checkAnswer(
- sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"),
- Seq(Row(1, 2), Row(2, 2), Row(3, 2)))
- checkAnswer(
- sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"),
- Seq(Row(1, 2), Row(2, 2), Row(3, 2)))
+ sql("SELECT a, sum(b) FROM testData2 GROUP BY 1"),
+ sql("SELECT a, sum(b) FROM testData2 GROUP BY a"))
+ // duplicate group-by columns
checkAnswer(
sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1"),
sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a"))
+
+ checkAnswer(
+ sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY 1, 2"),
+ sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a"))
+ }
+
+ test("Group By Ordinal - non aggregate expressions") {
+ checkAnswer(
+ sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, 2"),
+ sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, b + 2"))
+
+ checkAnswer(
+ sql("SELECT a, b + 2 as c, count(2) FROM testData2 GROUP BY a, 2"),
+ sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, b + 2"))
+ }
+
+ test("Group By Ordinal - non-foldable constant expression") {
+ checkAnswer(
+ sql("SELECT a, b, sum(b) FROM testData2 GROUP BY a, b, 1 + 0"),
+ sql("SELECT a, b, sum(b) FROM testData2 GROUP BY a, b"))
+
checkAnswer(
sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"),
sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a"))
+ }
+
+ test("Group By Ordinal - alias") {
+ checkAnswer(
+ sql("SELECT a, (b + 2) as c, count(2) FROM testData2 GROUP BY a, 2"),
+ sql("SELECT a, b + 2, count(2) FROM testData2 GROUP BY a, b + 2"))
+
+ checkAnswer(
+ sql("SELECT a as b, b as a, sum(b) FROM testData2 GROUP BY 1, 2"),
+ sql("SELECT a, b, sum(b) FROM testData2 GROUP BY a, b"))
+ }
+
+ test("Group By Ordinal - constants") {
checkAnswer(
sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY 1, 2"),
sql("SELECT 1, 2, sum(b) FROM testData2"))
}
+ test("Group By Ordinal - negative cases") {
+ intercept[UnresolvedException[Aggregate]] {
+ sql("SELECT a, b FROM testData2 GROUP BY -1")
+ }
+
+ intercept[UnresolvedException[Aggregate]] {
+ sql("SELECT a, b FROM testData2 GROUP BY 3")
+ }
+
+ var e = intercept[UnresolvedException[Aggregate]](
+ sql("SELECT SUM(a) FROM testData2 GROUP BY 1"))
+ assert(e.getMessage contains
+ "Invalid call to Group by position: the '1'th column in the select contains " +
+ "an aggregate function")
+
+ e = intercept[UnresolvedException[Aggregate]](
+ sql("SELECT SUM(a) + 1 FROM testData2 GROUP BY 1"))
+ assert(e.getMessage contains
+ "Invalid call to Group by position: the '1'th column in the select contains " +
+ "an aggregate function")
+
+ var ae = intercept[AnalysisException](
+ sql("SELECT a, rand(0), sum(b) FROM testData2 GROUP BY a, 2"))
+ assert(ae.getMessage contains
+ "nondeterministic expression rand(0) should not appear in grouping expression")
+
+ ae = intercept[AnalysisException](
+ sql("SELECT * FROM testData2 GROUP BY a, b, 1"))
+ assert(ae.getMessage contains
+ "Group by position: star is not allowed to use in the select list " +
+ "when using ordinals in group by")
+ }
+
+ test("Group By Ordinal: spark.sql.groupByOrdinal=false") {
+ withSQLConf(SQLConf.GROUP_BY_ORDINAL.key -> "false") {
+ // If spark.sql.groupByOrdinal=false, ignore the position number.
+ intercept[AnalysisException] {
+ sql("SELECT a, sum(b) FROM testData2 GROUP BY 1")
+ }
+ // '*' is not allowed to use in the select list when users specify ordinals in group by
+ checkAnswer(
+ sql("SELECT * FROM testData2 GROUP BY a, b, 1"),
+ sql("SELECT * FROM testData2 GROUP BY a, b"))
+ }
+ }
+
test("aggregates with nulls") {
checkAnswer(
sql("SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a)," +
@@ -2174,7 +2253,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
checkAnswer(
sql("SELECT * FROM testData2 ORDER BY 1 + 0 DESC, b ASC"),
sql("SELECT * FROM testData2 ORDER BY b ASC"))
-
checkAnswer(
sql("SELECT * FROM testData2 ORDER BY 1 DESC, b ASC"),
sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC"))