From b91873db0930c6fe885c27936e1243d5fabd03ed Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 20 Apr 2017 16:59:38 +0200 Subject: [SPARK-20409][SQL] fail early if aggregate function in GROUP BY ## What changes were proposed in this pull request? It's illegal to have aggregate function in GROUP BY, and we should fail at analysis phase, if this happens. ## How was this patch tested? new regression test Author: Wenchen Fan Closes #17704 from cloud-fan/minor. --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 14 ++++---------- .../apache/spark/sql/catalyst/analysis/CheckAnalysis.scala | 7 ++++++- .../resources/sql-tests/results/group-by-ordinal.sql.out | 4 ++-- .../org/apache/spark/sql/DataFrameAggregateSuite.scala | 7 +++++++ 4 files changed, 19 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d9f36f7f87..175bfb3e80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -966,7 +966,7 @@ class Analyzer( case p if !p.childrenResolved => p // Replace the index with the related attribute for ORDER BY, // which is a 1-base position of the projection list. - case s @ Sort(orders, global, child) + case Sort(orders, global, child) if orders.exists(_.child.isInstanceOf[UnresolvedOrdinal]) => val newOrders = orders map { case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering, _) => @@ -983,17 +983,11 @@ class Analyzer( // Replace the index with the corresponding expression in aggregateExpressions. The index is // a 1-base position of aggregateExpressions, which is output columns (select expression) - case a @ Aggregate(groups, aggs, child) if aggs.forall(_.resolved) && + case Aggregate(groups, aggs, child) if aggs.forall(_.resolved) && groups.exists(_.isInstanceOf[UnresolvedOrdinal]) => val newGroups = groups.map { - case ordinal @ UnresolvedOrdinal(index) if index > 0 && index <= aggs.size => - aggs(index - 1) match { - case e if ResolveAggregateFunctions.containsAggregate(e) => - ordinal.failAnalysis( - s"GROUP BY position $index is an aggregate function, and " + - "aggregate functions are not allowed in GROUP BY") - case o => o - } + case u @ UnresolvedOrdinal(index) if index > 0 && index <= aggs.size => + aggs(index - 1) case ordinal @ UnresolvedOrdinal(index) => ordinal.failAnalysis( s"GROUP BY position $index is not in select list " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index da0c6b098f..61797bc34d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -254,6 +254,11 @@ trait CheckAnalysis extends PredicateHelper { } def checkValidGroupingExprs(expr: Expression): Unit = { + if (expr.find(_.isInstanceOf[AggregateExpression]).isDefined) { + failAnalysis( + "aggregate functions are not allowed in GROUP BY, but found " + expr.sql) + } + // Check if the data type of expr is orderable. if (!RowOrdering.isOrderable(expr.dataType)) { failAnalysis( @@ -271,8 +276,8 @@ trait CheckAnalysis extends PredicateHelper { } } - aggregateExprs.foreach(checkValidAggregateExpression) groupingExprs.foreach(checkValidGroupingExprs) + aggregateExprs.foreach(checkValidAggregateExpression) case Sort(orders, _, _) => orders.foreach { order => diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out index c0930bbde6..d03681d0ea 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out @@ -122,7 +122,7 @@ select a, b, sum(b) from data group by 3 struct<> -- !query 11 output org.apache.spark.sql.AnalysisException -GROUP BY position 3 is an aggregate function, and aggregate functions are not allowed in GROUP BY; line 1 pos 39 +aggregate functions are not allowed in GROUP BY, but found sum(CAST(data.`b` AS BIGINT)); -- !query 12 @@ -131,7 +131,7 @@ select a, b, sum(b) + 2 from data group by 3 struct<> -- !query 12 output org.apache.spark.sql.AnalysisException -GROUP BY position 3 is an aggregate function, and aggregate functions are not allowed in GROUP BY; line 1 pos 43 +aggregate functions are not allowed in GROUP BY, but found (sum(CAST(data.`b` AS BIGINT)) + CAST(2 AS BIGINT)); -- !query 13 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index e7079120bb..8569c2d76b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -538,4 +538,11 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Seq(Row(3, 0, 0.0, 1, 5.0), Row(2, 1, 4.0, 0, 0.0)) ) } + + test("aggregate function in GROUP BY") { + val e = intercept[AnalysisException] { + testData.groupBy(sum($"key")).count() + } + assert(e.message.contains("aggregate functions are not allowed in GROUP BY")) + } } -- cgit v1.2.3