aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-03-08 23:34:42 +0800
committerWenchen Fan <wenchen@databricks.com>2016-03-08 23:34:42 +0800
commit7d05d02bffe5f1c4fbf955664bcc87e38ce01f5f (patch)
treeb7b1d1462abaac12e8d3697e21397dc711ee5d93 /sql/catalyst
parent9e86e6efd136182bb00fa925c3818c9baccbd1fc (diff)
downloadspark-7d05d02bffe5f1c4fbf955664bcc87e38ce01f5f.tar.gz
spark-7d05d02bffe5f1c4fbf955664bcc87e38ce01f5f.tar.bz2
spark-7d05d02bffe5f1c4fbf955664bcc87e38ce01f5f.zip
[SPARK-13637][SQL] use more information to simplify the code in Expand builder
## What changes were proposed in this pull request? The code in `Expand.apply` can be simplified by existing information: * the `groupByExprs` parameter are all `Attribute`s * the `child` parameter is a `Project` that append aliased group by expressions to its child's output ## How was this patch tested? by existing tests. Author: Wenchen Fan <wenchen@databricks.com> Closes #11485 from cloud-fan/expand.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala48
2 files changed, 23 insertions, 29 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index b5fa372643..268d7f21e6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -298,12 +298,10 @@ class Analyzer(
}.asInstanceOf[NamedExpression]
}
- val child = Project(x.child.output ++ groupByAliases, x.child)
-
Aggregate(
groupByAttributes :+ VirtualColumn.groupingIdAttribute,
aggregations,
- Expand(x.bitmasks, groupByAttributes, gid, child))
+ Expand(x.bitmasks, groupByAliases, groupByAttributes, gid, x.child))
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 411594c951..3bc246a32d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -449,21 +449,21 @@ private[sql] object Expand {
* Extract attribute set according to the grouping id.
*
* @param bitmask bitmask to represent the selected of the attribute sequence
- * @param exprs the attributes in sequence
+ * @param attrs the attributes in sequence
* @return the attributes of non selected specified via bitmask (with the bit set to 1)
*/
- private def buildNonSelectExprSet(
+ private def buildNonSelectAttrSet(
bitmask: Int,
- exprs: Seq[Expression]): ArrayBuffer[Expression] = {
- val set = new ArrayBuffer[Expression](2)
+ attrs: Seq[Attribute]): AttributeSet = {
+ val nonSelect = new ArrayBuffer[Attribute]()
- var bit = exprs.length - 1
+ var bit = attrs.length - 1
while (bit >= 0) {
- if (((bitmask >> bit) & 1) == 1) set += exprs(exprs.length - bit - 1)
+ if (((bitmask >> bit) & 1) == 1) nonSelect += attrs(attrs.length - bit - 1)
bit -= 1
}
- set
+ AttributeSet(nonSelect)
}
/**
@@ -471,13 +471,15 @@ private[sql] object Expand {
* multiple output rows for a input row.
*
* @param bitmasks The bitmask set represents the grouping sets
- * @param groupByExprs The grouping by expressions
+ * @param groupByAliases The aliased original group by expressions
+ * @param groupByAttrs The attributes of aliased group by expressions
* @param gid Attribute of the grouping id
* @param child Child operator
*/
def apply(
bitmasks: Seq[Int],
- groupByExprs: Seq[Expression],
+ groupByAliases: Seq[Alias],
+ groupByAttrs: Seq[Attribute],
gid: Attribute,
child: LogicalPlan): Expand = {
// Create an array of Projections for the child projection, and replace the projections'
@@ -485,27 +487,21 @@ private[sql] object Expand {
// are not set for this grouping set (according to the bit mask).
val projections = bitmasks.map { bitmask =>
// get the non selected grouping attributes according to the bit mask
- val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, groupByExprs)
+ val nonSelectedGroupAttrSet = buildNonSelectAttrSet(bitmask, groupByAttrs)
- (child.output :+ gid).map(expr => expr transformDown {
- // TODO this causes a problem when a column is used both for grouping and aggregation.
- case x: Expression if nonSelectedGroupExprSet.exists(_.semanticEquals(x)) =>
+ child.output ++ groupByAttrs.map { attr =>
+ if (nonSelectedGroupAttrSet.contains(attr)) {
// if the input attribute in the Invalid Grouping Expression set of for this group
// replace it with constant null
- Literal.create(null, expr.dataType)
- case x if x == gid =>
- // replace the groupingId with concrete value (the bit mask)
- Literal.create(bitmask, IntegerType)
- })
- }
- val output = child.output.map { attr =>
- if (groupByExprs.exists(_.semanticEquals(attr))) {
- attr.withNullability(true)
- } else {
- attr
- }
+ Literal.create(null, attr.dataType)
+ } else {
+ attr
+ }
+ // groupingId is the last output, here we use the bit mask as the concrete value for it.
+ } :+ Literal.create(bitmask, IntegerType)
}
- Expand(projections, output :+ gid, child)
+ val output = child.output ++ groupByAttrs :+ gid
+ Expand(projections, output, Project(child.output ++ groupByAliases, child))
}
}