aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSameer Agarwal <sameer@databricks.com>2016-04-24 22:52:50 -0700
committerDavies Liu <davies.liu@gmail.com>2016-04-24 22:52:50 -0700
commitcbdcd4edab48593f6331bc267eb94e40908733e5 (patch)
tree400cba8511ec4273b73247d2ce9a9fc433e88cfa
parentc752b6c5ec488b87c3aaaa86902dd4da9b4b406f (diff)
downloadspark-cbdcd4edab48593f6331bc267eb94e40908733e5.tar.gz
spark-cbdcd4edab48593f6331bc267eb94e40908733e5.tar.bz2
spark-cbdcd4edab48593f6331bc267eb94e40908733e5.zip
[SPARK-14870] [SQL] Fix NPE in TPCDS q14a
## What changes were proposed in this pull request? This PR fixes a bug in `TungstenAggregate` that manifests while aggregating by keys over nullable `BigDecimal` columns. This causes a null pointer exception while executing TPCDS q14a. ## How was this patch tested? 1. Added regression test in `DataFrameAggregateSuite`. 2. Verified that TPCDS q14a works Author: Sameer Agarwal <sameer@databricks.com> Closes #12651 from sameeragarwal/tpcds-fix.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala15
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala15
4 files changed, 37 insertions, 3 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index fa09f821fc..e4fa429b37 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -239,16 +239,19 @@ class CodegenContext {
/**
* Update a column in MutableRow from ExprCode.
+ *
+ * @param isVectorized True if the underlying row is of type `ColumnarBatch.Row`, false otherwise
*/
def updateColumn(
row: String,
dataType: DataType,
ordinal: Int,
ev: ExprCode,
- nullable: Boolean): String = {
+ nullable: Boolean,
+ isVectorized: Boolean = false): String = {
if (nullable) {
// Can't call setNullAt on DecimalType, because we need to keep the offset
- if (dataType.isInstanceOf[DecimalType]) {
+ if (!isVectorized && dataType.isInstanceOf[DecimalType]) {
s"""
if (!${ev.isNull}) {
${setColumn(row, dataType, ordinal, ev.value)};
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 782da0ea60..49db75e141 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -633,7 +633,8 @@ case class TungstenAggregate(
updateExpr.map(BindReferences.bindReference(_, inputAttr).genCode(ctx))
val updateVectorizedRow = vectorizedRowEvals.zipWithIndex.map { case (ev, i) =>
val dt = updateExpr(i).dataType
- ctx.updateColumn(vectorizedRowBuffer, dt, i, ev, updateExpr(i).nullable)
+ ctx.updateColumn(vectorizedRowBuffer, dt, i, ev, updateExpr(i).nullable,
+ isVectorized = true)
}
Option(
s"""
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 7d96ef6fe0..0fcfb97d2b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -61,6 +61,21 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
df1.groupBy("key").min("value2"),
Seq(Row("a", 0), Row("b", 4))
)
+
+ checkAnswer(
+ decimalData.groupBy("a").agg(sum("b")),
+ Seq(Row(new java.math.BigDecimal(1.0), new java.math.BigDecimal(3.0)),
+ Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(3.0)),
+ Row(new java.math.BigDecimal(3.0), new java.math.BigDecimal(3.0)))
+ )
+
+ checkAnswer(
+ decimalDataWithNulls.groupBy("a").agg(sum("b")),
+ Seq(Row(new java.math.BigDecimal(1.0), new java.math.BigDecimal(1.0)),
+ Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(1.0)),
+ Row(new java.math.BigDecimal(3.0), new java.math.BigDecimal(3.0)),
+ Row(null, new java.math.BigDecimal(2.0)))
+ )
}
test("rollup") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
index 7fa6760b71..c5f25fa1df 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
@@ -103,6 +103,19 @@ private[sql] trait SQLTestData { self =>
df
}
+ protected lazy val decimalDataWithNulls: DataFrame = {
+ val df = sqlContext.sparkContext.parallelize(
+ DecimalDataWithNulls(1, 1) ::
+ DecimalDataWithNulls(1, null) ::
+ DecimalDataWithNulls(2, 1) ::
+ DecimalDataWithNulls(2, null) ::
+ DecimalDataWithNulls(3, 1) ::
+ DecimalDataWithNulls(3, 2) ::
+ DecimalDataWithNulls(null, 2) :: Nil).toDF()
+ df.registerTempTable("decimalDataWithNulls")
+ df
+ }
+
protected lazy val binaryData: DataFrame = {
val df = sqlContext.sparkContext.parallelize(
BinaryData("12".getBytes(StandardCharsets.UTF_8), 1) ::
@@ -267,6 +280,7 @@ private[sql] trait SQLTestData { self =>
negativeData
largeAndSmallInts
decimalData
+ decimalDataWithNulls
binaryData
upperCaseData
lowerCaseData
@@ -296,6 +310,7 @@ private[sql] object SQLTestData {
case class TestData3(a: Int, b: Option[Int])
case class LargeAndSmallInts(a: Int, b: Int)
case class DecimalData(a: BigDecimal, b: BigDecimal)
+ case class DecimalDataWithNulls(a: BigDecimal, b: BigDecimal)
case class BinaryData(a: Array[Byte], b: Int)
case class UpperCaseData(N: Int, L: String)
case class LowerCaseData(n: Int, l: String)