aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorCheng Lian <lian.cs.zju@gmail.com>2014-10-13 13:36:39 -0700
committerMichael Armbrust <michael@databricks.com>2014-10-13 13:36:39 -0700
commit56102dc2d849c221f325a7888cd66abb640ec887 (patch)
tree1094decdc8659750a4f3e5b31b7aa9abed43a323 /sql
parent2ac40da3f9fa6d45a59bb45b41606f1931ac5e81 (diff)
downloadspark-56102dc2d849c221f325a7888cd66abb640ec887.tar.gz
spark-56102dc2d849c221f325a7888cd66abb640ec887.tar.bz2
spark-56102dc2d849c221f325a7888cd66abb640ec887.zip
[SPARK-2066][SQL] Adds checks for non-aggregate attributes with aggregation
This PR adds a new rule `CheckAggregation` to the analyzer to provide better error message for non-aggregate attributes with aggregation. Author: Cheng Lian <lian.cs.zju@gmail.com> Closes #2774 from liancheng/non-aggregate-attr and squashes the following commits: 5246004 [Cheng Lian] Passes test suites bf1878d [Cheng Lian] Adds checks for non-aggregate attributes with aggregation
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala36
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala26
2 files changed, 57 insertions, 5 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index fe83eb1250..8255306314 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -63,7 +63,8 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
typeCoercionRules ++
extendedRules : _*),
Batch("Check Analysis", Once,
- CheckResolution),
+ CheckResolution,
+ CheckAggregation),
Batch("AnalysisOperators", fixedPoint,
EliminateAnalysisOperators)
)
@@ -89,6 +90,32 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
}
/**
+ * Checks for non-aggregated attributes with aggregation
+ */
+ object CheckAggregation extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ plan.transform {
+ case aggregatePlan @ Aggregate(groupingExprs, aggregateExprs, child) =>
+ def isValidAggregateExpression(expr: Expression): Boolean = expr match {
+ case _: AggregateExpression => true
+ case e: Attribute => groupingExprs.contains(e)
+ case e if groupingExprs.contains(e) => true
+ case e if e.references.isEmpty => true
+ case e => e.children.forall(isValidAggregateExpression)
+ }
+
+ aggregateExprs.foreach { e =>
+ if (!isValidAggregateExpression(e)) {
+ throw new TreeNodeException(plan, s"Expression not in GROUP BY: $e")
+ }
+ }
+
+ aggregatePlan
+ }
+ }
+ }
+
+ /**
* Replaces [[UnresolvedRelation]]s with concrete relations from the catalog.
*/
object ResolveRelations extends Rule[LogicalPlan] {
@@ -204,18 +231,17 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
*/
object UnresolvedHavingClauseAttributes extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
- case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _))
+ case filter @ Filter(havingCondition, aggregate @ Aggregate(_, originalAggExprs, _))
if aggregate.resolved && containsAggregate(havingCondition) => {
val evaluatedCondition = Alias(havingCondition, "havingCondition")()
val aggExprsWithHaving = evaluatedCondition +: originalAggExprs
-
+
Project(aggregate.output,
Filter(evaluatedCondition.toAttribute,
aggregate.copy(aggregateExpressions = aggExprsWithHaving)))
}
-
}
-
+
protected def containsAggregate(condition: Expression): Boolean =
condition
.collect { case ae: AggregateExpression => ae }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index a94022c0cf..15f6ba4f72 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.joins.BroadcastHashJoin
import org.apache.spark.sql.test._
import org.scalatest.BeforeAndAfterAll
@@ -694,4 +695,29 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
checkAnswer(
sql("SELECT CASE WHEN key = 1 THEN 1 ELSE 2 END FROM testData WHERE key = 1 group by key"), 1)
}
+
+ test("throw errors for non-aggregate attributes with aggregation") {
+ def checkAggregation(query: String, isInvalidQuery: Boolean = true) {
+ val logicalPlan = sql(query).queryExecution.logical
+
+ if (isInvalidQuery) {
+ val e = intercept[TreeNodeException[LogicalPlan]](sql(query).queryExecution.analyzed)
+ assert(
+ e.getMessage.startsWith("Expression not in GROUP BY"),
+ "Non-aggregate attribute(s) not detected\n" + logicalPlan)
+ } else {
+ // Should not throw
+ sql(query).queryExecution.analyzed
+ }
+ }
+
+ checkAggregation("SELECT key, COUNT(*) FROM testData")
+ checkAggregation("SELECT COUNT(key), COUNT(*) FROM testData", false)
+
+ checkAggregation("SELECT value, COUNT(*) FROM testData GROUP BY key")
+ checkAggregation("SELECT COUNT(value), SUM(key) FROM testData GROUP BY key", false)
+
+ checkAggregation("SELECT key + 2, COUNT(*) FROM testData GROUP BY key + 1")
+ checkAggregation("SELECT key + 1 + 1, COUNT(*) FROM testData GROUP BY key + 1", false)
+ }
}