diff options
author | Sean Zhong <seanzhong@databricks.com> | 2016-08-04 13:43:25 +0800 |
---|---|---|
committer | Wenchen Fan <wenchen@databricks.com> | 2016-08-04 13:43:25 +0800 |
commit | 27e815c31de26636df089b0b8d9bd678b92d3588 (patch) | |
tree | 01e2e747fec4ea4f21bb8141b16305be2fc93f03 /sql/catalyst/src | |
parent | 780c7224a5b8dd3bf7838c6f280c61daeef1dcbc (diff) | |
download | spark-27e815c31de26636df089b0b8d9bd678b92d3588.tar.gz spark-27e815c31de26636df089b0b8d9bd678b92d3588.tar.bz2 spark-27e815c31de26636df089b0b8d9bd678b92d3588.zip |
[SPARK-16888][SQL] Implements eval method for expression AssertNotNull
## What changes were proposed in this pull request?
Implements `eval()` method for expression `AssertNotNull` so that we can convert local projection on LocalRelation to another LocalRelation.
### Before change:
```
scala> import org.apache.spark.sql.catalyst.dsl.expressions._
scala> import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull
scala> import org.apache.spark.sql.Column
scala> case class A(a: Int)
scala> Seq((A(1),2)).toDS().select(new Column(AssertNotNull("_1".attr, Nil))).explain
java.lang.UnsupportedOperationException: Only code-generated evaluation is supported.
at org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull.eval(objects.scala:850)
...
```
### After the change:
```
scala> Seq((A(1),2)).toDS().select(new Column(AssertNotNull("_1".attr, Nil))).explain(true)
== Parsed Logical Plan ==
'Project [assertnotnull('_1) AS assertnotnull(_1)#5]
+- LocalRelation [_1#2, _2#3]
== Analyzed Logical Plan ==
assertnotnull(_1): struct<a:int>
Project [assertnotnull(_1#2) AS assertnotnull(_1)#5]
+- LocalRelation [_1#2, _2#3]
== Optimized Logical Plan ==
LocalRelation [assertnotnull(_1)#5]
== Physical Plan ==
LocalTableScan [assertnotnull(_1)#5]
```
## How was this patch tested?
Unit test.
Author: Sean Zhong <seanzhong@databricks.com>
Closes #14486 from clockfly/assertnotnull_eval.
Diffstat (limited to 'sql/catalyst/src')
2 files changed, 21 insertions, 7 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 952a5f3b04..7cb94a7942 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -859,17 +859,23 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) override def foldable: Boolean = false override def nullable: Boolean = false - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + private val errMsg = "Null value appeared in non-nullable field:" + + walkedTypePath.mkString("\n", "\n", "\n") + + "If the schema is inferred from a Scala tuple/case class, or a Java bean, " + + "please try to use scala.Option[_] or other nullable types " + + "(e.g. java.lang.Integer instead of int/scala.Int)." + + override def eval(input: InternalRow): Any = { + val result = child.eval(input) + if (result == null) { + throw new RuntimeException(errMsg); + } + result + } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val childGen = child.genCode(ctx) - val errMsg = "Null value appeared in non-nullable field:" + - walkedTypePath.mkString("\n", "\n", "\n") + - "If the schema is inferred from a Scala tuple/case class, or a Java bean, " + - "please try to use scala.Option[_] or other nullable types " + - "(e.g. java.lang.Integer instead of int/scala.Int)." val errMsgField = ctx.addReferenceObj("errMsg", errMsg) val code = s""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala index 712fe35f47..e736379930 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull import org.apache.spark.sql.types._ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -45,6 +46,13 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("AssertNotNUll") { + val ex = intercept[RuntimeException] { + evaluate(AssertNotNull(Literal(null), Seq.empty[String])) + }.getMessage + assert(ex.contains("Null value appeared in non-nullable field")) + } + test("IsNaN") { checkEvaluation(IsNaN(Literal(Double.NaN)), true) checkEvaluation(IsNaN(Literal(Float.NaN)), true) |