From 56102dc2d849c221f325a7888cd66abb640ec887 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 13 Oct 2014 13:36:39 -0700 Subject: [SPARK-2066][SQL] Adds checks for non-aggregate attributes with aggregation This PR adds a new rule `CheckAggregation` to the analyzer to provide better error message for non-aggregate attributes with aggregation. Author: Cheng Lian Closes #2774 from liancheng/non-aggregate-attr and squashes the following commits: 5246004 [Cheng Lian] Passes test suites bf1878d [Cheng Lian] Adds checks for non-aggregate attributes with aggregation --- .../spark/sql/catalyst/analysis/Analyzer.scala | 36 +++++++++++++++++++--- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 26 ++++++++++++++++ 2 files changed, 57 insertions(+), 5 deletions(-) (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index fe83eb1250..8255306314 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -63,7 +63,8 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool typeCoercionRules ++ extendedRules : _*), Batch("Check Analysis", Once, - CheckResolution), + CheckResolution, + CheckAggregation), Batch("AnalysisOperators", fixedPoint, EliminateAnalysisOperators) ) @@ -88,6 +89,32 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool } } + /** + * Checks for non-aggregated attributes with aggregation + */ + object CheckAggregation extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = { + plan.transform { + case aggregatePlan @ Aggregate(groupingExprs, aggregateExprs, child) => + def isValidAggregateExpression(expr: Expression): Boolean = expr match { + case _: AggregateExpression => true + case e: Attribute => groupingExprs.contains(e) + case e if groupingExprs.contains(e) => true + case e if e.references.isEmpty => true + case e => e.children.forall(isValidAggregateExpression) + } + + aggregateExprs.foreach { e => + if (!isValidAggregateExpression(e)) { + throw new TreeNodeException(plan, s"Expression not in GROUP BY: $e") + } + } + + aggregatePlan + } + } + } + /** * Replaces [[UnresolvedRelation]]s with concrete relations from the catalog. */ @@ -204,18 +231,17 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool */ object UnresolvedHavingClauseAttributes extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _)) + case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _)) if aggregate.resolved && containsAggregate(havingCondition) => { val evaluatedCondition = Alias(havingCondition, "havingCondition")() val aggExprsWithHaving = evaluatedCondition +: originalAggExprs - + Project(aggregate.output, Filter(evaluatedCondition.toAttribute, aggregate.copy(aggregateExpressions = aggExprsWithHaving))) } - } - + protected def containsAggregate(condition: Expression): Boolean = condition .collect { case ae: AggregateExpression => ae } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index a94022c0cf..15f6ba4f72 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.joins.BroadcastHashJoin import org.apache.spark.sql.test._ import org.scalatest.BeforeAndAfterAll @@ -694,4 +695,29 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { checkAnswer( sql("SELECT CASE WHEN key = 1 THEN 1 ELSE 2 END FROM testData WHERE key = 1 group by key"), 1) } + + test("throw errors for non-aggregate attributes with aggregation") { + def checkAggregation(query: String, isInvalidQuery: Boolean = true) { + val logicalPlan = sql(query).queryExecution.logical + + if (isInvalidQuery) { + val e = intercept[TreeNodeException[LogicalPlan]](sql(query).queryExecution.analyzed) + assert( + e.getMessage.startsWith("Expression not in GROUP BY"), + "Non-aggregate attribute(s) not detected\n" + logicalPlan) + } else { + // Should not throw + sql(query).queryExecution.analyzed + } + } + + checkAggregation("SELECT key, COUNT(*) FROM testData") + checkAggregation("SELECT COUNT(key), COUNT(*) FROM testData", false) + + checkAggregation("SELECT value, COUNT(*) FROM testData GROUP BY key") + checkAggregation("SELECT COUNT(value), SUM(key) FROM testData GROUP BY key", false) + + checkAggregation("SELECT key + 2, COUNT(*) FROM testData GROUP BY key + 1") + checkAggregation("SELECT key + 1 + 1, COUNT(*) FROM testData GROUP BY key + 1", false) + } } -- cgit v1.2.3