diff options
author | Wenchen Fan <cloud0fan@outlook.com> | 2015-07-14 10:20:15 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2015-07-14 10:20:15 -0700 |
commit | 59d820aa8dec08b744971237860b4c6bef577ddf (patch) | |
tree | d34e1f77c56827af239a14da4e480deb631f381b /sql | |
parent | 257236c3e17906098f801cbc2059e7a9054e8cab (diff) | |
download | spark-59d820aa8dec08b744971237860b4c6bef577ddf.tar.gz spark-59d820aa8dec08b744971237860b4c6bef577ddf.tar.bz2 spark-59d820aa8dec08b744971237860b4c6bef577ddf.zip |
[SPARK-9029] [SQL] shortcut CaseKeyWhen if key is null
Author: Wenchen Fan <cloud0fan@outlook.com>
Closes #7389 from cloud-fan/case-when and squashes the following commits:
ea4b6ba [Wenchen Fan] shortcut for case key when
Diffstat (limited to 'sql')
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala | 48 |
1 files changed, 24 insertions, 24 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala index eea7706b9d..c7f039ede2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala @@ -230,24 +230,31 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW } } + private def evalElse(input: InternalRow): Any = { + if (branchesArr.length % 2 == 0) { + null + } else { + branchesArr(branchesArr.length - 1).eval(input) + } + } + /** Written in imperative fashion for performance considerations. */ override def eval(input: InternalRow): Any = { val evaluatedKey = key.eval(input) - val len = branchesArr.length - var i = 0 - // If all branches fail and an elseVal is not provided, the whole statement - // defaults to null, according to Hive's semantics. - while (i < len - 1) { - if (threeValueEquals(evaluatedKey, branchesArr(i).eval(input))) { - return branchesArr(i + 1).eval(input) + // If key is null, we can just return the else part or null if there is no else. + // If key is not null but doesn't match any when part, we need to return + // the else part or null if there is no else, according to Hive's semantics. + if (evaluatedKey != null) { + val len = branchesArr.length + var i = 0 + while (i < len - 1) { + if (evaluatedKey == branchesArr(i).eval(input)) { + return branchesArr(i + 1).eval(input) + } + i += 2 } - i += 2 } - var res: Any = null - if (i == len - 1) { - res = branchesArr(i).eval(input) - } - return res + evalElse(input) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { @@ -261,8 +268,7 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW s""" if (!$got) { ${cond.code} - if (!${keyEval.isNull} && !${cond.isNull} - && ${ctx.genEqual(key.dataType, keyEval.primitive, cond.primitive)}) { + if (!${cond.isNull} && ${ctx.genEqual(key.dataType, keyEval.primitive, cond.primitive)}) { $got = true; ${res.code} ${ev.isNull} = ${res.isNull}; @@ -290,19 +296,13 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW boolean ${ev.isNull} = true; ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; ${keyEval.code} - $cases + if (!${keyEval.isNull}) { + $cases + } $other """ } - private def threeValueEquals(l: Any, r: Any) = { - if (l == null || r == null) { - false - } else { - l == r - } - } - override def toString: String = { s"CASE $key" + branches.sliding(2, 2).map { case Seq(cond, value) => s" WHEN $cond THEN $value" |