From d1e8108854deba3de8e2d87eb4389d11fb17ee57 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 29 Jun 2016 19:08:36 +0800 Subject: [SPARK-16291][SQL] CheckAnalysis should capture nested aggregate functions that reference no input attributes ## What changes were proposed in this pull request? `MAX(COUNT(*))` is invalid since aggregate expression can't be nested within another aggregate expression. This case should be captured at analysis phase, but somehow sneaks off to runtime. The reason is that when checking aggregate expressions in `CheckAnalysis`, a checking branch treats all expressions that reference no input attributes as valid ones. However, `MAX(COUNT(*))` is translated into `MAX(COUNT(1))` at analysis phase and also references no input attribute. This PR fixes this issue by removing the aforementioned branch. ## How was this patch tested? New test case added in `AnalysisErrorSuite`. Author: Cheng Lian Closes #13968 from liancheng/spark-16291-nested-agg-functions. --- .../apache/spark/sql/catalyst/analysis/CheckAnalysis.scala | 1 - .../spark/sql/catalyst/analysis/AnalysisErrorSuite.scala | 12 +++++++++++- .../src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 4 +--- 3 files changed, 12 insertions(+), 5 deletions(-) (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index ac9693e079..7b30fcc6c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -206,7 +206,6 @@ trait CheckAnalysis extends PredicateHelper { "Add to group by or wrap in first() (or first_value) if you don't care " + "which value you get.") case e if groupingExprs.exists(_.semanticEquals(e)) => // OK - case e if e.references.isEmpty => // OK case e => e.children.foreach(checkValidAggregateExpression) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index a41383fbf6..a9cde1e19e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count, Max} import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData} @@ -162,6 +162,16 @@ class AnalysisErrorSuite extends AnalysisTest { UnspecifiedFrame)).as('window)), "Distinct window functions are not supported" :: Nil) + errorTest( + "nested aggregate functions", + testRelation.groupBy('a)( + AggregateExpression( + Max(AggregateExpression(Count(Literal(1)), Complete, isDistinct = false)), + Complete, + isDistinct = false)), + "not allowed to use an aggregate function in the argument of another aggregate function." :: Nil + ) + errorTest( "offset window function", testRelation2.select( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index b1dbf21d4b..084ba9b78e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -21,10 +21,8 @@ import java.math.MathContext import java.sql.Timestamp import org.apache.spark.AccumulatorSuite -import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.catalyst.catalog.{CatalogTestUtils, ExternalCatalog, SessionCatalog} -import org.apache.spark.sql.catalyst.expressions.{ExpressionDescription, SortOrder} +import org.apache.spark.sql.catalyst.expressions.SortOrder import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.aggregate -- cgit v1.2.3