aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src/main/scala
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <simonh@tw.ibm.com>2016-07-23 10:27:16 +0800
committerWenchen Fan <wenchen@databricks.com>2016-07-23 10:27:16 +0800
commite10b8741d86a0a625d28bcb1c654736a260be85e (patch)
tree61c33b519258245ac269c10157497415eef7859b /sql/catalyst/src/main/scala
parent47f5b88db4d65f1870b16745d3c93d01051ba20b (diff)
downloadspark-e10b8741d86a0a625d28bcb1c654736a260be85e.tar.gz
spark-e10b8741d86a0a625d28bcb1c654736a260be85e.tar.bz2
spark-e10b8741d86a0a625d28bcb1c654736a260be85e.zip
[SPARK-16622][SQL] Fix NullPointerException when the returned value of the called method in Invoke is null
## What changes were proposed in this pull request? Currently we don't check the value returned by called method in `Invoke`. When the returned value is null and is assigned to a variable of primitive type, `NullPointerException` will be thrown. ## How was this patch tested? Jenkins tests. Author: Liang-Chi Hsieh <simonh@tw.ibm.com> Closes #14259 from viirya/agg-empty-ds.
Diffstat (limited to 'sql/catalyst/src/main/scala')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala47
1 files changed, 30 insertions, 17 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index ea4dee174e..d6863ed2fd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -134,31 +134,42 @@ case class Invoke(
val argGen = arguments.map(_.genCode(ctx))
val argString = argGen.map(_.value).mkString(", ")
- val callFunc = if (method.isDefined && method.get.getReturnType.isPrimitive) {
- s"${obj.value}.$functionName($argString)"
- } else {
- s"(${ctx.boxedType(javaType)}) ${obj.value}.$functionName($argString)"
- }
+ val returnPrimitive = method.isDefined && method.get.getReturnType.isPrimitive
+ val needTryCatch = method.isDefined && method.get.getExceptionTypes.nonEmpty
- val setIsNull = if (propagateNull && arguments.nonEmpty) {
- s"boolean ${ev.isNull} = ${obj.isNull} || ${argGen.map(_.isNull).mkString(" || ")};"
+ def getFuncResult(resultVal: String, funcCall: String): String = if (needTryCatch) {
+ s"""
+ try {
+ $resultVal = $funcCall;
+ } catch (Exception e) {
+ org.apache.spark.unsafe.Platform.throwException(e);
+ }
+ """
} else {
- s"boolean ${ev.isNull} = ${obj.isNull};"
+ s"$resultVal = $funcCall;"
}
- val evaluate = if (method.forall(_.getExceptionTypes.isEmpty)) {
- s"final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : $callFunc;"
+ val evaluate = if (returnPrimitive) {
+ getFuncResult(ev.value, s"${obj.value}.$functionName($argString)")
} else {
+ val funcResult = ctx.freshName("funcResult")
s"""
- $javaType ${ev.value} = ${ctx.defaultValue(javaType)};
- try {
- ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(javaType)} : $callFunc;
- } catch (Exception e) {
- org.apache.spark.unsafe.Platform.throwException(e);
+ Object $funcResult = null;
+ ${getFuncResult(funcResult, s"${obj.value}.$functionName($argString)")}
+ if ($funcResult == null) {
+ ${ev.isNull} = true;
+ } else {
+ ${ev.value} = (${ctx.boxedType(javaType)}) $funcResult;
}
"""
}
+ val setIsNull = if (propagateNull && arguments.nonEmpty) {
+ s"boolean ${ev.isNull} = ${obj.isNull} || ${argGen.map(_.isNull).mkString(" || ")};"
+ } else {
+ s"boolean ${ev.isNull} = ${obj.isNull};"
+ }
+
// If the function can return null, we do an extra check to make sure our null bit is still set
// correctly.
val postNullCheck = if (ctx.defaultValue(dataType) == "null") {
@@ -166,12 +177,14 @@ case class Invoke(
} else {
""
}
-
val code = s"""
${obj.code}
${argGen.map(_.code).mkString("\n")}
$setIsNull
- $evaluate
+ $javaType ${ev.value} = ${ctx.defaultValue(dataType)};
+ if (!${ev.isNull}) {
+ $evaluate
+ }
$postNullCheck
"""
ev.copy(code = code)