aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHerman van Hovell <hvanhovell@databricks.com>2016-09-15 20:24:15 +0200
committerHerman van Hovell <hvanhovell@databricks.com>2016-09-15 20:24:15 +0200
commitd403562eb4b5b1d804909861d3e8b75d8f6323b9 (patch)
treed83e7ca7f9f3508a44c8fa0f613e9a33d156c5d0
parent5b8f7377d54f83b93ef2bfc2a01ca65fae6d3032 (diff)
downloadspark-d403562eb4b5b1d804909861d3e8b75d8f6323b9.tar.gz
spark-d403562eb4b5b1d804909861d3e8b75d8f6323b9.tar.bz2
spark-d403562eb4b5b1d804909861d3e8b75d8f6323b9.zip
[SPARK-17114][SQL] Fix aggregates grouped by literals with empty input
## What changes were proposed in this pull request? This PR fixes an issue with aggregates that have an empty input, and use a literals as their grouping keys. These aggregates are currently interpreted as aggregates **without** grouping keys, this triggers the ungrouped code path (which aways returns a single row). This PR fixes the `RemoveLiteralFromGroupExpressions` optimizer rule, which changes the semantics of the Aggregate by eliminating all literal grouping keys. ## How was this patch tested? Added tests to `SQLQueryTestSuite`. Author: Herman van Hovell <hvanhovell@databricks.com> Closes #15101 from hvanhovell/SPARK-17114-3.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala11
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala10
-rw-r--r--sql/core/src/test/resources/sql-tests/inputs/group-by.sql17
-rw-r--r--sql/core/src/test/resources/sql-tests/results/group-by.sql.out51
4 files changed, 86 insertions, 3 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index d2f0c97989..0df16b7a56 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -1098,9 +1098,16 @@ object ReplaceExceptWithAntiJoin extends Rule[LogicalPlan] {
*/
object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case a @ Aggregate(grouping, _, _) =>
+ case a @ Aggregate(grouping, _, _) if grouping.nonEmpty =>
val newGrouping = grouping.filter(!_.foldable)
- a.copy(groupingExpressions = newGrouping)
+ if (newGrouping.nonEmpty) {
+ a.copy(groupingExpressions = newGrouping)
+ } else {
+ // All grouping expressions are literals. We should not drop them all, because this can
+ // change the return semantics when the input of the Aggregate is empty (SPARK-17114). We
+ // instead replace this by single, easy to hash/sort, literal expression.
+ a.copy(groupingExpressions = Seq(Literal(0, IntegerType)))
+ }
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala
index 4c26c184b7..aecf59aee6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
class AggregateOptimizeSuite extends PlanTest {
- val conf = new SimpleCatalystConf(caseSensitiveAnalysis = false)
+ val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false)
val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
val analyzer = new Analyzer(catalog, conf)
@@ -49,6 +49,14 @@ class AggregateOptimizeSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
+ test("do not remove all grouping expressions if they are all literals") {
+ val query = testRelation.groupBy(Literal("1"), Literal(1) + Literal(2))(sum('b))
+ val optimized = Optimize.execute(analyzer.execute(query))
+ val correctAnswer = analyzer.execute(testRelation.groupBy(Literal(0))(sum('b)))
+
+ comparePlans(optimized, correctAnswer)
+ }
+
test("Remove aliased literals") {
val query = testRelation.select('a, Literal(1).as('y)).groupBy('a, 'y)(sum('b))
val optimized = Optimize.execute(analyzer.execute(query))
diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
new file mode 100644
index 0000000000..6741703d9d
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
@@ -0,0 +1,17 @@
+-- Temporary data.
+create temporary view myview as values 128, 256 as v(int_col);
+
+-- group by should produce all input rows,
+select int_col, count(*) from myview group by int_col;
+
+-- group by should produce a single row.
+select 'foo', count(*) from myview group by 1;
+
+-- group-by should not produce any rows (whole stage code generation).
+select 'foo' from myview where int_col == 0 group by 1;
+
+-- group-by should not produce any rows (hash aggregate).
+select 'foo', approx_count_distinct(int_col) from myview where int_col == 0 group by 1;
+
+-- group-by should not produce any rows (sort aggregate).
+select 'foo', max(struct(int_col)) from myview where int_col == 0 group by 1;
diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
new file mode 100644
index 0000000000..9127bd4dd4
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
@@ -0,0 +1,51 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 6
+
+
+-- !query 0
+create temporary view myview as values 128, 256 as v(int_col)
+-- !query 0 schema
+struct<>
+-- !query 0 output
+
+
+
+-- !query 1
+select int_col, count(*) from myview group by int_col
+-- !query 1 schema
+struct<int_col:int,count(1):bigint>
+-- !query 1 output
+128 1
+256 1
+
+
+-- !query 2
+select 'foo', count(*) from myview group by 1
+-- !query 2 schema
+struct<foo:string,count(1):bigint>
+-- !query 2 output
+foo 2
+
+
+-- !query 3
+select 'foo' from myview where int_col == 0 group by 1
+-- !query 3 schema
+struct<foo:string>
+-- !query 3 output
+
+
+
+-- !query 4
+select 'foo', approx_count_distinct(int_col) from myview where int_col == 0 group by 1
+-- !query 4 schema
+struct<foo:string,approx_count_distinct(int_col):bigint>
+-- !query 4 output
+
+
+
+-- !query 5
+select 'foo', max(struct(int_col)) from myview where int_col == 0 group by 1
+-- !query 5 schema
+struct<foo:string,max(struct(int_col)):struct<int_col:int>>
+-- !query 5 output
+