diff options
author | Liang-Chi Hsieh <simonh@tw.ibm.com> | 2016-07-23 10:27:16 +0800 |
---|---|---|
committer | Wenchen Fan <wenchen@databricks.com> | 2016-07-23 10:27:16 +0800 |
commit | e10b8741d86a0a625d28bcb1c654736a260be85e (patch) | |
tree | 61c33b519258245ac269c10157497415eef7859b /sql/catalyst/src/main/scala | |
parent | 47f5b88db4d65f1870b16745d3c93d01051ba20b (diff) | |
download | spark-e10b8741d86a0a625d28bcb1c654736a260be85e.tar.gz spark-e10b8741d86a0a625d28bcb1c654736a260be85e.tar.bz2 spark-e10b8741d86a0a625d28bcb1c654736a260be85e.zip |
[SPARK-16622][SQL] Fix NullPointerException when the returned value of the called method in Invoke is null
## What changes were proposed in this pull request?
Currently we don't check the value returned by called method in `Invoke`. When the returned value is null and is assigned to a variable of primitive type, `NullPointerException` will be thrown.
## How was this patch tested?
Jenkins tests.
Author: Liang-Chi Hsieh <simonh@tw.ibm.com>
Closes #14259 from viirya/agg-empty-ds.
Diffstat (limited to 'sql/catalyst/src/main/scala')
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala | 47 |
1 files changed, 30 insertions, 17 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index ea4dee174e..d6863ed2fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -134,31 +134,42 @@ case class Invoke( val argGen = arguments.map(_.genCode(ctx)) val argString = argGen.map(_.value).mkString(", ") - val callFunc = if (method.isDefined && method.get.getReturnType.isPrimitive) { - s"${obj.value}.$functionName($argString)" - } else { - s"(${ctx.boxedType(javaType)}) ${obj.value}.$functionName($argString)" - } + val returnPrimitive = method.isDefined && method.get.getReturnType.isPrimitive + val needTryCatch = method.isDefined && method.get.getExceptionTypes.nonEmpty - val setIsNull = if (propagateNull && arguments.nonEmpty) { - s"boolean ${ev.isNull} = ${obj.isNull} || ${argGen.map(_.isNull).mkString(" || ")};" + def getFuncResult(resultVal: String, funcCall: String): String = if (needTryCatch) { + s""" + try { + $resultVal = $funcCall; + } catch (Exception e) { + org.apache.spark.unsafe.Platform.throwException(e); + } + """ } else { - s"boolean ${ev.isNull} = ${obj.isNull};" + s"$resultVal = $funcCall;" } - val evaluate = if (method.forall(_.getExceptionTypes.isEmpty)) { - s"final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : $callFunc;" + val evaluate = if (returnPrimitive) { + getFuncResult(ev.value, s"${obj.value}.$functionName($argString)") } else { + val funcResult = ctx.freshName("funcResult") s""" - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; - try { - ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(javaType)} : $callFunc; - } catch (Exception e) { - org.apache.spark.unsafe.Platform.throwException(e); + Object $funcResult = null; + ${getFuncResult(funcResult, s"${obj.value}.$functionName($argString)")} + if ($funcResult == null) { + ${ev.isNull} = true; + } else { + ${ev.value} = (${ctx.boxedType(javaType)}) $funcResult; } """ } + val setIsNull = if (propagateNull && arguments.nonEmpty) { + s"boolean ${ev.isNull} = ${obj.isNull} || ${argGen.map(_.isNull).mkString(" || ")};" + } else { + s"boolean ${ev.isNull} = ${obj.isNull};" + } + // If the function can return null, we do an extra check to make sure our null bit is still set // correctly. val postNullCheck = if (ctx.defaultValue(dataType) == "null") { @@ -166,12 +177,14 @@ case class Invoke( } else { "" } - val code = s""" ${obj.code} ${argGen.map(_.code).mkString("\n")} $setIsNull - $evaluate + $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + $evaluate + } $postNullCheck """ ev.copy(code = code) |