aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHerman van Hovell <hvanhovell@databricks.com>2016-09-28 16:25:10 -0700
committerReynold Xin <rxin@databricks.com>2016-09-28 16:25:10 -0700
commit7d09232028967978d9db314ec041a762599f636b (patch)
tree464184d18818f790ee486b69cf580461dfe97dc8
parent557d6e32272dee4eaa0f426cc3e2f82ea361c3da (diff)
downloadspark-7d09232028967978d9db314ec041a762599f636b.tar.gz
spark-7d09232028967978d9db314ec041a762599f636b.tar.bz2
spark-7d09232028967978d9db314ec041a762599f636b.zip
[SPARK-17641][SQL] Collect_list/Collect_set should not collect null values.
## What changes were proposed in this pull request? We added native versions of `collect_set` and `collect_list` in Spark 2.0. These currently also (try to) collect null values, this is different from the original Hive implementation. This PR fixes this by adding a null check to the `Collect.update` method. ## How was this patch tested? Added a regression test to `DataFrameAggregateSuite`. Author: Herman van Hovell <hvanhovell@databricks.com> Closes #15208 from hvanhovell/SPARK-17641.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala12
2 files changed, 18 insertions, 1 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
index 896ff61b23..78a388d206 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
@@ -65,7 +65,12 @@ abstract class Collect extends ImperativeAggregate {
}
override def update(b: MutableRow, input: InternalRow): Unit = {
- buffer += child.eval(input)
+ // Do not allow null values. We follow the semantics of Hive's collect_list/collect_set here.
+ // See: org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMkCollectionEvaluator
+ val value = child.eval(input)
+ if (value != null) {
+ buffer += value
+ }
}
override def merge(buffer: MutableRow, input: InternalRow): Unit = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 0e172bee4f..7aa4f0026f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -477,6 +477,18 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
assert(error.message.contains("collect_set() cannot have map type data"))
}
+ test("SPARK-17641: collect functions should not collect null values") {
+ val df = Seq(("1", 2), (null, 2), ("1", 4)).toDF("a", "b")
+ checkAnswer(
+ df.select(collect_list($"a"), collect_list($"b")),
+ Seq(Row(Seq("1", "1"), Seq(2, 2, 4)))
+ )
+ checkAnswer(
+ df.select(collect_set($"a"), collect_set($"b")),
+ Seq(Row(Seq("1"), Seq(2, 4)))
+ )
+ }
+
test("SPARK-14664: Decimal sum/avg over window should work.") {
checkAnswer(
spark.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"),