aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <simonh@tw.ibm.com>2016-07-23 10:27:16 +0800
committerWenchen Fan <wenchen@databricks.com>2016-07-23 10:27:16 +0800
commite10b8741d86a0a625d28bcb1c654736a260be85e (patch)
tree61c33b519258245ac269c10157497415eef7859b /sql
parent47f5b88db4d65f1870b16745d3c93d01051ba20b (diff)
downloadspark-e10b8741d86a0a625d28bcb1c654736a260be85e.tar.gz
spark-e10b8741d86a0a625d28bcb1c654736a260be85e.tar.bz2
spark-e10b8741d86a0a625d28bcb1c654736a260be85e.zip
[SPARK-16622][SQL] Fix NullPointerException when the returned value of the called method in Invoke is null
## What changes were proposed in this pull request? Currently we don't check the value returned by called method in `Invoke`. When the returned value is null and is assigned to a variable of primitive type, `NullPointerException` will be thrown. ## How was this patch tested? Jenkins tests. Author: Liang-Chi Hsieh <simonh@tw.ibm.com> Closes #14259 from viirya/agg-empty-ds.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala47
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala35
2 files changed, 65 insertions, 17 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index ea4dee174e..d6863ed2fd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -134,31 +134,42 @@ case class Invoke(
val argGen = arguments.map(_.genCode(ctx))
val argString = argGen.map(_.value).mkString(", ")
- val callFunc = if (method.isDefined && method.get.getReturnType.isPrimitive) {
- s"${obj.value}.$functionName($argString)"
- } else {
- s"(${ctx.boxedType(javaType)}) ${obj.value}.$functionName($argString)"
- }
+ val returnPrimitive = method.isDefined && method.get.getReturnType.isPrimitive
+ val needTryCatch = method.isDefined && method.get.getExceptionTypes.nonEmpty
- val setIsNull = if (propagateNull && arguments.nonEmpty) {
- s"boolean ${ev.isNull} = ${obj.isNull} || ${argGen.map(_.isNull).mkString(" || ")};"
+ def getFuncResult(resultVal: String, funcCall: String): String = if (needTryCatch) {
+ s"""
+ try {
+ $resultVal = $funcCall;
+ } catch (Exception e) {
+ org.apache.spark.unsafe.Platform.throwException(e);
+ }
+ """
} else {
- s"boolean ${ev.isNull} = ${obj.isNull};"
+ s"$resultVal = $funcCall;"
}
- val evaluate = if (method.forall(_.getExceptionTypes.isEmpty)) {
- s"final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : $callFunc;"
+ val evaluate = if (returnPrimitive) {
+ getFuncResult(ev.value, s"${obj.value}.$functionName($argString)")
} else {
+ val funcResult = ctx.freshName("funcResult")
s"""
- $javaType ${ev.value} = ${ctx.defaultValue(javaType)};
- try {
- ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(javaType)} : $callFunc;
- } catch (Exception e) {
- org.apache.spark.unsafe.Platform.throwException(e);
+ Object $funcResult = null;
+ ${getFuncResult(funcResult, s"${obj.value}.$functionName($argString)")}
+ if ($funcResult == null) {
+ ${ev.isNull} = true;
+ } else {
+ ${ev.value} = (${ctx.boxedType(javaType)}) $funcResult;
}
"""
}
+ val setIsNull = if (propagateNull && arguments.nonEmpty) {
+ s"boolean ${ev.isNull} = ${obj.isNull} || ${argGen.map(_.isNull).mkString(" || ")};"
+ } else {
+ s"boolean ${ev.isNull} = ${obj.isNull};"
+ }
+
// If the function can return null, we do an extra check to make sure our null bit is still set
// correctly.
val postNullCheck = if (ctx.defaultValue(dataType) == "null") {
@@ -166,12 +177,14 @@ case class Invoke(
} else {
""
}
-
val code = s"""
${obj.code}
${argGen.map(_.code).mkString("\n")}
$setIsNull
- $evaluate
+ $javaType ${ev.value} = ${ctx.defaultValue(dataType)};
+ if (!${ev.isNull}) {
+ $evaluate
+ }
$postNullCheck
"""
ev.copy(code = code)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
new file mode 100644
index 0000000000..ee65826cd5
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.objects.Invoke
+import org.apache.spark.sql.types.{IntegerType, ObjectType}
+
+
+class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
+
+ test("SPARK-16622: The returned value of the called method in Invoke can be null") {
+ val inputRow = InternalRow.fromSeq(Seq((false, null)))
+ val cls = classOf[Tuple2[Boolean, java.lang.Integer]]
+ val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
+ val invoke = Invoke(inputObject, "_2", IntegerType)
+ checkEvaluationWithGeneratedMutableProjection(invoke, null, inputRow)
+ }
+}