diff options
Diffstat (limited to 'sql/catalyst/src')
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 181 | ||||
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala | 4 |
2 files changed, 129 insertions, 56 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index bc8cf4e78a..7bcba421fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -87,7 +87,7 @@ class Analyzer( ResolveGroupingAnalytics :: ResolvePivot :: ResolveOrdinalInOrderByAndGroupBy :: - ResolveSortReferences :: + ResolveMissingReferences :: ResolveGenerate :: ResolveFunctions :: ResolveAliases :: @@ -228,21 +228,56 @@ class Analyzer( Seq.tabulate(1 << c.groupByExprs.length)(i => i) } - private def hasGroupingId(expr: Seq[Expression]): Boolean = { - expr.exists(_.collectFirst { - case u: UnresolvedAttribute if resolver(u.name, VirtualColumn.groupingIdName) => u - }.isDefined) + private def hasGroupingAttribute(expr: Expression): Boolean = { + expr.collectFirst { + case u: UnresolvedAttribute if resolver(u.name, VirtualColumn.hiveGroupingIdName) => u + }.isDefined } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + private def hasGroupingFunction(e: Expression): Boolean = { + e.collectFirst { + case g: Grouping => g + case g: GroupingID => g + }.isDefined + } + + private def replaceGroupingFunc( + expr: Expression, + groupByExprs: Seq[Expression], + gid: Expression): Expression = { + expr transform { + case e: GroupingID => + if (e.groupByExprs.isEmpty || e.groupByExprs == groupByExprs) { + gid + } else { + throw new AnalysisException( + s"Columns of grouping_id (${e.groupByExprs.mkString(",")}) does not match " + + s"grouping columns (${groupByExprs.mkString(",")})") + } + case Grouping(col: Expression) => + val idx = groupByExprs.indexOf(col) + if (idx >= 0) { + Cast(BitwiseAnd(ShiftRight(gid, Literal(groupByExprs.length - 1 - idx)), + Literal(1)), ByteType) + } else { + throw new AnalysisException(s"Column of grouping ($col) can't be found " + + s"in grouping columns ${groupByExprs.mkString(",")}") + } + } + } + + // This require transformUp to replace grouping()/grouping_id() in resolved Filter/Sort + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case a if !a.childrenResolved => a // be sure all of the children are resolved. + case p if p.expressions.exists(hasGroupingAttribute) => + failAnalysis( + s"${VirtualColumn.hiveGroupingIdName} is deprecated; use grouping_id() instead") + case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) => GroupingSets(bitmasks(c), groupByExprs, child, aggregateExpressions) case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) => GroupingSets(bitmasks(r), groupByExprs, child, aggregateExpressions) - case g: GroupingSets if g.expressions.exists(!_.resolved) && hasGroupingId(g.expressions) => - failAnalysis( - s"${VirtualColumn.groupingIdName} is deprecated; use grouping_id() instead") + // Ensure all the expressions have been resolved. case x: GroupingSets if x.expressions.forall(_.resolved) => val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)() @@ -270,7 +305,7 @@ class Analyzer( def isPartOfAggregation(e: Expression): Boolean = { aggsBuffer.exists(a => a.find(_ eq e).isDefined) } - expr.transformDown { + replaceGroupingFunc(expr, x.groupByExprs, gid).transformDown { // AggregateExpression should be computed on the unmodified value of its argument // expressions, so we should not replace any references to grouping expression // inside it. @@ -278,23 +313,6 @@ class Analyzer( aggsBuffer += e e case e if isPartOfAggregation(e) => e - case e: GroupingID => - if (e.groupByExprs.isEmpty || e.groupByExprs == x.groupByExprs) { - gid - } else { - throw new AnalysisException( - s"Columns of grouping_id (${e.groupByExprs.mkString(",")}) does not match " + - s"grouping columns (${x.groupByExprs.mkString(",")})") - } - case Grouping(col: Expression) => - val idx = x.groupByExprs.indexOf(col) - if (idx >= 0) { - Cast(BitwiseAnd(ShiftRight(gid, Literal(x.groupByExprs.length - 1 - idx)), - Literal(1)), ByteType) - } else { - throw new AnalysisException(s"Column of grouping ($col) can't be found " + - s"in grouping columns ${x.groupByExprs.mkString(",")}") - } case e => val index = groupByAliases.indexWhere(_.child.semanticEquals(e)) if (index == -1) { @@ -306,9 +324,37 @@ class Analyzer( } Aggregate( - groupByAttributes :+ VirtualColumn.groupingIdAttribute, + groupByAttributes :+ gid, aggregations, Expand(x.bitmasks, groupByAliases, groupByAttributes, gid, x.child)) + + case f @ Filter(cond, child) if hasGroupingFunction(cond) => + val groupingExprs = findGroupingExprs(child) + // The unresolved grouping id will be resolved by ResolveMissingReferences + val newCond = replaceGroupingFunc(cond, groupingExprs, VirtualColumn.groupingIdAttribute) + f.copy(condition = newCond) + + case s @ Sort(order, _, child) if order.exists(hasGroupingFunction) => + val groupingExprs = findGroupingExprs(child) + val gid = VirtualColumn.groupingIdAttribute + // The unresolved grouping id will be resolved by ResolveMissingReferences + val newOrder = order.map(replaceGroupingFunc(_, groupingExprs, gid).asInstanceOf[SortOrder]) + s.copy(order = newOrder) + } + + private def findGroupingExprs(plan: LogicalPlan): Seq[Expression] = { + plan.collectFirst { + case a: Aggregate => + // this Aggregate should have grouping id as the last grouping key. + val gid = a.groupingExpressions.last + if (!gid.isInstanceOf[AttributeReference] + || gid.asInstanceOf[AttributeReference].name != VirtualColumn.groupingIdName) { + failAnalysis(s"grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup") + } + a.groupingExpressions.take(a.groupingExpressions.length - 1) + }.getOrElse { + failAnalysis(s"grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup") + } } } @@ -663,13 +709,15 @@ class Analyzer( * clause. This rule detects such queries and adds the required attributes to the original * projection, so that they will be available during sorting. Another projection is added to * remove these attributes after sorting. + * + * The HAVING clause could also used a grouping columns that is not presented in the SELECT. */ - object ResolveSortReferences extends Rule[LogicalPlan] { + object ResolveMissingReferences extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions case sa @ Sort(_, _, child: Aggregate) => sa - case s @ Sort(order, _, child) if !s.resolved && child.resolved => + case s @ Sort(order, _, child) if child.resolved => try { val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder]) val requiredAttrs = AttributeSet(newOrder).filter(_.resolved) @@ -689,6 +737,26 @@ class Analyzer( // in Sort case ae: AnalysisException => s } + + case f @ Filter(cond, child) if child.resolved => + try { + val newCond = resolveExpressionRecursively(cond, child) + val requiredAttrs = newCond.references.filter(_.resolved) + val missingAttrs = requiredAttrs -- child.outputSet + if (missingAttrs.nonEmpty) { + // Add missing attributes and then project them away. + Project(child.output, + Filter(newCond, addMissingAttr(child, missingAttrs))) + } else if (newCond != cond) { + f.copy(condition = newCond) + } else { + f + } + } catch { + // Attempting to resolve it might fail. When this happens, return the original plan. + // Users will see an AnalysisException for resolution failure of missing attributes + case ae: AnalysisException => f + } } /** @@ -843,27 +911,33 @@ class Analyzer( if aggregate.resolved => // Try resolving the condition of the filter as though it is in the aggregate clause - val aggregatedCondition = - Aggregate( - grouping, - Alias(havingCondition, "havingCondition")(isGenerated = true) :: Nil, - child) - val resolvedOperator = execute(aggregatedCondition) - def resolvedAggregateFilter = - resolvedOperator - .asInstanceOf[Aggregate] - .aggregateExpressions.head - - // If resolution was successful and we see the filter has an aggregate in it, add it to - // the original aggregate operator. - if (resolvedOperator.resolved && containsAggregate(resolvedAggregateFilter)) { - val aggExprsWithHaving = resolvedAggregateFilter +: originalAggExprs - - Project(aggregate.output, - Filter(resolvedAggregateFilter.toAttribute, - aggregate.copy(aggregateExpressions = aggExprsWithHaving))) - } else { - filter + try { + val aggregatedCondition = + Aggregate( + grouping, + Alias(havingCondition, "havingCondition")(isGenerated = true) :: Nil, + child) + val resolvedOperator = execute(aggregatedCondition) + def resolvedAggregateFilter = + resolvedOperator + .asInstanceOf[Aggregate] + .aggregateExpressions.head + + // If resolution was successful and we see the filter has an aggregate in it, add it to + // the original aggregate operator. + if (resolvedOperator.resolved && containsAggregate(resolvedAggregateFilter)) { + val aggExprsWithHaving = resolvedAggregateFilter +: originalAggExprs + + Project(aggregate.output, + Filter(resolvedAggregateFilter.toAttribute, + aggregate.copy(aggregateExpressions = aggExprsWithHaving))) + } else { + filter + } + } catch { + // Attempting to resolve in the aggregate can result in ambiguity. When this happens, + // just return the original plan. + case ae: AnalysisException => filter } case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved => @@ -927,11 +1001,8 @@ class Analyzer( } } - private def isAggregateExpression(e: Expression): Boolean = { - e.isInstanceOf[AggregateExpression] || e.isInstanceOf[Grouping] || e.isInstanceOf[GroupingID] - } def containsAggregate(condition: Expression): Boolean = { - condition.find(isAggregateExpression).isDefined + condition.find(_.isInstanceOf[AggregateExpression]).isDefined } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 2307122ea1..78310fb2f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -333,6 +333,8 @@ case class PrettyAttribute( } object VirtualColumn { - val groupingIdName: String = "grouping__id" + // The attribute name used by Hive, which has different result than Spark, deprecated. + val hiveGroupingIdName: String = "grouping__id" + val groupingIdName: String = "spark_grouping_id" val groupingIdAttribute: UnresolvedAttribute = UnresolvedAttribute(groupingIdName) } |