aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala11
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala4
2 files changed, 6 insertions, 9 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
index e6a705fb80..84c28c27f9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
@@ -238,7 +238,7 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
// If all branches fail and an elseVal is not provided, the whole statement
// defaults to null, according to Hive's semantics.
while (i < len - 1) {
- if (equalNullSafe(evaluatedKey, branchesArr(i).eval(input))) {
+ if (threeValueEquals(evaluatedKey, branchesArr(i).eval(input))) {
return branchesArr(i + 1).eval(input)
}
i += 2
@@ -261,8 +261,7 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
s"""
if (!$got) {
${cond.code}
- if (${keyEval.isNull} && ${cond.isNull} ||
- !${keyEval.isNull} && !${cond.isNull}
+ if (!${keyEval.isNull} && !${cond.isNull}
&& ${ctx.genEqual(key.dataType, keyEval.primitive, cond.primitive)}) {
$got = true;
${res.code}
@@ -296,10 +295,8 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
"""
}
- private def equalNullSafe(l: Any, r: Any) = {
- if (l == null && r == null) {
- true
- } else if (l == null || r == null) {
+ private def threeValueEquals(l: Any, r: Any) = {
+ if (l == null || r == null) {
false
} else {
l == r
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
index aaf40cc83e..adadc8c54f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
@@ -125,7 +125,7 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
val literalString = Literal("a")
checkEvaluation(CaseKeyWhen(c1, Seq(c2, c4, c5)), "b", row)
- checkEvaluation(CaseKeyWhen(c1, Seq(c2, c4, literalNull, c5, c6)), "b", row)
+ checkEvaluation(CaseKeyWhen(c1, Seq(c2, c4, literalNull, c5, c6)), "c", row)
checkEvaluation(CaseKeyWhen(c2, Seq(literalInt, c4, c5)), "a", row)
checkEvaluation(CaseKeyWhen(c2, Seq(c1, c4, c5)), "b", row)
checkEvaluation(CaseKeyWhen(c4, Seq(literalString, c2, c3)), 1, row)
@@ -134,7 +134,7 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(CaseKeyWhen(literalInt, Seq(c2, c4, c5)), "a", row)
checkEvaluation(CaseKeyWhen(literalString, Seq(c5, c2, c4, c3)), 2, row)
checkEvaluation(CaseKeyWhen(c6, Seq(c5, c2, c4, c3)), null, row)
- checkEvaluation(CaseKeyWhen(literalNull, Seq(c2, c5, c1, c6)), "c", row)
+ checkEvaluation(CaseKeyWhen(literalNull, Seq(c2, c5, c1, c6)), null, row)
}
test("function least") {