From c87531b765f8934a9a6c0f673617e0abfa5e5f0e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 1 Dec 2015 07:42:37 -0800 Subject: [SPARK-11949][SQL] Set field nullable property for GroupingSets to get correct results for null values JIRA: https://issues.apache.org/jira/browse/SPARK-11949 The result of cube plan uses incorrect schema. The schema of cube result should set nullable property to true because the grouping expressions will have null values. Author: Liang-Chi Hsieh Closes #10038 from viirya/fix-cube. --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 10 ++++++++-- .../scala/org/apache/spark/sql/DataFrameAggregateSuite.scala | 10 ++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 94ffbbb2e5..b8f212fca7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -223,6 +223,11 @@ class Analyzer( case other => Alias(other, other.toString)() } + // TODO: We need to use bitmasks to determine which grouping expressions need to be + // set as nullable. For example, if we have GROUPING SETS ((a,b), a), we do not need + // to change the nullability of a. + val attributeMap = groupByAliases.map(a => (a -> a.toAttribute.withNullability(true))).toMap + val aggregations: Seq[NamedExpression] = x.aggregations.map { // If an expression is an aggregate (contains a AggregateExpression) then we dont change // it so that the aggregation is computed on the unmodified value of its argument @@ -231,12 +236,13 @@ class Analyzer( // If not then its a grouping expression and we need to use the modified (with nulls from // Expand) value of the expression. case expr => expr.transformDown { - case e => groupByAliases.find(_.child.semanticEquals(e)).map(_.toAttribute).getOrElse(e) + case e => + groupByAliases.find(_.child.semanticEquals(e)).map(attributeMap(_)).getOrElse(e) }.asInstanceOf[NamedExpression] } val child = Project(x.child.output ++ groupByAliases, x.child) - val groupByAttributes = groupByAliases.map(_.toAttribute) + val groupByAttributes = groupByAliases.map(attributeMap(_)) Aggregate( groupByAttributes :+ VirtualColumn.groupingIdAttribute, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index b5c636d0de..b1004bc5bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.DecimalType +case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Double) class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -86,6 +87,15 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(null, 2013, 78000.0) :: Row(null, null, 113000.0) :: Nil ) + + val df0 = sqlContext.sparkContext.parallelize(Seq( + Fact(20151123, 18, 35, "room1", 18.6), + Fact(20151123, 18, 35, "room2", 22.4), + Fact(20151123, 18, 36, "room1", 17.4), + Fact(20151123, 18, 36, "room2", 25.6))).toDF() + + val cube0 = df0.cube("date", "hour", "minute", "room_name").agg(Map("temp" -> "avg")) + assert(cube0.where("date IS NULL").count > 0) } test("rollup overlapping columns") { -- cgit v1.2.3