aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-07-18 11:06:46 -0700
committerReynold Xin <rxin@databricks.com>2015-07-18 11:06:46 -0700
commitfba3f5ba85673336c0556ef8731dcbcd175c7418 (patch)
treeaf971e0781afdc90d9e1d9696c4e543c2938249b
parentb9ef7ac98c3dee3256c4a393e563b42b4612a4bf (diff)
downloadspark-fba3f5ba85673336c0556ef8731dcbcd175c7418.tar.gz
spark-fba3f5ba85673336c0556ef8731dcbcd175c7418.tar.bz2
spark-fba3f5ba85673336c0556ef8731dcbcd175c7418.zip
[SPARK-9169][SQL] Improve unit test coverage for null expressions.
Author: Reynold Xin <rxin@databricks.com> Closes #7490 from rxin/unit-test-null-funcs and squashes the following commits: 7b276f0 [Reynold Xin] Move isNaN. 8307287 [Reynold Xin] [SPARK-9169][SQL] Improve unit test coverage for null expressions.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala81
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala51
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala78
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala12
4 files changed, 119 insertions, 103 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
index 1522bcae08..98c6708464 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala
@@ -21,8 +21,19 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.util.TypeUtils
-import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.types._
+
+/**
+ * An expression that is evaluated to the first non-null input.
+ *
+ * {{{
+ * coalesce(1, 2) => 1
+ * coalesce(null, 1, 2) => 1
+ * coalesce(null, null, 2) => 2
+ * coalesce(null, null, null) => null
+ * }}}
+ */
case class Coalesce(children: Seq[Expression]) extends Expression {
/** Coalesce is nullable if all of its children are nullable, or if it has no children. */
@@ -70,6 +81,62 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
}
}
+
+/**
+ * Evaluates to `true` if it's NaN or null
+ */
+case class IsNaN(child: Expression) extends UnaryExpression
+ with Predicate with ImplicitCastInputTypes {
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(DoubleType, FloatType))
+
+ override def nullable: Boolean = false
+
+ override def eval(input: InternalRow): Any = {
+ val value = child.eval(input)
+ if (value == null) {
+ true
+ } else {
+ child.dataType match {
+ case DoubleType => value.asInstanceOf[Double].isNaN
+ case FloatType => value.asInstanceOf[Float].isNaN
+ }
+ }
+ }
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val eval = child.gen(ctx)
+ child.dataType match {
+ case FloatType =>
+ s"""
+ ${eval.code}
+ boolean ${ev.isNull} = false;
+ ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+ if (${eval.isNull}) {
+ ${ev.primitive} = true;
+ } else {
+ ${ev.primitive} = Float.isNaN(${eval.primitive});
+ }
+ """
+ case DoubleType =>
+ s"""
+ ${eval.code}
+ boolean ${ev.isNull} = false;
+ ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+ if (${eval.isNull}) {
+ ${ev.primitive} = true;
+ } else {
+ ${ev.primitive} = Double.isNaN(${eval.primitive});
+ }
+ """
+ }
+ }
+}
+
+
+/**
+ * An expression that is evaluated to true if the input is null.
+ */
case class IsNull(child: Expression) extends UnaryExpression with Predicate {
override def nullable: Boolean = false
@@ -83,13 +150,14 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate {
ev.primitive = eval.isNull
eval.code
}
-
- override def toString: String = s"IS NULL $child"
}
+
+/**
+ * An expression that is evaluated to true if the input is not null.
+ */
case class IsNotNull(child: Expression) extends UnaryExpression with Predicate {
override def nullable: Boolean = false
- override def toString: String = s"IS NOT NULL $child"
override def eval(input: InternalRow): Any = {
child.eval(input) != null
@@ -103,12 +171,13 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate {
}
}
+
/**
- * A predicate that is evaluated to be true if there are at least `n` non-null values.
+ * A predicate that is evaluated to be true if there are at least `n` non-null and non-NaN values.
*/
case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate {
override def nullable: Boolean = false
- override def foldable: Boolean = false
+ override def foldable: Boolean = children.forall(_.foldable)
override def toString: String = s"AtLeastNNulls(n, ${children.mkString(",")})"
private[this] val childrenArray = children.toArray
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 2751c8e75f..bddd2a9ecc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -18,7 +18,6 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
@@ -120,56 +119,6 @@ case class InSet(child: Expression, hset: Set[Any])
}
}
-/**
- * Evaluates to `true` if it's NaN or null
- */
-case class IsNaN(child: Expression) extends UnaryExpression
- with Predicate with ImplicitCastInputTypes {
-
- override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(DoubleType, FloatType))
-
- override def nullable: Boolean = false
-
- override def eval(input: InternalRow): Any = {
- val value = child.eval(input)
- if (value == null) {
- true
- } else {
- child.dataType match {
- case DoubleType => value.asInstanceOf[Double].isNaN
- case FloatType => value.asInstanceOf[Float].isNaN
- }
- }
- }
-
- override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- val eval = child.gen(ctx)
- child.dataType match {
- case FloatType =>
- s"""
- ${eval.code}
- boolean ${ev.isNull} = false;
- ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
- if (${eval.isNull}) {
- ${ev.primitive} = true;
- } else {
- ${ev.primitive} = Float.isNaN(${eval.primitive});
- }
- """
- case DoubleType =>
- s"""
- ${eval.code}
- boolean ${ev.isNull} = false;
- ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
- if (${eval.isNull}) {
- ${ev.primitive} = true;
- } else {
- ${ev.primitive} = Double.isNaN(${eval.primitive});
- }
- """
- }
- }
-}
case class And(left: Expression, right: Expression) extends BinaryOperator with Predicate {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala
index ccdada8b56..765cc7a969 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala
@@ -18,48 +18,52 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.types.{BooleanType, StringType, ShortType}
+import org.apache.spark.sql.types._
class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
- test("null checking") {
- val row = create_row("^Ba*n", null, true, null)
- val c1 = 'a.string.at(0)
- val c2 = 'a.string.at(1)
- val c3 = 'a.boolean.at(2)
- val c4 = 'a.boolean.at(3)
-
- checkEvaluation(c1.isNull, false, row)
- checkEvaluation(c1.isNotNull, true, row)
-
- checkEvaluation(c2.isNull, true, row)
- checkEvaluation(c2.isNotNull, false, row)
-
- checkEvaluation(Literal.create(1, ShortType).isNull, false)
- checkEvaluation(Literal.create(1, ShortType).isNotNull, true)
-
- checkEvaluation(Literal.create(null, ShortType).isNull, true)
- checkEvaluation(Literal.create(null, ShortType).isNotNull, false)
+ def testAllTypes(testFunc: (Any, DataType) => Unit): Unit = {
+ testFunc(false, BooleanType)
+ testFunc(1.toByte, ByteType)
+ testFunc(1.toShort, ShortType)
+ testFunc(1, IntegerType)
+ testFunc(1L, LongType)
+ testFunc(1.0F, FloatType)
+ testFunc(1.0, DoubleType)
+ testFunc(Decimal(1.5), DecimalType.Unlimited)
+ testFunc(new java.sql.Date(10), DateType)
+ testFunc(new java.sql.Timestamp(10), TimestampType)
+ testFunc("abcd", StringType)
+ }
- checkEvaluation(Coalesce(c1 :: c2 :: Nil), "^Ba*n", row)
- checkEvaluation(Coalesce(Literal.create(null, StringType) :: Nil), null, row)
- checkEvaluation(Coalesce(Literal.create(null, StringType) :: c1 :: c2 :: Nil), "^Ba*n", row)
+ test("isnull and isnotnull") {
+ testAllTypes { (value: Any, tpe: DataType) =>
+ checkEvaluation(IsNull(Literal.create(value, tpe)), false)
+ checkEvaluation(IsNotNull(Literal.create(value, tpe)), true)
+ checkEvaluation(IsNull(Literal.create(null, tpe)), true)
+ checkEvaluation(IsNotNull(Literal.create(null, tpe)), false)
+ }
+ }
- checkEvaluation(
- If(c3, Literal.create("a", StringType), Literal.create("b", StringType)), "a", row)
- checkEvaluation(If(c3, c1, c2), "^Ba*n", row)
- checkEvaluation(If(c4, c2, c1), "^Ba*n", row)
- checkEvaluation(If(Literal.create(null, BooleanType), c2, c1), "^Ba*n", row)
- checkEvaluation(If(Literal.create(true, BooleanType), c1, c2), "^Ba*n", row)
- checkEvaluation(If(Literal.create(false, BooleanType), c2, c1), "^Ba*n", row)
- checkEvaluation(If(Literal.create(false, BooleanType),
- Literal.create("a", StringType), Literal.create("b", StringType)), "b", row)
+ test("IsNaN") {
+ checkEvaluation(IsNaN(Literal(Double.NaN)), true)
+ checkEvaluation(IsNaN(Literal(Float.NaN)), true)
+ checkEvaluation(IsNaN(Literal(math.log(-3))), true)
+ checkEvaluation(IsNaN(Literal.create(null, DoubleType)), true)
+ checkEvaluation(IsNaN(Literal(Double.PositiveInfinity)), false)
+ checkEvaluation(IsNaN(Literal(Float.MaxValue)), false)
+ checkEvaluation(IsNaN(Literal(5.5f)), false)
+ }
- checkEvaluation(c1 in (c1, c2), true, row)
- checkEvaluation(
- Literal.create("^Ba*n", StringType) in (Literal.create("^Ba*n", StringType)), true, row)
- checkEvaluation(
- Literal.create("^Ba*n", StringType) in (Literal.create("^Ba*n", StringType), c2), true, row)
+ test("coalesce") {
+ testAllTypes { (value: Any, tpe: DataType) =>
+ val lit = Literal.create(value, tpe)
+ val nullLit = Literal.create(null, tpe)
+ checkEvaluation(Coalesce(Seq(nullLit)), null)
+ checkEvaluation(Coalesce(Seq(lit)), value)
+ checkEvaluation(Coalesce(Seq(nullLit, lit)), value)
+ checkEvaluation(Coalesce(Seq(nullLit, lit, lit)), value)
+ checkEvaluation(Coalesce(Seq(nullLit, nullLit, lit)), value)
+ }
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
index 052abc51af..2173a0c25c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
@@ -114,16 +114,10 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(
And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1), Literal(2)))),
true)
- }
- test("IsNaN") {
- checkEvaluation(IsNaN(Literal(Double.NaN)), true)
- checkEvaluation(IsNaN(Literal(Float.NaN)), true)
- checkEvaluation(IsNaN(Literal(math.log(-3))), true)
- checkEvaluation(IsNaN(Literal.create(null, DoubleType)), true)
- checkEvaluation(IsNaN(Literal(Double.PositiveInfinity)), false)
- checkEvaluation(IsNaN(Literal(Float.MaxValue)), false)
- checkEvaluation(IsNaN(Literal(5.5f)), false)
+ checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"))), true)
+ checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^Ba*n"))), true)
+ checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^n"))), false)
}
test("INSET") {