aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndrew Ray <ray.andrew@gmail.com>2015-11-19 15:11:30 -0800
committerYin Huai <yhuai@databricks.com>2015-11-19 15:11:30 -0800
commit37cff1b1a79cad11277612cb9bc8bc2365cf5ff2 (patch)
treecdc83803a933b06ff3ec41c958825152f0854c51
parent01403aa97b6aaab9b86ae806b5ea9e82690a741f (diff)
downloadspark-37cff1b1a79cad11277612cb9bc8bc2365cf5ff2.tar.gz
spark-37cff1b1a79cad11277612cb9bc8bc2365cf5ff2.tar.bz2
spark-37cff1b1a79cad11277612cb9bc8bc2365cf5ff2.zip
[SPARK-11275][SQL] Incorrect results when using rollup/cube
Fixes bug with grouping sets (including cube/rollup) where aggregates that included grouping expressions would return the wrong (null) result. Also simplifies the analyzer rule a bit and leaves column pruning to the optimizer. Added multiple unit tests to DataFrameAggregateSuite and verified it passes hive compatibility suite: ``` build/sbt -Phive -Dspark.hive.whitelist='groupby.*_grouping.*' 'test-only org.apache.spark.sql.hive.execution.HiveCompatibilitySuite' ``` This is an alternative to pr https://github.com/apache/spark/pull/9419 but I think its better as it simplifies the analyzer rule instead of adding another special case to it. Author: Andrew Ray <ray.andrew@gmail.com> Closes #9815 from aray/groupingset-agg-fix.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala58
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala62
3 files changed, 90 insertions, 34 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 84781cd57f..47962ebe6e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -213,45 +213,35 @@ class Analyzer(
GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations)
case x: GroupingSets =>
val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)()
- // We will insert another Projection if the GROUP BY keys contains the
- // non-attribute expressions. And the top operators can references those
- // expressions by its alias.
- // e.g. SELECT key%5 as c1 FROM src GROUP BY key%5 ==>
- // SELECT a as c1 FROM (SELECT key%5 AS a FROM src) GROUP BY a
-
- // find all of the non-attribute expressions in the GROUP BY keys
- val nonAttributeGroupByExpressions = new ArrayBuffer[Alias]()
-
- // The pair of (the original GROUP BY key, associated attribute)
- val groupByExprPairs = x.groupByExprs.map(_ match {
- case e: NamedExpression => (e, e.toAttribute)
- case other => {
- val alias = Alias(other, other.toString)()
- nonAttributeGroupByExpressions += alias // add the non-attributes expression alias
- (other, alias.toAttribute)
- }
- })
-
- // substitute the non-attribute expressions for aggregations.
- val aggregation = x.aggregations.map(expr => expr.transformDown {
- case e => groupByExprPairs.find(_._1.semanticEquals(e)).map(_._2).getOrElse(e)
- }.asInstanceOf[NamedExpression])
- // substitute the group by expressions.
- val newGroupByExprs = groupByExprPairs.map(_._2)
+ // Expand works by setting grouping expressions to null as determined by the bitmasks. To
+ // prevent these null values from being used in an aggregate instead of the original value
+ // we need to create new aliases for all group by expressions that will only be used for
+ // the intended purpose.
+ val groupByAliases: Seq[Alias] = x.groupByExprs.map {
+ case e: NamedExpression => Alias(e, e.name)()
+ case other => Alias(other, other.toString)()
+ }
- val child = if (nonAttributeGroupByExpressions.length > 0) {
- // insert additional projection if contains the
- // non-attribute expressions in the GROUP BY keys
- Project(x.child.output ++ nonAttributeGroupByExpressions, x.child)
- } else {
- x.child
+ val aggregations: Seq[NamedExpression] = x.aggregations.map {
+ // If an expression is an aggregate (contains a AggregateExpression) then we dont change
+ // it so that the aggregation is computed on the unmodified value of its argument
+ // expressions.
+ case expr if expr.find(_.isInstanceOf[AggregateExpression]).nonEmpty => expr
+ // If not then its a grouping expression and we need to use the modified (with nulls from
+ // Expand) value of the expression.
+ case expr => expr.transformDown {
+ case e => groupByAliases.find(_.child.semanticEquals(e)).map(_.toAttribute).getOrElse(e)
+ }.asInstanceOf[NamedExpression]
}
+ val child = Project(x.child.output ++ groupByAliases, x.child)
+ val groupByAttributes = groupByAliases.map(_.toAttribute)
+
Aggregate(
- newGroupByExprs :+ VirtualColumn.groupingIdAttribute,
- aggregation,
- Expand(x.bitmasks, newGroupByExprs, gid, child))
+ groupByAttributes :+ VirtualColumn.groupingIdAttribute,
+ aggregations,
+ Expand(x.bitmasks, groupByAttributes, gid, child))
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 45630a591d..0c444482c5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -323,6 +323,10 @@ trait GroupingAnalytics extends UnaryNode {
override def output: Seq[Attribute] = aggregations.map(_.toAttribute)
+ // Needs to be unresolved before its translated to Aggregate + Expand because output attributes
+ // will change in analysis.
+ override lazy val resolved: Boolean = false
+
def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 71adf2148a..9c42f65bb6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -60,6 +60,68 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
)
}
+ test("rollup") {
+ checkAnswer(
+ courseSales.rollup("course", "year").sum("earnings"),
+ Row("Java", 2012, 20000.0) ::
+ Row("Java", 2013, 30000.0) ::
+ Row("Java", null, 50000.0) ::
+ Row("dotNET", 2012, 15000.0) ::
+ Row("dotNET", 2013, 48000.0) ::
+ Row("dotNET", null, 63000.0) ::
+ Row(null, null, 113000.0) :: Nil
+ )
+ }
+
+ test("cube") {
+ checkAnswer(
+ courseSales.cube("course", "year").sum("earnings"),
+ Row("Java", 2012, 20000.0) ::
+ Row("Java", 2013, 30000.0) ::
+ Row("Java", null, 50000.0) ::
+ Row("dotNET", 2012, 15000.0) ::
+ Row("dotNET", 2013, 48000.0) ::
+ Row("dotNET", null, 63000.0) ::
+ Row(null, 2012, 35000.0) ::
+ Row(null, 2013, 78000.0) ::
+ Row(null, null, 113000.0) :: Nil
+ )
+ }
+
+ test("rollup overlapping columns") {
+ checkAnswer(
+ testData2.rollup($"a" + $"b" as "foo", $"b" as "bar").agg(sum($"a" - $"b") as "foo"),
+ Row(2, 1, 0) :: Row(3, 2, -1) :: Row(3, 1, 1) :: Row(4, 2, 0) :: Row(4, 1, 2) :: Row(5, 2, 1)
+ :: Row(2, null, 0) :: Row(3, null, 0) :: Row(4, null, 2) :: Row(5, null, 1)
+ :: Row(null, null, 3) :: Nil
+ )
+
+ checkAnswer(
+ testData2.rollup("a", "b").agg(sum("b")),
+ Row(1, 1, 1) :: Row(1, 2, 2) :: Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 1) :: Row(3, 2, 2)
+ :: Row(1, null, 3) :: Row(2, null, 3) :: Row(3, null, 3)
+ :: Row(null, null, 9) :: Nil
+ )
+ }
+
+ test("cube overlapping columns") {
+ checkAnswer(
+ testData2.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")),
+ Row(2, 1, 0) :: Row(3, 2, -1) :: Row(3, 1, 1) :: Row(4, 2, 0) :: Row(4, 1, 2) :: Row(5, 2, 1)
+ :: Row(2, null, 0) :: Row(3, null, 0) :: Row(4, null, 2) :: Row(5, null, 1)
+ :: Row(null, 1, 3) :: Row(null, 2, 0)
+ :: Row(null, null, 3) :: Nil
+ )
+
+ checkAnswer(
+ testData2.cube("a", "b").agg(sum("b")),
+ Row(1, 1, 1) :: Row(1, 2, 2) :: Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 1) :: Row(3, 2, 2)
+ :: Row(1, null, 3) :: Row(2, null, 3) :: Row(3, null, 3)
+ :: Row(null, 1, 3) :: Row(null, 2, 6)
+ :: Row(null, null, 9) :: Nil
+ )
+ }
+
test("spark.sql.retainGroupColumns config") {
checkAnswer(
testData2.groupBy("a").agg(sum($"b")),