From b60b8137992641b9193e57061aa405f908b0f267 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 2 Mar 2016 20:18:57 -0800 Subject: [SPARK-13617][SQL] remove unnecessary GroupingAnalytics trait ## What changes were proposed in this pull request? The `trait GroupingAnalytics` only has one implementation, it's an unnecessary abstraction. This PR removes it, and does some code simplification when resolving `GroupingSet`. ## How was this patch tested? existing tests Author: Wenchen Fan Closes #11469 from cloud-fan/groupingset. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 22 ++++++++++----------- .../catalyst/plans/logical/basicOperators.scala | 23 ++++++---------------- 2 files changed, 17 insertions(+), 28 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 876aa0eae0..36eb59ef5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -181,8 +181,8 @@ class Analyzer( case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) => Aggregate(groups, assignAliases(aggs), child) - case g: GroupingAnalytics if g.child.resolved && hasUnresolvedAlias(g.aggregations) => - g.withNewAggs(assignAliases(g.aggregations)) + case g: GroupingSets if g.child.resolved && hasUnresolvedAlias(g.aggregations) => + g.copy(aggregations = assignAliases(g.aggregations)) case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) if child.resolved && hasUnresolvedAlias(groupByExprs) => @@ -250,13 +250,9 @@ class Analyzer( val nonNullBitmask = x.bitmasks.reduce(_ & _) - val attributeMap = groupByAliases.zipWithIndex.map { case (a, idx) => - if ((nonNullBitmask & 1 << idx) == 0) { - (a -> a.toAttribute.withNullability(true)) - } else { - (a -> a.toAttribute) - } - }.toMap + val groupByAttributes = groupByAliases.zipWithIndex.map { case (a, idx) => + a.toAttribute.withNullability((nonNullBitmask & 1 << idx) == 0) + } val aggregations: Seq[NamedExpression] = x.aggregations.map { case expr => // collect all the found AggregateExpression, so we can check an expression is part of @@ -292,12 +288,16 @@ class Analyzer( s"in grouping columns ${x.groupByExprs.mkString(",")}") } case e => - groupByAliases.find(_.child.semanticEquals(e)).map(attributeMap(_)).getOrElse(e) + val index = groupByAliases.indexWhere(_.child.semanticEquals(e)) + if (index == -1) { + e + } else { + groupByAttributes(index) + } }.asInstanceOf[NamedExpression] } val child = Project(x.child.output ++ groupByAliases, x.child) - val groupByAttributes = groupByAliases.map(attributeMap(_)) Aggregate( groupByAttributes :+ VirtualColumn.groupingIdAttribute, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index e81a0f9487..522348735a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -533,20 +533,6 @@ case class Expand( } } -trait GroupingAnalytics extends UnaryNode { - - def groupByExprs: Seq[Expression] - def aggregations: Seq[NamedExpression] - - override def output: Seq[Attribute] = aggregations.map(_.toAttribute) - - // Needs to be unresolved before its translated to Aggregate + Expand because output attributes - // will change in analysis. - override lazy val resolved: Boolean = false - - def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics -} - /** * A GROUP BY clause with GROUPING SETS can generate a result set equivalent * to generated by a UNION ALL of multiple simple GROUP BY clauses. @@ -565,10 +551,13 @@ case class GroupingSets( bitmasks: Seq[Int], groupByExprs: Seq[Expression], child: LogicalPlan, - aggregations: Seq[NamedExpression]) extends GroupingAnalytics { + aggregations: Seq[NamedExpression]) extends UnaryNode { + + override def output: Seq[Attribute] = aggregations.map(_.toAttribute) - def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics = - this.copy(aggregations = aggs) + // Needs to be unresolved before its translated to Aggregate + Expand because output attributes + // will change in analysis. + override lazy val resolved: Boolean = false } case class Pivot( -- cgit v1.2.3