aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2016-01-14 10:09:03 -0800
committerReynold Xin <rxin@databricks.com>2016-01-14 10:09:03 -0800
commit902667fd2766f0472a15851b1ed8fb5859593f97 (patch)
tree8807f6ff521c385c1b3d105b8af7080ad86216b8
parent501e99ef0fbd2f2165095548fe67a3447ccbfc91 (diff)
downloadspark-902667fd2766f0472a15851b1ed8fb5859593f97.tar.gz
spark-902667fd2766f0472a15851b1ed8fb5859593f97.tar.bz2
spark-902667fd2766f0472a15851b1ed8fb5859593f97.zip
[SPARK-12771][SQL] Simplify CaseWhen code generation
The generated code for CaseWhen uses a control variable "got" to make sure we do not evaluate more branches once a branch is true. Changing that to generate just simple "if / else" would be slightly more efficient. This closes #10737. Author: Reynold Xin <rxin@databricks.com> Closes #10755 from rxin/SPARK-12771.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala60
1 files changed, 35 insertions, 25 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index 8cc7bc1da2..83abbcdc61 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -137,45 +137,55 @@ case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[E
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- val got = ctx.freshName("got")
-
- val cases = branches.map { case (condition, value) =>
- val cond = condition.gen(ctx)
- val res = value.gen(ctx)
+ // Generate code that looks like:
+ //
+ // condA = ...
+ // if (condA) {
+ // valueA
+ // } else {
+ // condB = ...
+ // if (condB) {
+ // valueB
+ // } else {
+ // condC = ...
+ // if (condC) {
+ // valueC
+ // } else {
+ // elseValue
+ // }
+ // }
+ // }
+ val cases = branches.map { case (condExpr, valueExpr) =>
+ val cond = condExpr.gen(ctx)
+ val res = valueExpr.gen(ctx)
s"""
- if (!$got) {
- ${cond.code}
- if (!${cond.isNull} && ${cond.value}) {
- $got = true;
- ${res.code}
- ${ev.isNull} = ${res.isNull};
- ${ev.value} = ${res.value};
- }
+ ${cond.code}
+ if (!${cond.isNull} && ${cond.value}) {
+ ${res.code}
+ ${ev.isNull} = ${res.isNull};
+ ${ev.value} = ${res.value};
}
"""
- }.mkString("\n")
+ }
- val elseCase = {
- if (elseValue.isDefined) {
- val res = elseValue.get.gen(ctx)
+ var generatedCode = cases.mkString("", "\nelse {\n", "\nelse {\n")
+
+ elseValue.foreach { elseExpr =>
+ val res = elseExpr.gen(ctx)
+ generatedCode +=
s"""
- if (!$got) {
${res.code}
${ev.isNull} = ${res.isNull};
${ev.value} = ${res.value};
- }
"""
- } else {
- ""
- }
}
+ generatedCode += "}\n" * cases.size
+
s"""
- boolean $got = false;
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
- $cases
- $elseCase
+ $generatedCode
"""
}