From e10b8741d86a0a625d28bcb1c654736a260be85e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 23 Jul 2016 10:27:16 +0800 Subject: [SPARK-16622][SQL] Fix NullPointerException when the returned value of the called method in Invoke is null ## What changes were proposed in this pull request? Currently we don't check the value returned by called method in `Invoke`. When the returned value is null and is assigned to a variable of primitive type, `NullPointerException` will be thrown. ## How was this patch tested? Jenkins tests. Author: Liang-Chi Hsieh Closes #14259 from viirya/agg-empty-ds. --- .../sql/catalyst/expressions/objects/objects.scala | 47 ++++++++++++++-------- .../expressions/ObjectExpressionsSuite.scala | 35 ++++++++++++++++ 2 files changed, 65 insertions(+), 17 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index ea4dee174e..d6863ed2fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -134,31 +134,42 @@ case class Invoke( val argGen = arguments.map(_.genCode(ctx)) val argString = argGen.map(_.value).mkString(", ") - val callFunc = if (method.isDefined && method.get.getReturnType.isPrimitive) { - s"${obj.value}.$functionName($argString)" - } else { - s"(${ctx.boxedType(javaType)}) ${obj.value}.$functionName($argString)" - } + val returnPrimitive = method.isDefined && method.get.getReturnType.isPrimitive + val needTryCatch = method.isDefined && method.get.getExceptionTypes.nonEmpty - val setIsNull = if (propagateNull && arguments.nonEmpty) { - s"boolean ${ev.isNull} = ${obj.isNull} || ${argGen.map(_.isNull).mkString(" || ")};" + def getFuncResult(resultVal: String, funcCall: String): String = if (needTryCatch) { + s""" + try { + $resultVal = $funcCall; + } catch (Exception e) { + org.apache.spark.unsafe.Platform.throwException(e); + } + """ } else { - s"boolean ${ev.isNull} = ${obj.isNull};" + s"$resultVal = $funcCall;" } - val evaluate = if (method.forall(_.getExceptionTypes.isEmpty)) { - s"final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : $callFunc;" + val evaluate = if (returnPrimitive) { + getFuncResult(ev.value, s"${obj.value}.$functionName($argString)") } else { + val funcResult = ctx.freshName("funcResult") s""" - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; - try { - ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(javaType)} : $callFunc; - } catch (Exception e) { - org.apache.spark.unsafe.Platform.throwException(e); + Object $funcResult = null; + ${getFuncResult(funcResult, s"${obj.value}.$functionName($argString)")} + if ($funcResult == null) { + ${ev.isNull} = true; + } else { + ${ev.value} = (${ctx.boxedType(javaType)}) $funcResult; } """ } + val setIsNull = if (propagateNull && arguments.nonEmpty) { + s"boolean ${ev.isNull} = ${obj.isNull} || ${argGen.map(_.isNull).mkString(" || ")};" + } else { + s"boolean ${ev.isNull} = ${obj.isNull};" + } + // If the function can return null, we do an extra check to make sure our null bit is still set // correctly. val postNullCheck = if (ctx.defaultValue(dataType) == "null") { @@ -166,12 +177,14 @@ case class Invoke( } else { "" } - val code = s""" ${obj.code} ${argGen.map(_.code).mkString("\n")} $setIsNull - $evaluate + $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + $evaluate + } $postNullCheck """ ev.copy(code = code) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala new file mode 100644 index 0000000000..ee65826cd5 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.objects.Invoke +import org.apache.spark.sql.types.{IntegerType, ObjectType} + + +class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("SPARK-16622: The returned value of the called method in Invoke can be null") { + val inputRow = InternalRow.fromSeq(Seq((false, null))) + val cls = classOf[Tuple2[Boolean, java.lang.Integer]] + val inputObject = BoundReference(0, ObjectType(cls), nullable = true) + val invoke = Invoke(inputObject, "_2", IntegerType) + checkEvaluationWithGeneratedMutableProjection(invoke, null, inputRow) + } +} -- cgit v1.2.3