aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDaoyuan Wang <daoyuan.wang@intel.com>2016-02-02 11:09:40 -0800
committerMichael Armbrust <michael@databricks.com>2016-02-02 11:09:40 -0800
commit358300c795025735c3b2f96c5447b1b227d4abc1 (patch)
tree834cca25a3ffb37cdc0a64ee8292c4e52d1f2fe2 /sql
parentcba1d6b659288bfcd8db83a6d778155bab2bbecf (diff)
downloadspark-358300c795025735c3b2f96c5447b1b227d4abc1.tar.gz
spark-358300c795025735c3b2f96c5447b1b227d4abc1.tar.bz2
spark-358300c795025735c3b2f96c5447b1b227d4abc1.zip
[SPARK-13056][SQL] map column would throw NPE if value is null
Jira: https://issues.apache.org/jira/browse/SPARK-13056 Create a map like { "a": "somestring", "b": null} Query like SELECT col["b"] FROM t1; NPE would be thrown. Author: Daoyuan Wang <daoyuan.wang@intel.com> Closes #10964 from adrian-wang/npewriter.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala15
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala10
2 files changed, 19 insertions, 6 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index 5256baaf43..9f2f82d68c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -218,7 +218,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
protected override def nullSafeEval(value: Any, ordinal: Any): Any = {
val baseValue = value.asInstanceOf[ArrayData]
val index = ordinal.asInstanceOf[Number].intValue()
- if (index >= baseValue.numElements() || index < 0) {
+ if (index >= baseValue.numElements() || index < 0 || baseValue.isNullAt(index)) {
null
} else {
baseValue.get(index, dataType)
@@ -267,6 +267,7 @@ case class GetMapValue(child: Expression, key: Expression)
val map = value.asInstanceOf[MapData]
val length = map.numElements()
val keys = map.keyArray()
+ val values = map.valueArray()
var i = 0
var found = false
@@ -278,10 +279,10 @@ case class GetMapValue(child: Expression, key: Expression)
}
}
- if (!found) {
+ if (!found || values.isNullAt(i)) {
null
} else {
- map.valueArray().get(i, dataType)
+ values.get(i, dataType)
}
}
@@ -291,10 +292,12 @@ case class GetMapValue(child: Expression, key: Expression)
val keys = ctx.freshName("keys")
val found = ctx.freshName("found")
val key = ctx.freshName("key")
+ val values = ctx.freshName("values")
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
s"""
final int $length = $eval1.numElements();
final ArrayData $keys = $eval1.keyArray();
+ final ArrayData $values = $eval1.valueArray();
int $index = 0;
boolean $found = false;
@@ -307,10 +310,10 @@ case class GetMapValue(child: Expression, key: Expression)
}
}
- if ($found) {
- ${ev.value} = ${ctx.getValue(eval1 + ".valueArray()", dataType, index)};
- } else {
+ if (!$found || $values.isNullAt($index)) {
${ev.isNull} = true;
+ } else {
+ ${ev.value} = ${ctx.getValue(values, dataType, index)};
}
"""
})
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 2b821c1056..79bfd4b44b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -2055,6 +2055,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
)
}
+ test("SPARK-13056: Null in map value causes NPE") {
+ val df = Seq(1 -> Map("abc" -> "somestring", "cba" -> null)).toDF("key", "value")
+ withTempTable("maptest") {
+ df.registerTempTable("maptest")
+ // local optimization will by pass codegen code, so we should keep the filter `key=1`
+ checkAnswer(sql("SELECT value['abc'] FROM maptest where key = 1"), Row("somestring"))
+ checkAnswer(sql("SELECT value['cba'] FROM maptest where key = 1"), Row(null))
+ }
+ }
+
test("hash function") {
val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j")
withTempTable("tbl") {