diff options
Diffstat (limited to 'sql')
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala | 31 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 8 |
2 files changed, 23 insertions, 16 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 5deb2f81d1..85faa19bbf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -1029,24 +1029,27 @@ case class ScalaUDF( // such as IntegerType, its javaType is `int` and the returned type of user-defined // function is Object. Trying to convert an Object to `int` will cause casting exception. val evalCode = evals.map(_.code).mkString - val funcArguments = converterTerms.zipWithIndex.map { - case (converter, i) => - val eval = evals(i) - val dt = children(i).dataType - s"$converter.apply(${eval.isNull} ? null : (${ctx.boxedType(dt)}) ${eval.value})" - }.mkString(",") - val callFunc = s"${ctx.boxedType(ctx.javaType(dataType))} $resultTerm = " + - s"(${ctx.boxedType(ctx.javaType(dataType))})${catalystConverterTerm}" + - s".apply($funcTerm.apply($funcArguments));" + val (converters, funcArguments) = converterTerms.zipWithIndex.map { case (converter, i) => + val eval = evals(i) + val argTerm = ctx.freshName("arg") + val convert = s"Object $argTerm = ${eval.isNull} ? null : $converter.apply(${eval.value});" + (convert, argTerm) + }.unzip - evalCode + s""" - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - Boolean ${ev.isNull}; + val callFunc = s"${ctx.boxedType(dataType)} $resultTerm = " + + s"(${ctx.boxedType(dataType)})${catalystConverterTerm}" + + s".apply($funcTerm.apply(${funcArguments.mkString(", ")}));" + s""" + $evalCode + ${converters.mkString("\n")} $callFunc - ${ev.value} = $resultTerm; - ${ev.isNull} = $resultTerm == null; + boolean ${ev.isNull} = $resultTerm == null; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.value} = $resultTerm; + } """ } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 8887dc68a5..5353fefaf4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1144,9 +1144,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { // passing null into the UDF that could handle it val boxedUDF = udf[java.lang.Integer, java.lang.Integer] { - (i: java.lang.Integer) => if (i == null) -10 else i * 2 + (i: java.lang.Integer) => if (i == null) -10 else null } - checkAnswer(df.select(boxedUDF($"age")), Row(44) :: Row(-10) :: Nil) + checkAnswer(df.select(boxedUDF($"age")), Row(null) :: Row(-10) :: Nil) + + sqlContext.udf.register("boxedUDF", + (i: java.lang.Integer) => (if (i == null) -10 else null): java.lang.Integer) + checkAnswer(sql("select boxedUDF(null), boxedUDF(-1)"), Row(-10, null) :: Nil) val primitiveUDF = udf((i: Int) => i * 2) checkAnswer(df.select(primitiveUDF($"age")), Row(44) :: Row(null) :: Nil) |