diff options
Diffstat (limited to 'sql')
15 files changed, 254 insertions, 52 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala index a42360d562..8099751900 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala @@ -186,8 +186,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C * * The bitmask denotes the grouping expressions validity for a grouping set, * the bitmask also be called as grouping id (`GROUPING__ID`, the virtual column in Hive) - * e.g. In superset (k1, k2, k3), (bit 0: k1, bit 1: k2, and bit 2: k3), the grouping id of - * GROUPING SETS (k1, k2) and (k2) should be 3 and 2 respectively. + * e.g. In superset (k1, k2, k3), (bit 2: k1, bit 1: k2, and bit 0: k3), the grouping id of + * GROUPING SETS (k1, k2) and (k2) should be 1 and 5 respectively. */ protected def extractGroupingSet(children: Seq[ASTNode]): (Seq[Expression], Seq[Int]) = { val (keyASTs, setASTs) = children.partition { @@ -198,12 +198,15 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val keys = keyASTs.map(nodeToExpr) val keyMap = keyASTs.zipWithIndex.toMap + val mask = (1 << keys.length) - 1 val bitmasks: Seq[Int] = setASTs.map { case Token("TOK_GROUPING_SETS_EXPRESSION", columns) => - columns.foldLeft(0)((bitmap, col) => { - val keyIndex = keyMap.find(_._1.treeEquals(col)).map(_._2) - bitmap | 1 << keyIndex.getOrElse( + columns.foldLeft(mask)((bitmap, col) => { + val keyIndex = keyMap.find(_._1.treeEquals(col)).map(_._2).getOrElse( throw new AnalysisException(s"${col.treeString} doesn't show up in the GROUP BY list")) + // 0 means that the column at the given index is a grouping column, 1 means it is not, + // so we unset the bit in bitmap. + bitmap & ~(1 << (keys.length - 1 - keyIndex)) }) case _ => sys.error("Expect GROUPING SETS clause") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 62b241f052..c0fa79612a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -238,14 +238,39 @@ class Analyzer( } }.toMap - val aggregations: Seq[NamedExpression] = x.aggregations.map { - // If an expression is an aggregate (contains a AggregateExpression) then we dont change - // it so that the aggregation is computed on the unmodified value of its argument - // expressions. - case expr if expr.find(_.isInstanceOf[AggregateExpression]).nonEmpty => expr - // If not then its a grouping expression and we need to use the modified (with nulls from - // Expand) value of the expression. - case expr => expr.transformDown { + val aggregations: Seq[NamedExpression] = x.aggregations.map { case expr => + // collect all the found AggregateExpression, so we can check an expression is part of + // any AggregateExpression or not. + val aggsBuffer = ArrayBuffer[Expression]() + // Returns whether the expression belongs to any expressions in `aggsBuffer` or not. + def isPartOfAggregation(e: Expression): Boolean = { + aggsBuffer.exists(a => a.find(_ eq e).isDefined) + } + expr.transformDown { + // AggregateExpression should be computed on the unmodified value of its argument + // expressions, so we should not replace any references to grouping expression + // inside it. + case e: AggregateExpression => + aggsBuffer += e + e + case e if isPartOfAggregation(e) => e + case e: GroupingID => + if (e.groupByExprs.isEmpty || e.groupByExprs == x.groupByExprs) { + gid + } else { + throw new AnalysisException( + s"Columns of grouping_id (${e.groupByExprs.mkString(",")}) does not match " + + s"grouping columns (${x.groupByExprs.mkString(",")})") + } + case Grouping(col: Expression) => + val idx = x.groupByExprs.indexOf(col) + if (idx >= 0) { + Cast(BitwiseAnd(ShiftRight(gid, Literal(x.groupByExprs.length - 1 - idx)), + Literal(1)), ByteType) + } else { + throw new AnalysisException(s"Column of grouping ($col) can't be found " + + s"in grouping columns ${x.groupByExprs.mkString(",")}") + } case e => groupByAliases.find(_.child.semanticEquals(e)).map(attributeMap(_)).getOrElse(e) }.asInstanceOf[NamedExpression] @@ -819,8 +844,11 @@ class Analyzer( } } + private def isAggregateExpression(e: Expression): Boolean = { + e.isInstanceOf[AggregateExpression] || e.isInstanceOf[Grouping] || e.isInstanceOf[GroupingID] + } def containsAggregate(condition: Expression): Boolean = { - condition.find(_.isInstanceOf[AggregateExpression]).isDefined + condition.find(isAggregateExpression).isDefined } } @@ -1002,7 +1030,7 @@ class Analyzer( _.transform { // Extracts children expressions of a WindowFunction (input parameters of // a WindowFunction). - case wf : WindowFunction => + case wf: WindowFunction => val newChildren = wf.children.map(extractExpr) wf.withNewChildren(newChildren) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 4a2f2b8bc6..fe053b9a0b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -70,6 +70,11 @@ trait CheckAnalysis { failAnalysis( s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}") + case g: Grouping => + failAnalysis(s"grouping() can only be used with GroupingSets/Cube/Rollup") + case g: GroupingID => + failAnalysis(s"grouping_id() can only be used with GroupingSets/Cube/Rollup") + case w @ WindowExpression(AggregateExpression(_, _, true), _) => failAnalysis(s"Distinct window functions are not supported: $w") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index d9009e3848..1be97c7b81 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -291,6 +291,8 @@ object FunctionRegistry { // grouping sets expression[Cube]("cube"), expression[Rollup]("rollup"), + expression[Grouping]("grouping"), + expression[GroupingID]("grouping_id"), // window functions expression[Lead]("lead"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 561fa3321d..f88a57a254 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -344,4 +344,3 @@ abstract class DeclarativeAggregate def right: AttributeReference = inputAggBufferAttributes(aggBufferAttributes.indexOf(a)) } } - diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala index 2997ee879d..a204060630 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala @@ -41,3 +41,26 @@ trait GroupingSet extends Expression with CodegenFallback { case class Cube(groupByExprs: Seq[Expression]) extends GroupingSet {} case class Rollup(groupByExprs: Seq[Expression]) extends GroupingSet {} + +/** + * Indicates whether a specified column expression in a GROUP BY list is aggregated or not. + * GROUPING returns 1 for aggregated or 0 for not aggregated in the result set. + */ +case class Grouping(child: Expression) extends Expression with Unevaluable { + override def references: AttributeSet = AttributeSet(VirtualColumn.groupingIdAttribute :: Nil) + override def children: Seq[Expression] = child :: Nil + override def dataType: DataType = ByteType + override def nullable: Boolean = false +} + +/** + * GroupingID is a function that computes the level of grouping. + * + * If groupByExprs is empty, it means all grouping expressions in GroupingSets. + */ +case class GroupingID(groupByExprs: Seq[Expression]) extends Expression with Unevaluable { + override def references: AttributeSet = AttributeSet(VirtualColumn.groupingIdAttribute :: Nil) + override def children: Seq[Expression] = groupByExprs + override def dataType: DataType = IntegerType + override def nullable: Boolean = false +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 57575f9ee0..e8e0a78904 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -412,7 +412,7 @@ private[sql] object Expand { var bit = exprs.length - 1 while (bit >= 0) { - if (((bitmask >> bit) & 1) == 0) set += exprs(bit) + if (((bitmask >> bit) & 1) == 1) set += exprs(exprs.length - bit - 1) bit -= 1 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index b970eee4e3..d34d377ab6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -396,6 +396,52 @@ object functions extends LegacyFunctions { */ def first(columnName: String): Column = first(Column(columnName)) + + /** + * Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated + * or not, returns 1 for aggregated or 0 for not aggregated in the result set. + * + * @group agg_funcs + * @since 2.0.0 + */ + def grouping(e: Column): Column = Column(Grouping(e.expr)) + + /** + * Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated + * or not, returns 1 for aggregated or 0 for not aggregated in the result set. + * + * @group agg_funcs + * @since 2.0.0 + */ + def grouping(columnName: String): Column = grouping(Column(columnName)) + + /** + * Aggregate function: returns the level of grouping, equals to + * + * (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + ... + grouping(cn) + * + * Note: the list of columns should match with grouping columns exactly, or empty (means all the + * grouping columns). + * + * @group agg_funcs + * @since 2.0.0 + */ + def grouping_id(cols: Column*): Column = Column(GroupingID(cols.map(_.expr))) + + /** + * Aggregate function: returns the level of grouping, equals to + * + * (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + ... + grouping(cn) + * + * Note: the list of columns should match with grouping columns exactly. + * + * @group agg_funcs + * @since 2.0.0 + */ + def grouping_id(colName: String, colNames: String*): Column = { + grouping_id((Seq(colName) ++ colNames).map(n => Column(n)) : _*) + } + /** * Aggregate function: returns the kurtosis of the values in a group. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 08fb7c9d84..78bf6c1bce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.DecimalType @@ -98,6 +99,49 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { assert(cube0.where("date IS NULL").count > 0) } + test("grouping and grouping_id") { + checkAnswer( + courseSales.cube("course", "year") + .agg(grouping("course"), grouping("year"), grouping_id("course", "year")), + Row("Java", 2012, 0, 0, 0) :: + Row("Java", 2013, 0, 0, 0) :: + Row("Java", null, 0, 1, 1) :: + Row("dotNET", 2012, 0, 0, 0) :: + Row("dotNET", 2013, 0, 0, 0) :: + Row("dotNET", null, 0, 1, 1) :: + Row(null, 2012, 1, 0, 2) :: + Row(null, 2013, 1, 0, 2) :: + Row(null, null, 1, 1, 3) :: Nil + ) + + intercept[AnalysisException] { + courseSales.groupBy().agg(grouping("course")).explain() + } + intercept[AnalysisException] { + courseSales.groupBy().agg(grouping_id("course")).explain() + } + } + + test("grouping/grouping_id inside window function") { + + val w = Window.orderBy(sum("earnings")) + checkAnswer( + courseSales.cube("course", "year") + .agg(sum("earnings"), + grouping_id("course", "year"), + rank().over(Window.partitionBy(grouping_id("course", "year")).orderBy(sum("earnings")))), + Row("Java", 2012, 20000.0, 0, 2) :: + Row("Java", 2013, 30000.0, 0, 3) :: + Row("Java", null, 50000.0, 1, 1) :: + Row("dotNET", 2012, 15000.0, 0, 1) :: + Row("dotNET", 2013, 48000.0, 0, 4) :: + Row("dotNET", null, 63000.0, 1, 2) :: + Row(null, 2012, 35000.0, 2, 1) :: + Row(null, 2013, 78000.0, 2, 2) :: + Row(null, null, 113000.0, 3, 1) :: Nil + ) + } + test("rollup overlapping columns") { checkAnswer( testData2.rollup($"a" + $"b" as "foo", $"b" as "bar").agg(sum($"a" - $"b") as "foo"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 8ef7b61314..f665a1c87b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2055,6 +2055,56 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { ) } + test("grouping sets") { + checkAnswer( + sql("select course, year, sum(earnings) from courseSales group by course, year " + + "grouping sets(course, year)"), + Row("Java", null, 50000.0) :: + Row("dotNET", null, 63000.0) :: + Row(null, 2012, 35000.0) :: + Row(null, 2013, 78000.0) :: Nil + ) + + checkAnswer( + sql("select course, year, sum(earnings) from courseSales group by course, year " + + "grouping sets(course)"), + Row("Java", null, 50000.0) :: + Row("dotNET", null, 63000.0) :: Nil + ) + + checkAnswer( + sql("select course, year, sum(earnings) from courseSales group by course, year " + + "grouping sets(year)"), + Row(null, 2012, 35000.0) :: + Row(null, 2013, 78000.0) :: Nil + ) + } + + test("grouping and grouping_id") { + checkAnswer( + sql("select course, year, grouping(course), grouping(year), grouping_id(course, year)" + + " from courseSales group by cube(course, year)"), + Row("Java", 2012, 0, 0, 0) :: + Row("Java", 2013, 0, 0, 0) :: + Row("Java", null, 0, 1, 1) :: + Row("dotNET", 2012, 0, 0, 0) :: + Row("dotNET", 2013, 0, 0, 0) :: + Row("dotNET", null, 0, 1, 1) :: + Row(null, 2012, 1, 0, 2) :: + Row(null, 2013, 1, 0, 2) :: + Row(null, null, 1, 1, 3) :: Nil + ) + + var error = intercept[AnalysisException] { + sql("select course, year, grouping(course) from courseSales group by course, year") + } + assert(error.getMessage contains "grouping() can only be used with GroupingSets/Cube/Rollup") + error = intercept[AnalysisException] { + sql("select course, year, grouping_id(course, year) from courseSales group by course, year") + } + assert(error.getMessage contains "grouping_id() can only be used with GroupingSets/Cube/Rollup") + } + test("SPARK-13056: Null in map value causes NPE") { val df = Seq(1 -> Map("abc" -> "somestring", "cba" -> null)).toDF("key", "value") withTempTable("maptest") { diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 61b73fa557..9097c1a1d3 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -328,6 +328,11 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // Hive returns null rather than NaN when n = 1 "udaf_covar_samp", + // The implementation of GROUPING__ID in Hive is wrong (not match with doc). + "groupby_grouping_id1", + "groupby_grouping_id2", + "groupby_grouping_sets1", + // Spark parser treats numerical literals differently: it creates decimals instead of doubles. "udf_abs", "udf_format_number", @@ -503,9 +508,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "groupby11", "groupby12", "groupby1_limit", - "groupby_grouping_id1", - "groupby_grouping_id2", - "groupby_grouping_sets1", "groupby_grouping_sets2", "groupby_grouping_sets3", "groupby_grouping_sets4", diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #1-0-63b61fb3f0e74226001ad279be440864 b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #1-0-63b61fb3f0e74226001ad279be440864 index dac1b84b91..c066aeead8 100644 --- a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #1-0-63b61fb3f0e74226001ad279be440864 +++ b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for CUBE #1-0-63b61fb3f0e74226001ad279be440864 @@ -1,6 +1,6 @@ -500 NULL 0 -91 0 1 -84 1 1 -105 2 1 -113 3 1 -107 4 1 +500 NULL 1 +91 0 0 +84 1 0 +105 2 0 +113 3 0 +107 4 0 diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #1-0-a78e3dbf242f240249e36b3d3fd0926a b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #1-0-a78e3dbf242f240249e36b3d3fd0926a index dac1b84b91..c066aeead8 100644 --- a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #1-0-a78e3dbf242f240249e36b3d3fd0926a +++ b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #1-0-a78e3dbf242f240249e36b3d3fd0926a @@ -1,6 +1,6 @@ -500 NULL 0 -91 0 1 -84 1 1 -105 2 1 -113 3 1 -107 4 1 +500 NULL 1 +91 0 0 +84 1 0 +105 2 0 +113 3 0 +107 4 0 diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #2-0-bf180c9d1a18f61b9d9f31bb0115cf89 b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #2-0-bf180c9d1a18f61b9d9f31bb0115cf89 index 1eea4a9b23..fcacbe3f69 100644 --- a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #2-0-bf180c9d1a18f61b9d9f31bb0115cf89 +++ b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #2-0-bf180c9d1a18f61b9d9f31bb0115cf89 @@ -1,10 +1,10 @@ -1 0 5 3 -1 0 15 3 -1 0 25 3 -1 0 60 3 -1 0 75 3 -1 0 80 3 -1 0 100 3 -1 0 140 3 -1 0 145 3 -1 0 150 3 +1 0 5 0 +1 0 15 0 +1 0 25 0 +1 0 60 0 +1 0 75 0 +1 0 80 0 +1 0 100 0 +1 0 140 0 +1 0 145 0 +1 0 150 0 diff --git a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #3-0-9257085d123728730be96b6d9fbb84ce b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #3-0-9257085d123728730be96b6d9fbb84ce index 1eea4a9b23..fcacbe3f69 100644 --- a/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #3-0-9257085d123728730be96b6d9fbb84ce +++ b/sql/hive/src/test/resources/golden/SPARK-8976 Wrong Result for Rollup #3-0-9257085d123728730be96b6d9fbb84ce @@ -1,10 +1,10 @@ -1 0 5 3 -1 0 15 3 -1 0 25 3 -1 0 60 3 -1 0 75 3 -1 0 80 3 -1 0 100 3 -1 0 140 3 -1 0 145 3 -1 0 150 3 +1 0 5 0 +1 0 15 0 +1 0 25 0 +1 0 60 0 +1 0 75 0 +1 0 80 0 +1 0 100 0 +1 0 140 0 +1 0 145 0 +1 0 150 0 |