From cbdcd4edab48593f6331bc267eb94e40908733e5 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Sun, 24 Apr 2016 22:52:50 -0700 Subject: [SPARK-14870] [SQL] Fix NPE in TPCDS q14a ## What changes were proposed in this pull request? This PR fixes a bug in `TungstenAggregate` that manifests while aggregating by keys over nullable `BigDecimal` columns. This causes a null pointer exception while executing TPCDS q14a. ## How was this patch tested? 1. Added regression test in `DataFrameAggregateSuite`. 2. Verified that TPCDS q14a works Author: Sameer Agarwal Closes #12651 from sameeragarwal/tpcds-fix. --- .../sql/catalyst/expressions/codegen/CodeGenerator.scala | 7 +++++-- .../spark/sql/execution/aggregate/TungstenAggregate.scala | 3 ++- .../org/apache/spark/sql/DataFrameAggregateSuite.scala | 15 +++++++++++++++ .../scala/org/apache/spark/sql/test/SQLTestData.scala | 15 +++++++++++++++ 4 files changed, 37 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index fa09f821fc..e4fa429b37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -239,16 +239,19 @@ class CodegenContext { /** * Update a column in MutableRow from ExprCode. + * + * @param isVectorized True if the underlying row is of type `ColumnarBatch.Row`, false otherwise */ def updateColumn( row: String, dataType: DataType, ordinal: Int, ev: ExprCode, - nullable: Boolean): String = { + nullable: Boolean, + isVectorized: Boolean = false): String = { if (nullable) { // Can't call setNullAt on DecimalType, because we need to keep the offset - if (dataType.isInstanceOf[DecimalType]) { + if (!isVectorized && dataType.isInstanceOf[DecimalType]) { s""" if (!${ev.isNull}) { ${setColumn(row, dataType, ordinal, ev.value)}; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 782da0ea60..49db75e141 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -633,7 +633,8 @@ case class TungstenAggregate( updateExpr.map(BindReferences.bindReference(_, inputAttr).genCode(ctx)) val updateVectorizedRow = vectorizedRowEvals.zipWithIndex.map { case (ev, i) => val dt = updateExpr(i).dataType - ctx.updateColumn(vectorizedRowBuffer, dt, i, ev, updateExpr(i).nullable) + ctx.updateColumn(vectorizedRowBuffer, dt, i, ev, updateExpr(i).nullable, + isVectorized = true) } Option( s""" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 7d96ef6fe0..0fcfb97d2b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -61,6 +61,21 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { df1.groupBy("key").min("value2"), Seq(Row("a", 0), Row("b", 4)) ) + + checkAnswer( + decimalData.groupBy("a").agg(sum("b")), + Seq(Row(new java.math.BigDecimal(1.0), new java.math.BigDecimal(3.0)), + Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(3.0)), + Row(new java.math.BigDecimal(3.0), new java.math.BigDecimal(3.0))) + ) + + checkAnswer( + decimalDataWithNulls.groupBy("a").agg(sum("b")), + Seq(Row(new java.math.BigDecimal(1.0), new java.math.BigDecimal(1.0)), + Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(1.0)), + Row(new java.math.BigDecimal(3.0), new java.math.BigDecimal(3.0)), + Row(null, new java.math.BigDecimal(2.0))) + ) } test("rollup") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index 7fa6760b71..c5f25fa1df 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -103,6 +103,19 @@ private[sql] trait SQLTestData { self => df } + protected lazy val decimalDataWithNulls: DataFrame = { + val df = sqlContext.sparkContext.parallelize( + DecimalDataWithNulls(1, 1) :: + DecimalDataWithNulls(1, null) :: + DecimalDataWithNulls(2, 1) :: + DecimalDataWithNulls(2, null) :: + DecimalDataWithNulls(3, 1) :: + DecimalDataWithNulls(3, 2) :: + DecimalDataWithNulls(null, 2) :: Nil).toDF() + df.registerTempTable("decimalDataWithNulls") + df + } + protected lazy val binaryData: DataFrame = { val df = sqlContext.sparkContext.parallelize( BinaryData("12".getBytes(StandardCharsets.UTF_8), 1) :: @@ -267,6 +280,7 @@ private[sql] trait SQLTestData { self => negativeData largeAndSmallInts decimalData + decimalDataWithNulls binaryData upperCaseData lowerCaseData @@ -296,6 +310,7 @@ private[sql] object SQLTestData { case class TestData3(a: Int, b: Option[Int]) case class LargeAndSmallInts(a: Int, b: Int) case class DecimalData(a: BigDecimal, b: BigDecimal) + case class DecimalDataWithNulls(a: BigDecimal, b: BigDecimal) case class BinaryData(a: Array[Byte], b: Int) case class UpperCaseData(N: Int, L: String) case class LowerCaseData(n: Int, l: String) -- cgit v1.2.3