diff options
author | Dongjoon Hyun <dongjoon@apache.org> | 2016-08-21 22:07:47 +0200 |
---|---|---|
committer | Herman van Hovell <hvanhovell@databricks.com> | 2016-08-21 22:07:47 +0200 |
commit | 91c2397684ab791572ac57ffb2a924ff058bb64f (patch) | |
tree | 83830eab6905a4ae73361fc1d5054dea25bd7964 | |
parent | ab7143463daf2056736c85e3a943c826b5992623 (diff) | |
download | spark-91c2397684ab791572ac57ffb2a924ff058bb64f.tar.gz spark-91c2397684ab791572ac57ffb2a924ff058bb64f.tar.bz2 spark-91c2397684ab791572ac57ffb2a924ff058bb64f.zip |
[SPARK-17098][SQL] Fix `NullPropagation` optimizer to handle `COUNT(NULL) OVER` correctly
## What changes were proposed in this pull request?
Currently, `NullPropagation` optimizer replaces `COUNT` on null literals in a bottom-up fashion. During that, `WindowExpression` is not covered properly. This PR adds the missing propagation logic.
**Before**
```scala
scala> sql("SELECT COUNT(1 + NULL) OVER ()").show
java.lang.UnsupportedOperationException: Cannot evaluate expression: cast(0 as bigint) windowspecdefinition(ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)
```
**After**
```scala
scala> sql("SELECT COUNT(1 + NULL) OVER ()").show
+----------------------------------------------------------------------------------------------+
|count((1 + CAST(NULL AS INT))) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)|
+----------------------------------------------------------------------------------------------+
| 0|
+----------------------------------------------------------------------------------------------+
```
## How was this patch tested?
Pass the Jenkins test with a new test case.
Author: Dongjoon Hyun <dongjoon@apache.org>
Closes #14689 from dongjoon-hyun/SPARK-17098.
3 files changed, 49 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index ce57f05868..9a0ff8a9b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -619,6 +619,8 @@ object NullPropagation extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { + case e @ WindowExpression(Cast(Literal(0L, _), _), _) => + Cast(Literal(0L), e.dataType) case e @ AggregateExpression(Count(exprs), _, _, _) if !exprs.exists(nonNullLiteral) => Cast(Literal(0L), e.dataType) case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType) diff --git a/sql/core/src/test/resources/sql-tests/inputs/null-propagation.sql b/sql/core/src/test/resources/sql-tests/inputs/null-propagation.sql new file mode 100644 index 0000000000..66549da797 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/null-propagation.sql @@ -0,0 +1,9 @@ + +-- count(null) should be 0 +SELECT COUNT(NULL) FROM VALUES 1, 2, 3; +SELECT COUNT(1 + NULL) FROM VALUES 1, 2, 3; + +-- count(null) on window should be 0 +SELECT COUNT(NULL) OVER () FROM VALUES 1, 2, 3; +SELECT COUNT(1 + NULL) OVER () FROM VALUES 1, 2, 3; + diff --git a/sql/core/src/test/resources/sql-tests/results/null-propagation.sql.out b/sql/core/src/test/resources/sql-tests/results/null-propagation.sql.out new file mode 100644 index 0000000000..ed3a651aa6 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/null-propagation.sql.out @@ -0,0 +1,38 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 4 + + +-- !query 0 +SELECT COUNT(NULL) FROM VALUES 1, 2, 3 +-- !query 0 schema +struct<count(NULL):bigint> +-- !query 0 output +0 + + +-- !query 1 +SELECT COUNT(1 + NULL) FROM VALUES 1, 2, 3 +-- !query 1 schema +struct<count((1 + CAST(NULL AS INT))):bigint> +-- !query 1 output +0 + + +-- !query 2 +SELECT COUNT(NULL) OVER () FROM VALUES 1, 2, 3 +-- !query 2 schema +struct<count(NULL) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING):bigint> +-- !query 2 output +0 +0 +0 + + +-- !query 3 +SELECT COUNT(1 + NULL) OVER () FROM VALUES 1, 2, 3 +-- !query 3 schema +struct<count((1 + CAST(NULL AS INT))) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING):bigint> +-- !query 3 output +0 +0 +0 |