aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndrew Ray <ray.andrew@gmail.com>2017-03-17 16:43:42 +0800
committerWenchen Fan <wenchen@databricks.com>2017-03-17 16:43:42 +0800
commit13538cf3dd089222c7e12a3cd6e72ac836fa51ac (patch)
tree0e3b4053cf75d2feeb820c0660793dd5aeb74325
parent8537c00e0a17eff2a8c6745fbdd1d08873c0434d (diff)
downloadspark-13538cf3dd089222c7e12a3cd6e72ac836fa51ac.tar.gz
spark-13538cf3dd089222c7e12a3cd6e72ac836fa51ac.tar.bz2
spark-13538cf3dd089222c7e12a3cd6e72ac836fa51ac.zip
[SPARK-19882][SQL] Pivot with null as a distinct pivot value throws NPE
## What changes were proposed in this pull request? Allows null values of the pivot column to be included in the pivot values list without throwing NPE Note this PR was made as an alternative to #17224 but preserves the two phase aggregate operation that is needed for good performance. ## How was this patch tested? Additional unit test Author: Andrew Ray <ray.andrew@gmail.com> Closes #17226 from aray/pivot-null.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala18
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala14
3 files changed, 24 insertions, 10 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 68a4746a54..8cf4073826 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -524,7 +524,7 @@ class Analyzer(
} else {
val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value =>
def ifExpr(expr: Expression) = {
- If(EqualTo(pivotColumn, value), expr, Literal(null))
+ If(EqualNullSafe(pivotColumn, value), expr, Literal(null))
}
aggregates.map { aggregate =>
val filteredAggregate = aggregate.transformDown {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala
index 9ad31243e4..5237148692 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala
@@ -91,14 +91,12 @@ case class PivotFirst(
override def update(mutableAggBuffer: InternalRow, inputRow: InternalRow): Unit = {
val pivotColValue = pivotColumn.eval(inputRow)
- if (pivotColValue != null) {
- // We ignore rows whose pivot column value is not in the list of pivot column values.
- val index = pivotIndex.getOrElse(pivotColValue, -1)
- if (index >= 0) {
- val value = valueColumn.eval(inputRow)
- if (value != null) {
- updateRow(mutableAggBuffer, mutableAggBufferOffset + index, value)
- }
+ // We ignore rows whose pivot column value is not in the list of pivot column values.
+ val index = pivotIndex.getOrElse(pivotColValue, -1)
+ if (index >= 0) {
+ val value = valueColumn.eval(inputRow)
+ if (value != null) {
+ updateRow(mutableAggBuffer, mutableAggBufferOffset + index, value)
}
}
}
@@ -140,7 +138,9 @@ case class PivotFirst(
override val aggBufferAttributes: Seq[AttributeReference] =
- pivotIndex.toList.sortBy(_._2).map(kv => AttributeReference(kv._1.toString, valueDataType)())
+ pivotIndex.toList.sortBy(_._2).map { kv =>
+ AttributeReference(Option(kv._1).getOrElse("null").toString, valueDataType)()
+ }
override val aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
index 51ffe34172..ca3cb56767 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
@@ -216,4 +216,18 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{
Row("d", 15000.0, 48000.0) :: Row("J", 20000.0, 30000.0) :: Nil
)
}
+
+ test("pivot with null should not throw NPE") {
+ checkAnswer(
+ Seq(Tuple1(None), Tuple1(Some(1))).toDF("a").groupBy($"a").pivot("a").count(),
+ Row(null, 1, null) :: Row(1, null, 1) :: Nil)
+ }
+
+ test("pivot with null and aggregate type not supported by PivotFirst returns correct result") {
+ checkAnswer(
+ Seq(Tuple1(None), Tuple1(Some(1))).toDF("a")
+ .withColumn("b", expr("array(a, 7)"))
+ .groupBy($"a").pivot("a").agg(min($"b")),
+ Row(null, Seq(null, 7), null) :: Row(1, null, Seq(1, 7)) :: Nil)
+ }
}