diff options
author | Herman van Hovell <hvanhovell@databricks.com> | 2016-09-28 16:25:10 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-09-28 16:25:10 -0700 |
commit | 7d09232028967978d9db314ec041a762599f636b (patch) | |
tree | 464184d18818f790ee486b69cf580461dfe97dc8 /sql/core/src/test | |
parent | 557d6e32272dee4eaa0f426cc3e2f82ea361c3da (diff) | |
download | spark-7d09232028967978d9db314ec041a762599f636b.tar.gz spark-7d09232028967978d9db314ec041a762599f636b.tar.bz2 spark-7d09232028967978d9db314ec041a762599f636b.zip |
[SPARK-17641][SQL] Collect_list/Collect_set should not collect null values.
## What changes were proposed in this pull request?
We added native versions of `collect_set` and `collect_list` in Spark 2.0. These currently also (try to) collect null values, this is different from the original Hive implementation. This PR fixes this by adding a null check to the `Collect.update` method.
## How was this patch tested?
Added a regression test to `DataFrameAggregateSuite`.
Author: Herman van Hovell <hvanhovell@databricks.com>
Closes #15208 from hvanhovell/SPARK-17641.
Diffstat (limited to 'sql/core/src/test')
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 0e172bee4f..7aa4f0026f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -477,6 +477,18 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { assert(error.message.contains("collect_set() cannot have map type data")) } + test("SPARK-17641: collect functions should not collect null values") { + val df = Seq(("1", 2), (null, 2), ("1", 4)).toDF("a", "b") + checkAnswer( + df.select(collect_list($"a"), collect_list($"b")), + Seq(Row(Seq("1", "1"), Seq(2, 2, 4))) + ) + checkAnswer( + df.select(collect_set($"a"), collect_set($"b")), + Seq(Row(Seq("1"), Seq(2, 4))) + ) + } + test("SPARK-14664: Decimal sum/avg over window should work.") { checkAnswer( spark.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"), |