diff options
author | Andrew Ray <ray.andrew@gmail.com> | 2017-03-17 16:43:42 +0800 |
---|---|---|
committer | Wenchen Fan <wenchen@databricks.com> | 2017-03-17 16:43:42 +0800 |
commit | 13538cf3dd089222c7e12a3cd6e72ac836fa51ac (patch) | |
tree | 0e3b4053cf75d2feeb820c0660793dd5aeb74325 | |
parent | 8537c00e0a17eff2a8c6745fbdd1d08873c0434d (diff) | |
download | spark-13538cf3dd089222c7e12a3cd6e72ac836fa51ac.tar.gz spark-13538cf3dd089222c7e12a3cd6e72ac836fa51ac.tar.bz2 spark-13538cf3dd089222c7e12a3cd6e72ac836fa51ac.zip |
[SPARK-19882][SQL] Pivot with null as a distinct pivot value throws NPE
## What changes were proposed in this pull request?
Allows null values of the pivot column to be included in the pivot values list without throwing NPE
Note this PR was made as an alternative to #17224 but preserves the two phase aggregate operation that is needed for good performance.
## How was this patch tested?
Additional unit test
Author: Andrew Ray <ray.andrew@gmail.com>
Closes #17226 from aray/pivot-null.
3 files changed, 24 insertions, 10 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 68a4746a54..8cf4073826 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -524,7 +524,7 @@ class Analyzer( } else { val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value => def ifExpr(expr: Expression) = { - If(EqualTo(pivotColumn, value), expr, Literal(null)) + If(EqualNullSafe(pivotColumn, value), expr, Literal(null)) } aggregates.map { aggregate => val filteredAggregate = aggregate.transformDown { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala index 9ad31243e4..5237148692 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala @@ -91,14 +91,12 @@ case class PivotFirst( override def update(mutableAggBuffer: InternalRow, inputRow: InternalRow): Unit = { val pivotColValue = pivotColumn.eval(inputRow) - if (pivotColValue != null) { - // We ignore rows whose pivot column value is not in the list of pivot column values. - val index = pivotIndex.getOrElse(pivotColValue, -1) - if (index >= 0) { - val value = valueColumn.eval(inputRow) - if (value != null) { - updateRow(mutableAggBuffer, mutableAggBufferOffset + index, value) - } + // We ignore rows whose pivot column value is not in the list of pivot column values. + val index = pivotIndex.getOrElse(pivotColValue, -1) + if (index >= 0) { + val value = valueColumn.eval(inputRow) + if (value != null) { + updateRow(mutableAggBuffer, mutableAggBufferOffset + index, value) } } } @@ -140,7 +138,9 @@ case class PivotFirst( override val aggBufferAttributes: Seq[AttributeReference] = - pivotIndex.toList.sortBy(_._2).map(kv => AttributeReference(kv._1.toString, valueDataType)()) + pivotIndex.toList.sortBy(_._2).map { kv => + AttributeReference(Option(kv._1).getOrElse("null").toString, valueDataType)() + } override val aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index 51ffe34172..ca3cb56767 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -216,4 +216,18 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{ Row("d", 15000.0, 48000.0) :: Row("J", 20000.0, 30000.0) :: Nil ) } + + test("pivot with null should not throw NPE") { + checkAnswer( + Seq(Tuple1(None), Tuple1(Some(1))).toDF("a").groupBy($"a").pivot("a").count(), + Row(null, 1, null) :: Row(1, null, 1) :: Nil) + } + + test("pivot with null and aggregate type not supported by PivotFirst returns correct result") { + checkAnswer( + Seq(Tuple1(None), Tuple1(Some(1))).toDF("a") + .withColumn("b", expr("array(a, 7)")) + .groupBy($"a").pivot("a").agg(min($"b")), + Row(null, Seq(null, 7), null) :: Row(1, null, Seq(1, 7)) :: Nil) + } } |