aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala20
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala8
2 files changed, 21 insertions, 7 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 952a5f3b04..7cb94a7942 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -859,17 +859,23 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String])
override def foldable: Boolean = false
override def nullable: Boolean = false
- override def eval(input: InternalRow): Any =
- throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
+ private val errMsg = "Null value appeared in non-nullable field:" +
+ walkedTypePath.mkString("\n", "\n", "\n") +
+ "If the schema is inferred from a Scala tuple/case class, or a Java bean, " +
+ "please try to use scala.Option[_] or other nullable types " +
+ "(e.g. java.lang.Integer instead of int/scala.Int)."
+
+ override def eval(input: InternalRow): Any = {
+ val result = child.eval(input)
+ if (result == null) {
+ throw new RuntimeException(errMsg);
+ }
+ result
+ }
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val childGen = child.genCode(ctx)
- val errMsg = "Null value appeared in non-nullable field:" +
- walkedTypePath.mkString("\n", "\n", "\n") +
- "If the schema is inferred from a Scala tuple/case class, or a Java bean, " +
- "please try to use scala.Option[_] or other nullable types " +
- "(e.g. java.lang.Integer instead of int/scala.Int)."
val errMsgField = ctx.addReferenceObj("errMsg", errMsg)
val code = s"""
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala
index 712fe35f47..e736379930 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
import org.apache.spark.sql.types._
class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -45,6 +46,13 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}
+ test("AssertNotNUll") {
+ val ex = intercept[RuntimeException] {
+ evaluate(AssertNotNull(Literal(null), Seq.empty[String]))
+ }.getMessage
+ assert(ex.contains("Null value appeared in non-nullable field"))
+ }
+
test("IsNaN") {
checkEvaluation(IsNaN(Literal(Double.NaN)), true)
checkEvaluation(IsNaN(Literal(Float.NaN)), true)