diff options
author | Vinod K C <vinod.kc@huawei.com> | 2015-07-13 12:51:33 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2015-07-13 12:51:33 -0700 |
commit | 4c797f2b0989317a2d004e5f72a0e593919737ea (patch) | |
tree | fe3435663a398aca3000a50becdba90cd9876883 | |
parent | 714fc55f4aadd5e7b7fb1e462910bfb6a82d9154 (diff) | |
download | spark-4c797f2b0989317a2d004e5f72a0e593919737ea.tar.gz spark-4c797f2b0989317a2d004e5f72a0e593919737ea.tar.bz2 spark-4c797f2b0989317a2d004e5f72a0e593919737ea.zip |
[SPARK-8636] [SQL] Fix equalNullSafe comparison
Author: Vinod K C <vinod.kc@huawei.com>
Closes #7040 from vinodkc/fix_CaseKeyWhen_equalNullSafe and squashes the following commits:
be5e641 [Vinod K C] Renamed equalNullSafe to threeValueEquals
aac9f67 [Vinod K C] Updated test suite and genCode method
f2d0b53 [Vinod K C] Fix equalNullSafe comparison
2 files changed, 6 insertions, 9 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala index e6a705fb80..84c28c27f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala @@ -238,7 +238,7 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW // If all branches fail and an elseVal is not provided, the whole statement // defaults to null, according to Hive's semantics. while (i < len - 1) { - if (equalNullSafe(evaluatedKey, branchesArr(i).eval(input))) { + if (threeValueEquals(evaluatedKey, branchesArr(i).eval(input))) { return branchesArr(i + 1).eval(input) } i += 2 @@ -261,8 +261,7 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW s""" if (!$got) { ${cond.code} - if (${keyEval.isNull} && ${cond.isNull} || - !${keyEval.isNull} && !${cond.isNull} + if (!${keyEval.isNull} && !${cond.isNull} && ${ctx.genEqual(key.dataType, keyEval.primitive, cond.primitive)}) { $got = true; ${res.code} @@ -296,10 +295,8 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW """ } - private def equalNullSafe(l: Any, r: Any) = { - if (l == null && r == null) { - true - } else if (l == null || r == null) { + private def threeValueEquals(l: Any, r: Any) = { + if (l == null || r == null) { false } else { l == r diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index aaf40cc83e..adadc8c54f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -125,7 +125,7 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper val literalString = Literal("a") checkEvaluation(CaseKeyWhen(c1, Seq(c2, c4, c5)), "b", row) - checkEvaluation(CaseKeyWhen(c1, Seq(c2, c4, literalNull, c5, c6)), "b", row) + checkEvaluation(CaseKeyWhen(c1, Seq(c2, c4, literalNull, c5, c6)), "c", row) checkEvaluation(CaseKeyWhen(c2, Seq(literalInt, c4, c5)), "a", row) checkEvaluation(CaseKeyWhen(c2, Seq(c1, c4, c5)), "b", row) checkEvaluation(CaseKeyWhen(c4, Seq(literalString, c2, c3)), 1, row) @@ -134,7 +134,7 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(CaseKeyWhen(literalInt, Seq(c2, c4, c5)), "a", row) checkEvaluation(CaseKeyWhen(literalString, Seq(c5, c2, c4, c3)), 2, row) checkEvaluation(CaseKeyWhen(c6, Seq(c5, c2, c4, c3)), null, row) - checkEvaluation(CaseKeyWhen(literalNull, Seq(c2, c5, c1, c6)), "c", row) + checkEvaluation(CaseKeyWhen(literalNull, Seq(c2, c5, c1, c6)), null, row) } test("function least") { |