diff options
author | Takuya UESHIN <ueshin@happy-camper.st> | 2014-06-20 00:12:52 -0700 |
---|---|---|
committer | Reynold Xin <rxin@apache.org> | 2014-06-20 00:12:52 -0700 |
commit | 324952892085d1933bcf392ce8f2ced452fe741e (patch) | |
tree | bc879548bd595083ce95b9ca59c7486b6684a08b /sql | |
parent | f46e02fcdbb3f86a8761c078708388d18282ee0c (diff) | |
download | spark-324952892085d1933bcf392ce8f2ced452fe741e.tar.gz spark-324952892085d1933bcf392ce8f2ced452fe741e.tar.bz2 spark-324952892085d1933bcf392ce8f2ced452fe741e.zip |
[SPARK-2196] [SQL] Fix nullability of CaseWhen.
`CaseWhen` should use `branches.length` to check if `elseValue` is provided or not.
Author: Takuya UESHIN <ueshin@happy-camper.st>
Closes #1133 from ueshin/issues/SPARK-2196 and squashes the following commits:
510f12d [Takuya UESHIN] Add some tests.
dc25e8d [Takuya UESHIN] Fix nullable of CaseWhen to be nullable if the elseValue is nullable.
4f049cc [Takuya UESHIN] Fix nullability of CaseWhen.
Diffstat (limited to 'sql')
2 files changed, 46 insertions, 1 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 2902906df2..2718d43646 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -233,10 +233,12 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression { branches.sliding(2, 2).collect { case Seq(cond, _) => cond }.toSeq @transient private[this] lazy val values = branches.sliding(2, 2).collect { case Seq(_, value) => value }.toSeq + @transient private[this] lazy val elseValue = + if (branches.length % 2 == 0) None else Option(branches.last) override def nullable = { // If no value is nullable and no elseValue is provided, the whole statement defaults to null. - values.exists(_.nullable) || (values.length % 2 == 0) + values.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true)) } override lazy val resolved = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 8c3b062d0f..84d7281477 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -333,6 +333,49 @@ class ExpressionEvaluationSuite extends FunSuite { Literal("^Ba*n", StringType) :: c2 :: Nil), true, row) } + test("case when") { + val row = new GenericRow(Array[Any](null, false, true, "a", "b", "c")) + val c1 = 'a.boolean.at(0) + val c2 = 'a.boolean.at(1) + val c3 = 'a.boolean.at(2) + val c4 = 'a.string.at(3) + val c5 = 'a.string.at(4) + val c6 = 'a.string.at(5) + + checkEvaluation(CaseWhen(Seq(c1, c4, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(c2, c4, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(c3, c4, c6)), "a", row) + checkEvaluation(CaseWhen(Seq(Literal(null, BooleanType), c4, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(Literal(false, BooleanType), c4, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(Literal(true, BooleanType), c4, c6)), "a", row) + + checkEvaluation(CaseWhen(Seq(c3, c4, c2, c5, c6)), "a", row) + checkEvaluation(CaseWhen(Seq(c2, c4, c3, c5, c6)), "b", row) + checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5, c6)), "c", row) + checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5)), null, row) + + assert(CaseWhen(Seq(c2, c4, c6)).nullable === true) + assert(CaseWhen(Seq(c2, c4, c3, c5, c6)).nullable === true) + assert(CaseWhen(Seq(c2, c4, c3, c5)).nullable === true) + + val c4_notNull = 'a.boolean.notNull.at(3) + val c5_notNull = 'a.boolean.notNull.at(4) + val c6_notNull = 'a.boolean.notNull.at(5) + + assert(CaseWhen(Seq(c2, c4_notNull, c6_notNull)).nullable === false) + assert(CaseWhen(Seq(c2, c4, c6_notNull)).nullable === true) + assert(CaseWhen(Seq(c2, c4_notNull, c6)).nullable === true) + + assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6_notNull)).nullable === false) + assert(CaseWhen(Seq(c2, c4, c3, c5_notNull, c6_notNull)).nullable === true) + assert(CaseWhen(Seq(c2, c4_notNull, c3, c5, c6_notNull)).nullable === true) + assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6)).nullable === true) + + assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull)).nullable === true) + assert(CaseWhen(Seq(c2, c4, c3, c5_notNull)).nullable === true) + assert(CaseWhen(Seq(c2, c4_notNull, c3, c5)).nullable === true) + } + test("complex type") { val row = new GenericRow(Array[Any]( "^Ba*n", // 0 |