aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-07-14 10:20:15 -0700
committerMichael Armbrust <michael@databricks.com>2015-07-14 10:20:15 -0700
commit59d820aa8dec08b744971237860b4c6bef577ddf (patch)
treed34e1f77c56827af239a14da4e480deb631f381b /sql
parent257236c3e17906098f801cbc2059e7a9054e8cab (diff)
downloadspark-59d820aa8dec08b744971237860b4c6bef577ddf.tar.gz
spark-59d820aa8dec08b744971237860b4c6bef577ddf.tar.bz2
spark-59d820aa8dec08b744971237860b4c6bef577ddf.zip
[SPARK-9029] [SQL] shortcut CaseKeyWhen if key is null
Author: Wenchen Fan <cloud0fan@outlook.com> Closes #7389 from cloud-fan/case-when and squashes the following commits: ea4b6ba [Wenchen Fan] shortcut for case key when
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala48
1 files changed, 24 insertions, 24 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
index eea7706b9d..c7f039ede2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
@@ -230,24 +230,31 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
}
}
+ private def evalElse(input: InternalRow): Any = {
+ if (branchesArr.length % 2 == 0) {
+ null
+ } else {
+ branchesArr(branchesArr.length - 1).eval(input)
+ }
+ }
+
/** Written in imperative fashion for performance considerations. */
override def eval(input: InternalRow): Any = {
val evaluatedKey = key.eval(input)
- val len = branchesArr.length
- var i = 0
- // If all branches fail and an elseVal is not provided, the whole statement
- // defaults to null, according to Hive's semantics.
- while (i < len - 1) {
- if (threeValueEquals(evaluatedKey, branchesArr(i).eval(input))) {
- return branchesArr(i + 1).eval(input)
+ // If key is null, we can just return the else part or null if there is no else.
+ // If key is not null but doesn't match any when part, we need to return
+ // the else part or null if there is no else, according to Hive's semantics.
+ if (evaluatedKey != null) {
+ val len = branchesArr.length
+ var i = 0
+ while (i < len - 1) {
+ if (evaluatedKey == branchesArr(i).eval(input)) {
+ return branchesArr(i + 1).eval(input)
+ }
+ i += 2
}
- i += 2
}
- var res: Any = null
- if (i == len - 1) {
- res = branchesArr(i).eval(input)
- }
- return res
+ evalElse(input)
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
@@ -261,8 +268,7 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
s"""
if (!$got) {
${cond.code}
- if (!${keyEval.isNull} && !${cond.isNull}
- && ${ctx.genEqual(key.dataType, keyEval.primitive, cond.primitive)}) {
+ if (!${cond.isNull} && ${ctx.genEqual(key.dataType, keyEval.primitive, cond.primitive)}) {
$got = true;
${res.code}
${ev.isNull} = ${res.isNull};
@@ -290,19 +296,13 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
${keyEval.code}
- $cases
+ if (!${keyEval.isNull}) {
+ $cases
+ }
$other
"""
}
- private def threeValueEquals(l: Any, r: Any) = {
- if (l == null || r == null) {
- false
- } else {
- l == r
- }
- }
-
override def toString: String = {
s"CASE $key" + branches.sliding(2, 2).map {
case Seq(cond, value) => s" WHEN $cond THEN $value"