aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2017-04-20 16:59:38 +0200
committerHerman van Hovell <hvanhovell@databricks.com>2017-04-20 16:59:38 +0200
commitb91873db0930c6fe885c27936e1243d5fabd03ed (patch)
tree45459b54f7e1c9b502d3e957c9fe0290a4614e06
parentc6f62c5b8106534007df31ca8c460064b89b450b (diff)
downloadspark-b91873db0930c6fe885c27936e1243d5fabd03ed.tar.gz
spark-b91873db0930c6fe885c27936e1243d5fabd03ed.tar.bz2
spark-b91873db0930c6fe885c27936e1243d5fabd03ed.zip
[SPARK-20409][SQL] fail early if aggregate function in GROUP BY
## What changes were proposed in this pull request? It's illegal to have aggregate function in GROUP BY, and we should fail at analysis phase, if this happens. ## How was this patch tested? new regression test Author: Wenchen Fan <wenchen@databricks.com> Closes #17704 from cloud-fan/minor.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala14
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala7
-rw-r--r--sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala7
4 files changed, 19 insertions, 13 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index d9f36f7f87..175bfb3e80 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -966,7 +966,7 @@ class Analyzer(
case p if !p.childrenResolved => p
// Replace the index with the related attribute for ORDER BY,
// which is a 1-base position of the projection list.
- case s @ Sort(orders, global, child)
+ case Sort(orders, global, child)
if orders.exists(_.child.isInstanceOf[UnresolvedOrdinal]) =>
val newOrders = orders map {
case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering, _) =>
@@ -983,17 +983,11 @@ class Analyzer(
// Replace the index with the corresponding expression in aggregateExpressions. The index is
// a 1-base position of aggregateExpressions, which is output columns (select expression)
- case a @ Aggregate(groups, aggs, child) if aggs.forall(_.resolved) &&
+ case Aggregate(groups, aggs, child) if aggs.forall(_.resolved) &&
groups.exists(_.isInstanceOf[UnresolvedOrdinal]) =>
val newGroups = groups.map {
- case ordinal @ UnresolvedOrdinal(index) if index > 0 && index <= aggs.size =>
- aggs(index - 1) match {
- case e if ResolveAggregateFunctions.containsAggregate(e) =>
- ordinal.failAnalysis(
- s"GROUP BY position $index is an aggregate function, and " +
- "aggregate functions are not allowed in GROUP BY")
- case o => o
- }
+ case u @ UnresolvedOrdinal(index) if index > 0 && index <= aggs.size =>
+ aggs(index - 1)
case ordinal @ UnresolvedOrdinal(index) =>
ordinal.failAnalysis(
s"GROUP BY position $index is not in select list " +
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index da0c6b098f..61797bc34d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -254,6 +254,11 @@ trait CheckAnalysis extends PredicateHelper {
}
def checkValidGroupingExprs(expr: Expression): Unit = {
+ if (expr.find(_.isInstanceOf[AggregateExpression]).isDefined) {
+ failAnalysis(
+ "aggregate functions are not allowed in GROUP BY, but found " + expr.sql)
+ }
+
// Check if the data type of expr is orderable.
if (!RowOrdering.isOrderable(expr.dataType)) {
failAnalysis(
@@ -271,8 +276,8 @@ trait CheckAnalysis extends PredicateHelper {
}
}
- aggregateExprs.foreach(checkValidAggregateExpression)
groupingExprs.foreach(checkValidGroupingExprs)
+ aggregateExprs.foreach(checkValidAggregateExpression)
case Sort(orders, _, _) =>
orders.foreach { order =>
diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out
index c0930bbde6..d03681d0ea 100644
--- a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out
@@ -122,7 +122,7 @@ select a, b, sum(b) from data group by 3
struct<>
-- !query 11 output
org.apache.spark.sql.AnalysisException
-GROUP BY position 3 is an aggregate function, and aggregate functions are not allowed in GROUP BY; line 1 pos 39
+aggregate functions are not allowed in GROUP BY, but found sum(CAST(data.`b` AS BIGINT));
-- !query 12
@@ -131,7 +131,7 @@ select a, b, sum(b) + 2 from data group by 3
struct<>
-- !query 12 output
org.apache.spark.sql.AnalysisException
-GROUP BY position 3 is an aggregate function, and aggregate functions are not allowed in GROUP BY; line 1 pos 43
+aggregate functions are not allowed in GROUP BY, but found (sum(CAST(data.`b` AS BIGINT)) + CAST(2 AS BIGINT));
-- !query 13
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index e7079120bb..8569c2d76b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -538,4 +538,11 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
Seq(Row(3, 0, 0.0, 1, 5.0), Row(2, 1, 4.0, 0, 0.0))
)
}
+
+ test("aggregate function in GROUP BY") {
+ val e = intercept[AnalysisException] {
+ testData.groupBy(sum($"key")).count()
+ }
+ assert(e.message.contains("aggregate functions are not allowed in GROUP BY"))
+ }
}