aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-04-07 11:51:34 -0700
committerDavies Liu <davies.liu@gmail.com>2016-04-07 11:51:34 -0700
commitaa852215f82876977d164f371627e894e86baacc (patch)
treeb62a001aa6d74fda188a12f96dde53230b94ed93
parent8dcb0c7c974e9707933ac2ae6ce837e765a5e81a (diff)
downloadspark-aa852215f82876977d164f371627e894e86baacc.tar.gz
spark-aa852215f82876977d164f371627e894e86baacc.tar.bz2
spark-aa852215f82876977d164f371627e894e86baacc.zip
[SPARK-12740] [SPARK-13932] support grouping()/grouping_id() in having/order clause
## What changes were proposed in this pull request? This PR brings the support of using grouping()/grouping_id() in HAVING/ORDER BY clause. The resolved grouping()/grouping_id() will be replaced by unresolved "spark_gropuing_id" virtual attribute, then resolved by ResolveMissingAttribute. This PR also fix the HAVING clause that access a grouping column that is not presented in SELECT clause, for example: ```sql select count(1) from (select 1 as a) t group by a having a > 0 ``` ## How was this patch tested? Add new tests. Author: Davies Liu <davies@databricks.com> Closes #12235 from davies/grouping_having.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala181
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala82
3 files changed, 211 insertions, 56 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index bc8cf4e78a..7bcba421fd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -87,7 +87,7 @@ class Analyzer(
ResolveGroupingAnalytics ::
ResolvePivot ::
ResolveOrdinalInOrderByAndGroupBy ::
- ResolveSortReferences ::
+ ResolveMissingReferences ::
ResolveGenerate ::
ResolveFunctions ::
ResolveAliases ::
@@ -228,21 +228,56 @@ class Analyzer(
Seq.tabulate(1 << c.groupByExprs.length)(i => i)
}
- private def hasGroupingId(expr: Seq[Expression]): Boolean = {
- expr.exists(_.collectFirst {
- case u: UnresolvedAttribute if resolver(u.name, VirtualColumn.groupingIdName) => u
- }.isDefined)
+ private def hasGroupingAttribute(expr: Expression): Boolean = {
+ expr.collectFirst {
+ case u: UnresolvedAttribute if resolver(u.name, VirtualColumn.hiveGroupingIdName) => u
+ }.isDefined
}
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ private def hasGroupingFunction(e: Expression): Boolean = {
+ e.collectFirst {
+ case g: Grouping => g
+ case g: GroupingID => g
+ }.isDefined
+ }
+
+ private def replaceGroupingFunc(
+ expr: Expression,
+ groupByExprs: Seq[Expression],
+ gid: Expression): Expression = {
+ expr transform {
+ case e: GroupingID =>
+ if (e.groupByExprs.isEmpty || e.groupByExprs == groupByExprs) {
+ gid
+ } else {
+ throw new AnalysisException(
+ s"Columns of grouping_id (${e.groupByExprs.mkString(",")}) does not match " +
+ s"grouping columns (${groupByExprs.mkString(",")})")
+ }
+ case Grouping(col: Expression) =>
+ val idx = groupByExprs.indexOf(col)
+ if (idx >= 0) {
+ Cast(BitwiseAnd(ShiftRight(gid, Literal(groupByExprs.length - 1 - idx)),
+ Literal(1)), ByteType)
+ } else {
+ throw new AnalysisException(s"Column of grouping ($col) can't be found " +
+ s"in grouping columns ${groupByExprs.mkString(",")}")
+ }
+ }
+ }
+
+ // This require transformUp to replace grouping()/grouping_id() in resolved Filter/Sort
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case a if !a.childrenResolved => a // be sure all of the children are resolved.
+ case p if p.expressions.exists(hasGroupingAttribute) =>
+ failAnalysis(
+ s"${VirtualColumn.hiveGroupingIdName} is deprecated; use grouping_id() instead")
+
case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) =>
GroupingSets(bitmasks(c), groupByExprs, child, aggregateExpressions)
case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) =>
GroupingSets(bitmasks(r), groupByExprs, child, aggregateExpressions)
- case g: GroupingSets if g.expressions.exists(!_.resolved) && hasGroupingId(g.expressions) =>
- failAnalysis(
- s"${VirtualColumn.groupingIdName} is deprecated; use grouping_id() instead")
+
// Ensure all the expressions have been resolved.
case x: GroupingSets if x.expressions.forall(_.resolved) =>
val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)()
@@ -270,7 +305,7 @@ class Analyzer(
def isPartOfAggregation(e: Expression): Boolean = {
aggsBuffer.exists(a => a.find(_ eq e).isDefined)
}
- expr.transformDown {
+ replaceGroupingFunc(expr, x.groupByExprs, gid).transformDown {
// AggregateExpression should be computed on the unmodified value of its argument
// expressions, so we should not replace any references to grouping expression
// inside it.
@@ -278,23 +313,6 @@ class Analyzer(
aggsBuffer += e
e
case e if isPartOfAggregation(e) => e
- case e: GroupingID =>
- if (e.groupByExprs.isEmpty || e.groupByExprs == x.groupByExprs) {
- gid
- } else {
- throw new AnalysisException(
- s"Columns of grouping_id (${e.groupByExprs.mkString(",")}) does not match " +
- s"grouping columns (${x.groupByExprs.mkString(",")})")
- }
- case Grouping(col: Expression) =>
- val idx = x.groupByExprs.indexOf(col)
- if (idx >= 0) {
- Cast(BitwiseAnd(ShiftRight(gid, Literal(x.groupByExprs.length - 1 - idx)),
- Literal(1)), ByteType)
- } else {
- throw new AnalysisException(s"Column of grouping ($col) can't be found " +
- s"in grouping columns ${x.groupByExprs.mkString(",")}")
- }
case e =>
val index = groupByAliases.indexWhere(_.child.semanticEquals(e))
if (index == -1) {
@@ -306,9 +324,37 @@ class Analyzer(
}
Aggregate(
- groupByAttributes :+ VirtualColumn.groupingIdAttribute,
+ groupByAttributes :+ gid,
aggregations,
Expand(x.bitmasks, groupByAliases, groupByAttributes, gid, x.child))
+
+ case f @ Filter(cond, child) if hasGroupingFunction(cond) =>
+ val groupingExprs = findGroupingExprs(child)
+ // The unresolved grouping id will be resolved by ResolveMissingReferences
+ val newCond = replaceGroupingFunc(cond, groupingExprs, VirtualColumn.groupingIdAttribute)
+ f.copy(condition = newCond)
+
+ case s @ Sort(order, _, child) if order.exists(hasGroupingFunction) =>
+ val groupingExprs = findGroupingExprs(child)
+ val gid = VirtualColumn.groupingIdAttribute
+ // The unresolved grouping id will be resolved by ResolveMissingReferences
+ val newOrder = order.map(replaceGroupingFunc(_, groupingExprs, gid).asInstanceOf[SortOrder])
+ s.copy(order = newOrder)
+ }
+
+ private def findGroupingExprs(plan: LogicalPlan): Seq[Expression] = {
+ plan.collectFirst {
+ case a: Aggregate =>
+ // this Aggregate should have grouping id as the last grouping key.
+ val gid = a.groupingExpressions.last
+ if (!gid.isInstanceOf[AttributeReference]
+ || gid.asInstanceOf[AttributeReference].name != VirtualColumn.groupingIdName) {
+ failAnalysis(s"grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")
+ }
+ a.groupingExpressions.take(a.groupingExpressions.length - 1)
+ }.getOrElse {
+ failAnalysis(s"grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")
+ }
}
}
@@ -663,13 +709,15 @@ class Analyzer(
* clause. This rule detects such queries and adds the required attributes to the original
* projection, so that they will be available during sorting. Another projection is added to
* remove these attributes after sorting.
+ *
+ * The HAVING clause could also used a grouping columns that is not presented in the SELECT.
*/
- object ResolveSortReferences extends Rule[LogicalPlan] {
+ object ResolveMissingReferences extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
// Skip sort with aggregate. This will be handled in ResolveAggregateFunctions
case sa @ Sort(_, _, child: Aggregate) => sa
- case s @ Sort(order, _, child) if !s.resolved && child.resolved =>
+ case s @ Sort(order, _, child) if child.resolved =>
try {
val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder])
val requiredAttrs = AttributeSet(newOrder).filter(_.resolved)
@@ -689,6 +737,26 @@ class Analyzer(
// in Sort
case ae: AnalysisException => s
}
+
+ case f @ Filter(cond, child) if child.resolved =>
+ try {
+ val newCond = resolveExpressionRecursively(cond, child)
+ val requiredAttrs = newCond.references.filter(_.resolved)
+ val missingAttrs = requiredAttrs -- child.outputSet
+ if (missingAttrs.nonEmpty) {
+ // Add missing attributes and then project them away.
+ Project(child.output,
+ Filter(newCond, addMissingAttr(child, missingAttrs)))
+ } else if (newCond != cond) {
+ f.copy(condition = newCond)
+ } else {
+ f
+ }
+ } catch {
+ // Attempting to resolve it might fail. When this happens, return the original plan.
+ // Users will see an AnalysisException for resolution failure of missing attributes
+ case ae: AnalysisException => f
+ }
}
/**
@@ -843,27 +911,33 @@ class Analyzer(
if aggregate.resolved =>
// Try resolving the condition of the filter as though it is in the aggregate clause
- val aggregatedCondition =
- Aggregate(
- grouping,
- Alias(havingCondition, "havingCondition")(isGenerated = true) :: Nil,
- child)
- val resolvedOperator = execute(aggregatedCondition)
- def resolvedAggregateFilter =
- resolvedOperator
- .asInstanceOf[Aggregate]
- .aggregateExpressions.head
-
- // If resolution was successful and we see the filter has an aggregate in it, add it to
- // the original aggregate operator.
- if (resolvedOperator.resolved && containsAggregate(resolvedAggregateFilter)) {
- val aggExprsWithHaving = resolvedAggregateFilter +: originalAggExprs
-
- Project(aggregate.output,
- Filter(resolvedAggregateFilter.toAttribute,
- aggregate.copy(aggregateExpressions = aggExprsWithHaving)))
- } else {
- filter
+ try {
+ val aggregatedCondition =
+ Aggregate(
+ grouping,
+ Alias(havingCondition, "havingCondition")(isGenerated = true) :: Nil,
+ child)
+ val resolvedOperator = execute(aggregatedCondition)
+ def resolvedAggregateFilter =
+ resolvedOperator
+ .asInstanceOf[Aggregate]
+ .aggregateExpressions.head
+
+ // If resolution was successful and we see the filter has an aggregate in it, add it to
+ // the original aggregate operator.
+ if (resolvedOperator.resolved && containsAggregate(resolvedAggregateFilter)) {
+ val aggExprsWithHaving = resolvedAggregateFilter +: originalAggExprs
+
+ Project(aggregate.output,
+ Filter(resolvedAggregateFilter.toAttribute,
+ aggregate.copy(aggregateExpressions = aggExprsWithHaving)))
+ } else {
+ filter
+ }
+ } catch {
+ // Attempting to resolve in the aggregate can result in ambiguity. When this happens,
+ // just return the original plan.
+ case ae: AnalysisException => filter
}
case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved =>
@@ -927,11 +1001,8 @@ class Analyzer(
}
}
- private def isAggregateExpression(e: Expression): Boolean = {
- e.isInstanceOf[AggregateExpression] || e.isInstanceOf[Grouping] || e.isInstanceOf[GroupingID]
- }
def containsAggregate(condition: Expression): Boolean = {
- condition.find(isAggregateExpression).isDefined
+ condition.find(_.isInstanceOf[AggregateExpression]).isDefined
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 2307122ea1..78310fb2f1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -333,6 +333,8 @@ case class PrettyAttribute(
}
object VirtualColumn {
- val groupingIdName: String = "grouping__id"
+ // The attribute name used by Hive, which has different result than Spark, deprecated.
+ val hiveGroupingIdName: String = "grouping__id"
+ val groupingIdName: String = "spark_grouping_id"
val groupingIdAttribute: UnresolvedAttribute = UnresolvedAttribute(groupingIdName)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 2ab7c1581c..dd648cdb81 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -2230,6 +2230,88 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
assert(error.getMessage contains "grouping__id is deprecated; use grouping_id() instead")
}
+ test("grouping and grouping_id in having") {
+ checkAnswer(
+ sql("select course, year from courseSales group by cube(course, year)" +
+ " having grouping(year) = 1 and grouping_id(course, year) > 0"),
+ Row("Java", null) ::
+ Row("dotNET", null) ::
+ Row(null, null) :: Nil
+ )
+
+ var error = intercept[AnalysisException] {
+ sql("select course, year from courseSales group by course, year" +
+ " having grouping(course) > 0")
+ }
+ assert(error.getMessage contains
+ "grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")
+ error = intercept[AnalysisException] {
+ sql("select course, year from courseSales group by course, year" +
+ " having grouping_id(course, year) > 0")
+ }
+ assert(error.getMessage contains
+ "grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")
+ error = intercept[AnalysisException] {
+ sql("select course, year from courseSales group by cube(course, year)" +
+ " having grouping__id > 0")
+ }
+ assert(error.getMessage contains "grouping__id is deprecated; use grouping_id() instead")
+ }
+
+ test("grouping and grouping_id in sort") {
+ checkAnswer(
+ sql("select course, year, grouping(course), grouping(year) from courseSales" +
+ " group by cube(course, year) order by grouping_id(course, year), course, year"),
+ Row("Java", 2012, 0, 0) ::
+ Row("Java", 2013, 0, 0) ::
+ Row("dotNET", 2012, 0, 0) ::
+ Row("dotNET", 2013, 0, 0) ::
+ Row("Java", null, 0, 1) ::
+ Row("dotNET", null, 0, 1) ::
+ Row(null, 2012, 1, 0) ::
+ Row(null, 2013, 1, 0) ::
+ Row(null, null, 1, 1) :: Nil
+ )
+
+ checkAnswer(
+ sql("select course, year, grouping_id(course, year) from courseSales" +
+ " group by cube(course, year) order by grouping(course), grouping(year), course, year"),
+ Row("Java", 2012, 0) ::
+ Row("Java", 2013, 0) ::
+ Row("dotNET", 2012, 0) ::
+ Row("dotNET", 2013, 0) ::
+ Row("Java", null, 1) ::
+ Row("dotNET", null, 1) ::
+ Row(null, 2012, 2) ::
+ Row(null, 2013, 2) ::
+ Row(null, null, 3) :: Nil
+ )
+
+ var error = intercept[AnalysisException] {
+ sql("select course, year from courseSales group by course, year" +
+ " order by grouping(course)")
+ }
+ assert(error.getMessage contains
+ "grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")
+ error = intercept[AnalysisException] {
+ sql("select course, year from courseSales group by course, year" +
+ " order by grouping_id(course, year)")
+ }
+ assert(error.getMessage contains
+ "grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")
+ error = intercept[AnalysisException] {
+ sql("select course, year from courseSales group by cube(course, year)" +
+ " order by grouping__id")
+ }
+ assert(error.getMessage contains "grouping__id is deprecated; use grouping_id() instead")
+ }
+
+ test("filter on a grouping column that is not presented in SELECT") {
+ checkAnswer(
+ sql("select count(1) from (select 1 as a) t group by a having a > 0"),
+ Row(1) :: Nil)
+ }
+
test("SPARK-13056: Null in map value causes NPE") {
val df = Seq(1 -> Map("abc" -> "somestring", "cba" -> null)).toDF("key", "value")
withTempTable("maptest") {