aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-09-06 10:36:00 +0800
committerWenchen Fan <wenchen@databricks.com>2016-09-06 10:36:00 +0800
commit8d08f43d09157b98e559c0be6ce6fd571a35e0d1 (patch)
treebdd145e566a7ca014fad3376d799ec2e5e74f3ff /sql
parent6d86403d8b252776effcddd71338b4d21a224f9b (diff)
downloadspark-8d08f43d09157b98e559c0be6ce6fd571a35e0d1.tar.gz
spark-8d08f43d09157b98e559c0be6ce6fd571a35e0d1.tar.bz2
spark-8d08f43d09157b98e559c0be6ce6fd571a35e0d1.zip
[SPARK-17279][SQL] better error message for exceptions during ScalaUDF execution
## What changes were proposed in this pull request? If `ScalaUDF` throws exceptions during executing user code, sometimes it's hard for users to figure out what's wrong, especially when they use Spark shell. An example ``` org.apache.spark.SparkException: Job aborted due to stage failure: Task 12 in stage 325.0 failed 4 times, most recent failure: Lost task 12.3 in stage 325.0 (TID 35622, 10.0.207.202): java.lang.NullPointerException at line8414e872fb8b42aba390efc153d1611a12.$read$$iwC$$iwC$$iwC$$iwC$$anonfun$2.apply(<console>:40) at line8414e872fb8b42aba390efc153d1611a12.$read$$iwC$$iwC$$iwC$$iwC$$anonfun$2.apply(<console>:40) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source) ... ``` We should catch these exceptions and rethrow them with better error message, to say that the exception is happened in scala udf. This PR also does some clean up for `ScalaUDF` and add a unit test suite for it. ## How was this patch tested? the new test suite Author: Wenchen Fan <wenchen@databricks.com> Closes #14850 from cloud-fan/npe.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala44
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala48
2 files changed, 78 insertions, 14 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
index 21390644bc..6cfdea9fdf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
+import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.types.DataType
@@ -994,20 +995,15 @@ case class ScalaUDF(
ctx: CodegenContext,
ev: ExprCode): ExprCode = {
- ctx.references += this
-
- val scalaUDFClassName = classOf[ScalaUDF].getName
+ val scalaUDF = ctx.addReferenceObj("scalaUDF", this)
val converterClassName = classOf[Any => Any].getName
val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$"
- val expressionClassName = classOf[Expression].getName
// Generate codes used to convert the returned value of user-defined functions to Catalyst type
val catalystConverterTerm = ctx.freshName("catalystConverter")
- val catalystConverterTermIdx = ctx.references.size - 1
ctx.addMutableState(converterClassName, catalystConverterTerm,
s"this.$catalystConverterTerm = ($converterClassName)$typeConvertersClassName" +
- s".createToCatalystConverter((($scalaUDFClassName)references" +
- s"[$catalystConverterTermIdx]).dataType());")
+ s".createToCatalystConverter($scalaUDF.dataType());")
val resultTerm = ctx.freshName("result")
@@ -1019,10 +1015,8 @@ case class ScalaUDF(
val funcClassName = s"scala.Function${children.size}"
val funcTerm = ctx.freshName("udf")
- val funcExpressionIdx = ctx.references.size - 1
ctx.addMutableState(funcClassName, funcTerm,
- s"this.$funcTerm = ($funcClassName)((($scalaUDFClassName)references" +
- s"[$funcExpressionIdx]).userDefinedFunc());")
+ s"this.$funcTerm = ($funcClassName)$scalaUDF.userDefinedFunc();")
// codegen for children expressions
val evals = children.map(_.genCode(ctx))
@@ -1039,9 +1033,16 @@ case class ScalaUDF(
(convert, argTerm)
}.unzip
- val callFunc = s"${ctx.boxedType(dataType)} $resultTerm = " +
- s"(${ctx.boxedType(dataType)})${catalystConverterTerm}" +
- s".apply($funcTerm.apply(${funcArguments.mkString(", ")}));"
+ val getFuncResult = s"$funcTerm.apply(${funcArguments.mkString(", ")})"
+ val callFunc =
+ s"""
+ ${ctx.boxedType(dataType)} $resultTerm = null;
+ try {
+ $resultTerm = (${ctx.boxedType(dataType)})$catalystConverterTerm.apply($getFuncResult);
+ } catch (Exception e) {
+ throw new org.apache.spark.SparkException($scalaUDF.udfErrorMessage(), e);
+ }
+ """
ev.copy(code = s"""
$evalCode
@@ -1057,5 +1058,20 @@ case class ScalaUDF(
private[this] val converter = CatalystTypeConverters.createToCatalystConverter(dataType)
- override def eval(input: InternalRow): Any = converter(f(input))
+ lazy val udfErrorMessage = {
+ val funcCls = function.getClass.getSimpleName
+ val inputTypes = children.map(_.dataType.simpleString).mkString(", ")
+ s"Failed to execute user defined function($funcCls: ($inputTypes) => ${dataType.simpleString})"
+ }
+
+ override def eval(input: InternalRow): Any = {
+ val result = try {
+ f(input)
+ } catch {
+ case e: Exception =>
+ throw new SparkException(udfErrorMessage, e)
+ }
+
+ converter(result)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
new file mode 100644
index 0000000000..7e45028653
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.sql.types.{IntegerType, StringType}
+
+class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
+
+ test("basic") {
+ val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil)
+ checkEvaluation(intUdf, 2)
+
+ val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil)
+ checkEvaluation(stringUdf, "ax")
+ }
+
+ test("better error message for NPE") {
+ val udf = ScalaUDF(
+ (s: String) => s.toLowerCase,
+ StringType,
+ Literal.create(null, StringType) :: Nil)
+
+ val e1 = intercept[SparkException](udf.eval())
+ assert(e1.getMessage.contains("Failed to execute user defined function"))
+
+ val e2 = intercept[SparkException] {
+ checkEvalutionWithUnsafeProjection(udf, null)
+ }
+ assert(e2.getMessage.contains("Failed to execute user defined function"))
+ }
+
+}