aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-03-02 20:18:57 -0800
committerReynold Xin <rxin@databricks.com>2016-03-02 20:18:57 -0800
commitb60b8137992641b9193e57061aa405f908b0f267 (patch)
treee4021ed472a04f08dce12e5f1ac0b87932b9072a
parent6250cf1e00f6b0bacca73ad785fa402f59bd6232 (diff)
downloadspark-b60b8137992641b9193e57061aa405f908b0f267.tar.gz
spark-b60b8137992641b9193e57061aa405f908b0f267.tar.bz2
spark-b60b8137992641b9193e57061aa405f908b0f267.zip
[SPARK-13617][SQL] remove unnecessary GroupingAnalytics trait
## What changes were proposed in this pull request? The `trait GroupingAnalytics` only has one implementation, it's an unnecessary abstraction. This PR removes it, and does some code simplification when resolving `GroupingSet`. ## How was this patch tested? existing tests Author: Wenchen Fan <wenchen@databricks.com> Closes #11469 from cloud-fan/groupingset.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala22
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala23
2 files changed, 17 insertions, 28 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 876aa0eae0..36eb59ef5e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -181,8 +181,8 @@ class Analyzer(
case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) =>
Aggregate(groups, assignAliases(aggs), child)
- case g: GroupingAnalytics if g.child.resolved && hasUnresolvedAlias(g.aggregations) =>
- g.withNewAggs(assignAliases(g.aggregations))
+ case g: GroupingSets if g.child.resolved && hasUnresolvedAlias(g.aggregations) =>
+ g.copy(aggregations = assignAliases(g.aggregations))
case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child)
if child.resolved && hasUnresolvedAlias(groupByExprs) =>
@@ -250,13 +250,9 @@ class Analyzer(
val nonNullBitmask = x.bitmasks.reduce(_ & _)
- val attributeMap = groupByAliases.zipWithIndex.map { case (a, idx) =>
- if ((nonNullBitmask & 1 << idx) == 0) {
- (a -> a.toAttribute.withNullability(true))
- } else {
- (a -> a.toAttribute)
- }
- }.toMap
+ val groupByAttributes = groupByAliases.zipWithIndex.map { case (a, idx) =>
+ a.toAttribute.withNullability((nonNullBitmask & 1 << idx) == 0)
+ }
val aggregations: Seq[NamedExpression] = x.aggregations.map { case expr =>
// collect all the found AggregateExpression, so we can check an expression is part of
@@ -292,12 +288,16 @@ class Analyzer(
s"in grouping columns ${x.groupByExprs.mkString(",")}")
}
case e =>
- groupByAliases.find(_.child.semanticEquals(e)).map(attributeMap(_)).getOrElse(e)
+ val index = groupByAliases.indexWhere(_.child.semanticEquals(e))
+ if (index == -1) {
+ e
+ } else {
+ groupByAttributes(index)
+ }
}.asInstanceOf[NamedExpression]
}
val child = Project(x.child.output ++ groupByAliases, x.child)
- val groupByAttributes = groupByAliases.map(attributeMap(_))
Aggregate(
groupByAttributes :+ VirtualColumn.groupingIdAttribute,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index e81a0f9487..522348735a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -533,20 +533,6 @@ case class Expand(
}
}
-trait GroupingAnalytics extends UnaryNode {
-
- def groupByExprs: Seq[Expression]
- def aggregations: Seq[NamedExpression]
-
- override def output: Seq[Attribute] = aggregations.map(_.toAttribute)
-
- // Needs to be unresolved before its translated to Aggregate + Expand because output attributes
- // will change in analysis.
- override lazy val resolved: Boolean = false
-
- def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics
-}
-
/**
* A GROUP BY clause with GROUPING SETS can generate a result set equivalent
* to generated by a UNION ALL of multiple simple GROUP BY clauses.
@@ -565,10 +551,13 @@ case class GroupingSets(
bitmasks: Seq[Int],
groupByExprs: Seq[Expression],
child: LogicalPlan,
- aggregations: Seq[NamedExpression]) extends GroupingAnalytics {
+ aggregations: Seq[NamedExpression]) extends UnaryNode {
+
+ override def output: Seq[Attribute] = aggregations.map(_.toAttribute)
- def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics =
- this.copy(aggregations = aggs)
+ // Needs to be unresolved before its translated to Aggregate + Expand because output attributes
+ // will change in analysis.
+ override lazy val resolved: Boolean = false
}
case class Pivot(