diff options
author | Daoyuan Wang <daoyuan.wang@intel.com> | 2016-02-02 11:09:40 -0800 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2016-02-02 11:09:40 -0800 |
commit | 358300c795025735c3b2f96c5447b1b227d4abc1 (patch) | |
tree | 834cca25a3ffb37cdc0a64ee8292c4e52d1f2fe2 /sql | |
parent | cba1d6b659288bfcd8db83a6d778155bab2bbecf (diff) | |
download | spark-358300c795025735c3b2f96c5447b1b227d4abc1.tar.gz spark-358300c795025735c3b2f96c5447b1b227d4abc1.tar.bz2 spark-358300c795025735c3b2f96c5447b1b227d4abc1.zip |
[SPARK-13056][SQL] map column would throw NPE if value is null
Jira:
https://issues.apache.org/jira/browse/SPARK-13056
Create a map like
{ "a": "somestring", "b": null}
Query like
SELECT col["b"] FROM t1;
NPE would be thrown.
Author: Daoyuan Wang <daoyuan.wang@intel.com>
Closes #10964 from adrian-wang/npewriter.
Diffstat (limited to 'sql')
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala | 15 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 10 |
2 files changed, 19 insertions, 6 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 5256baaf43..9f2f82d68c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -218,7 +218,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression) protected override def nullSafeEval(value: Any, ordinal: Any): Any = { val baseValue = value.asInstanceOf[ArrayData] val index = ordinal.asInstanceOf[Number].intValue() - if (index >= baseValue.numElements() || index < 0) { + if (index >= baseValue.numElements() || index < 0 || baseValue.isNullAt(index)) { null } else { baseValue.get(index, dataType) @@ -267,6 +267,7 @@ case class GetMapValue(child: Expression, key: Expression) val map = value.asInstanceOf[MapData] val length = map.numElements() val keys = map.keyArray() + val values = map.valueArray() var i = 0 var found = false @@ -278,10 +279,10 @@ case class GetMapValue(child: Expression, key: Expression) } } - if (!found) { + if (!found || values.isNullAt(i)) { null } else { - map.valueArray().get(i, dataType) + values.get(i, dataType) } } @@ -291,10 +292,12 @@ case class GetMapValue(child: Expression, key: Expression) val keys = ctx.freshName("keys") val found = ctx.freshName("found") val key = ctx.freshName("key") + val values = ctx.freshName("values") nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" final int $length = $eval1.numElements(); final ArrayData $keys = $eval1.keyArray(); + final ArrayData $values = $eval1.valueArray(); int $index = 0; boolean $found = false; @@ -307,10 +310,10 @@ case class GetMapValue(child: Expression, key: Expression) } } - if ($found) { - ${ev.value} = ${ctx.getValue(eval1 + ".valueArray()", dataType, index)}; - } else { + if (!$found || $values.isNullAt($index)) { ${ev.isNull} = true; + } else { + ${ev.value} = ${ctx.getValue(values, dataType, index)}; } """ }) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 2b821c1056..79bfd4b44b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2055,6 +2055,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { ) } + test("SPARK-13056: Null in map value causes NPE") { + val df = Seq(1 -> Map("abc" -> "somestring", "cba" -> null)).toDF("key", "value") + withTempTable("maptest") { + df.registerTempTable("maptest") + // local optimization will by pass codegen code, so we should keep the filter `key=1` + checkAnswer(sql("SELECT value['abc'] FROM maptest where key = 1"), Row("somestring")) + checkAnswer(sql("SELECT value['cba'] FROM maptest where key = 1"), Row(null)) + } + } + test("hash function") { val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") withTempTable("tbl") { |