aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala31
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala8
2 files changed, 23 insertions, 16 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
index 5deb2f81d1..85faa19bbf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
@@ -1029,24 +1029,27 @@ case class ScalaUDF(
// such as IntegerType, its javaType is `int` and the returned type of user-defined
// function is Object. Trying to convert an Object to `int` will cause casting exception.
val evalCode = evals.map(_.code).mkString
- val funcArguments = converterTerms.zipWithIndex.map {
- case (converter, i) =>
- val eval = evals(i)
- val dt = children(i).dataType
- s"$converter.apply(${eval.isNull} ? null : (${ctx.boxedType(dt)}) ${eval.value})"
- }.mkString(",")
- val callFunc = s"${ctx.boxedType(ctx.javaType(dataType))} $resultTerm = " +
- s"(${ctx.boxedType(ctx.javaType(dataType))})${catalystConverterTerm}" +
- s".apply($funcTerm.apply($funcArguments));"
+ val (converters, funcArguments) = converterTerms.zipWithIndex.map { case (converter, i) =>
+ val eval = evals(i)
+ val argTerm = ctx.freshName("arg")
+ val convert = s"Object $argTerm = ${eval.isNull} ? null : $converter.apply(${eval.value});"
+ (convert, argTerm)
+ }.unzip
- evalCode + s"""
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
- Boolean ${ev.isNull};
+ val callFunc = s"${ctx.boxedType(dataType)} $resultTerm = " +
+ s"(${ctx.boxedType(dataType)})${catalystConverterTerm}" +
+ s".apply($funcTerm.apply(${funcArguments.mkString(", ")}));"
+ s"""
+ $evalCode
+ ${converters.mkString("\n")}
$callFunc
- ${ev.value} = $resultTerm;
- ${ev.isNull} = $resultTerm == null;
+ boolean ${ev.isNull} = $resultTerm == null;
+ ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ if (!${ev.isNull}) {
+ ${ev.value} = $resultTerm;
+ }
"""
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 8887dc68a5..5353fefaf4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -1144,9 +1144,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
// passing null into the UDF that could handle it
val boxedUDF = udf[java.lang.Integer, java.lang.Integer] {
- (i: java.lang.Integer) => if (i == null) -10 else i * 2
+ (i: java.lang.Integer) => if (i == null) -10 else null
}
- checkAnswer(df.select(boxedUDF($"age")), Row(44) :: Row(-10) :: Nil)
+ checkAnswer(df.select(boxedUDF($"age")), Row(null) :: Row(-10) :: Nil)
+
+ sqlContext.udf.register("boxedUDF",
+ (i: java.lang.Integer) => (if (i == null) -10 else null): java.lang.Integer)
+ checkAnswer(sql("select boxedUDF(null), boxedUDF(-1)"), Row(-10, null) :: Nil)
val primitiveUDF = udf((i: Int) => i * 2)
checkAnswer(df.select(primitiveUDF($"age")), Row(44) :: Row(null) :: Nil)