aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2015-11-18 10:23:12 -0800
committerMichael Armbrust <michael@databricks.com>2015-11-18 10:23:12 -0800
commit33b837333435ceb0c04d1f361a5383c4fe6a5a75 (patch)
tree04a87253d495d31efbc2565a49cbb48c2ac7053d
parentcffb899c4397ecccedbcc41e7cf3da91f953435a (diff)
downloadspark-33b837333435ceb0c04d1f361a5383c4fe6a5a75.tar.gz
spark-33b837333435ceb0c04d1f361a5383c4fe6a5a75.tar.bz2
spark-33b837333435ceb0c04d1f361a5383c4fe6a5a75.zip
[SPARK-11725][SQL] correctly handle null inputs for UDF
If user use primitive parameters in UDF, there is no way for him to do the null-check for primitive inputs, so we are assuming the primitive input is null-propagatable for this case and return null if the input is null. Author: Wenchen Fan <wenchen@databricks.com> Closes #9770 from cloud-fan/udf.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala9
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala32
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala6
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala17
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala44
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala14
6 files changed, 121 insertions, 1 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 0b3dd351e3..38828e59a2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -719,6 +719,15 @@ trait ScalaReflection {
}
}
+ /**
+ * Returns classes of input parameters of scala function object.
+ */
+ def getParameterTypes(func: AnyRef): Seq[Class[_]] = {
+ val methods = func.getClass.getMethods.filter(m => m.getName == "apply" && !m.isBridge)
+ assert(methods.length == 1)
+ methods.head.getParameterTypes
+ }
+
def typeOfObject: PartialFunction[Any, DataType] = {
// The data type can be determined without ambiguity.
case obj: Boolean => BooleanType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 2f4670b55b..f00c451b59 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
-import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf}
+import org.apache.spark.sql.catalyst.{ScalaReflection, SimpleCatalystConf, CatalystConf}
import org.apache.spark.sql.types._
/**
@@ -85,6 +85,8 @@ class Analyzer(
extendedResolutionRules : _*),
Batch("Nondeterministic", Once,
PullOutNondeterministic),
+ Batch("UDF", Once,
+ HandleNullInputsForUDF),
Batch("Cleanup", fixedPoint,
CleanupAliases)
)
@@ -1063,6 +1065,34 @@ class Analyzer(
Project(p.output, newPlan.withNewChildren(newChild :: Nil))
}
}
+
+ /**
+ * Correctly handle null primitive inputs for UDF by adding extra [[If]] expression to do the
+ * null check. When user defines a UDF with primitive parameters, there is no way to tell if the
+ * primitive parameter is null or not, so here we assume the primitive input is null-propagatable
+ * and we should return null if the input is null.
+ */
+ object HandleNullInputsForUDF extends Rule[LogicalPlan] {
+ override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ case p if !p.resolved => p // Skip unresolved nodes.
+
+ case plan => plan transformExpressionsUp {
+
+ case udf @ ScalaUDF(func, _, inputs, _) =>
+ val parameterTypes = ScalaReflection.getParameterTypes(func)
+ assert(parameterTypes.length == inputs.length)
+
+ val inputsNullCheck = parameterTypes.zip(inputs)
+ // TODO: skip null handling for not-nullable primitive inputs after we can completely
+ // trust the `nullable` information.
+ // .filter { case (cls, expr) => cls.isPrimitive && expr.nullable }
+ .filter { case (cls, _) => cls.isPrimitive }
+ .map { case (_, expr) => IsNull(expr) }
+ .reduceLeftOption[Expression]((e1, e2) => Or(e1, e2))
+ inputsNullCheck.map(If(_, Literal.create(null, udf.dataType), udf)).getOrElse(udf)
+ }
+ }
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
index 3388cc20a9..03b89221ef 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
@@ -24,7 +24,13 @@ import org.apache.spark.sql.types.DataType
/**
* User-defined function.
+ * @param function The user defined scala function to run.
+ * Note that if you use primitive parameters, you are not able to check if it is
+ * null or not, and the UDF will return null for you if the primitive input is
+ * null. Use boxed type or [[Option]] if you wanna do the null-handling yourself.
* @param dataType Return type of function.
+ * @param children The input expressions of this UDF.
+ * @param inputTypes The expected input types of this UDF.
*/
case class ScalaUDF(
function: AnyRef,
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index 3b848cfdf7..4ea410d492 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -280,4 +280,21 @@ class ScalaReflectionSuite extends SparkFunSuite {
assert(s.fields.map(_.dataType) === Seq(IntegerType, StringType, DoubleType))
}
}
+
+ test("get parameter type from a function object") {
+ val primitiveFunc = (i: Int, j: Long) => "x"
+ val primitiveTypes = getParameterTypes(primitiveFunc)
+ assert(primitiveTypes.forall(_.isPrimitive))
+ assert(primitiveTypes === Seq(classOf[Int], classOf[Long]))
+
+ val boxedFunc = (i: java.lang.Integer, j: java.lang.Long) => "x"
+ val boxedTypes = getParameterTypes(boxedFunc)
+ assert(boxedTypes.forall(!_.isPrimitive))
+ assert(boxedTypes === Seq(classOf[java.lang.Integer], classOf[java.lang.Long]))
+
+ val anyFunc = (i: Any, j: AnyRef) => "x"
+ val anyTypes = getParameterTypes(anyFunc)
+ assert(anyTypes.forall(!_.isPrimitive))
+ assert(anyTypes === Seq(classOf[java.lang.Object], classOf[java.lang.Object]))
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 65f09b46af..08586a9741 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -174,4 +174,48 @@ class AnalysisSuite extends AnalysisTest {
)
assertAnalysisError(plan, Seq("data type mismatch: Arguments must be same type"))
}
+
+ test("SPARK-11725: correctly handle null inputs for ScalaUDF") {
+ val string = testRelation2.output(0)
+ val double = testRelation2.output(2)
+ val short = testRelation2.output(4)
+ val nullResult = Literal.create(null, StringType)
+
+ def checkUDF(udf: Expression, transformed: Expression): Unit = {
+ checkAnalysis(
+ Project(Alias(udf, "")() :: Nil, testRelation2),
+ Project(Alias(transformed, "")() :: Nil, testRelation2)
+ )
+ }
+
+ // non-primitive parameters do not need special null handling
+ val udf1 = ScalaUDF((s: String) => "x", StringType, string :: Nil)
+ val expected1 = udf1
+ checkUDF(udf1, expected1)
+
+ // only primitive parameter needs special null handling
+ val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil)
+ val expected2 = If(IsNull(double), nullResult, udf2)
+ checkUDF(udf2, expected2)
+
+ // special null handling should apply to all primitive parameters
+ val udf3 = ScalaUDF((s: Short, d: Double) => "x", StringType, short :: double :: Nil)
+ val expected3 = If(
+ IsNull(short) || IsNull(double),
+ nullResult,
+ udf3)
+ checkUDF(udf3, expected3)
+
+ // we can skip special null handling for primitive parameters that are not nullable
+ // TODO: this is disabled for now as we can not completely trust `nullable`.
+ val udf4 = ScalaUDF(
+ (s: Short, d: Double) => "x",
+ StringType,
+ short :: double.withNullability(false) :: Nil)
+ val expected4 = If(
+ IsNull(short),
+ nullResult,
+ udf4)
+ // checkUDF(udf4, expected4)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 35cdab50bd..5a7f24684d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -1115,4 +1115,18 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
checkAnswer(df.select(df("*")), Row(1, "a"))
checkAnswer(df.withColumnRenamed("d^'a.", "a"), Row(1, "a"))
}
+
+ test("SPARK-11725: correctly handle null inputs for ScalaUDF") {
+ val df = Seq(
+ new java.lang.Integer(22) -> "John",
+ null.asInstanceOf[java.lang.Integer] -> "Lucy").toDF("age", "name")
+
+ val boxedUDF = udf[java.lang.Integer, java.lang.Integer] {
+ (i: java.lang.Integer) => if (i == null) null else i * 2
+ }
+ checkAnswer(df.select(boxedUDF($"age")), Row(44) :: Row(null) :: Nil)
+
+ val primitiveUDF = udf((i: Int) => i * 2)
+ checkAnswer(df.select(primitiveUDF($"age")), Row(44) :: Row(null) :: Nil)
+ }
}