aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala43
2 files changed, 46 insertions, 1 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 2902906df2..2718d43646 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -233,10 +233,12 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression {
branches.sliding(2, 2).collect { case Seq(cond, _) => cond }.toSeq
@transient private[this] lazy val values =
branches.sliding(2, 2).collect { case Seq(_, value) => value }.toSeq
+ @transient private[this] lazy val elseValue =
+ if (branches.length % 2 == 0) None else Option(branches.last)
override def nullable = {
// If no value is nullable and no elseValue is provided, the whole statement defaults to null.
- values.exists(_.nullable) || (values.length % 2 == 0)
+ values.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true))
}
override lazy val resolved = {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index 8c3b062d0f..84d7281477 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -333,6 +333,49 @@ class ExpressionEvaluationSuite extends FunSuite {
Literal("^Ba*n", StringType) :: c2 :: Nil), true, row)
}
+ test("case when") {
+ val row = new GenericRow(Array[Any](null, false, true, "a", "b", "c"))
+ val c1 = 'a.boolean.at(0)
+ val c2 = 'a.boolean.at(1)
+ val c3 = 'a.boolean.at(2)
+ val c4 = 'a.string.at(3)
+ val c5 = 'a.string.at(4)
+ val c6 = 'a.string.at(5)
+
+ checkEvaluation(CaseWhen(Seq(c1, c4, c6)), "c", row)
+ checkEvaluation(CaseWhen(Seq(c2, c4, c6)), "c", row)
+ checkEvaluation(CaseWhen(Seq(c3, c4, c6)), "a", row)
+ checkEvaluation(CaseWhen(Seq(Literal(null, BooleanType), c4, c6)), "c", row)
+ checkEvaluation(CaseWhen(Seq(Literal(false, BooleanType), c4, c6)), "c", row)
+ checkEvaluation(CaseWhen(Seq(Literal(true, BooleanType), c4, c6)), "a", row)
+
+ checkEvaluation(CaseWhen(Seq(c3, c4, c2, c5, c6)), "a", row)
+ checkEvaluation(CaseWhen(Seq(c2, c4, c3, c5, c6)), "b", row)
+ checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5, c6)), "c", row)
+ checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5)), null, row)
+
+ assert(CaseWhen(Seq(c2, c4, c6)).nullable === true)
+ assert(CaseWhen(Seq(c2, c4, c3, c5, c6)).nullable === true)
+ assert(CaseWhen(Seq(c2, c4, c3, c5)).nullable === true)
+
+ val c4_notNull = 'a.boolean.notNull.at(3)
+ val c5_notNull = 'a.boolean.notNull.at(4)
+ val c6_notNull = 'a.boolean.notNull.at(5)
+
+ assert(CaseWhen(Seq(c2, c4_notNull, c6_notNull)).nullable === false)
+ assert(CaseWhen(Seq(c2, c4, c6_notNull)).nullable === true)
+ assert(CaseWhen(Seq(c2, c4_notNull, c6)).nullable === true)
+
+ assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6_notNull)).nullable === false)
+ assert(CaseWhen(Seq(c2, c4, c3, c5_notNull, c6_notNull)).nullable === true)
+ assert(CaseWhen(Seq(c2, c4_notNull, c3, c5, c6_notNull)).nullable === true)
+ assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6)).nullable === true)
+
+ assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull)).nullable === true)
+ assert(CaseWhen(Seq(c2, c4, c3, c5_notNull)).nullable === true)
+ assert(CaseWhen(Seq(c2, c4_notNull, c3, c5)).nullable === true)
+ }
+
test("complex type") {
val row = new GenericRow(Array[Any](
"^Ba*n", // 0