aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2016-01-12 11:13:08 -0800
committerReynold Xin <rxin@databricks.com>2016-01-12 11:13:08 -0800
commit0ed430e315b9a409490a3604a619321b476cb520 (patch)
treec76c7c28acb95160f04d7a1e2ab05cfd36f4004a
parent508592b1bae3b2c88350ddfc1d909892f236ce5f (diff)
downloadspark-0ed430e315b9a409490a3604a619321b476cb520.tar.gz
spark-0ed430e315b9a409490a3604a619321b476cb520.tar.bz2
spark-0ed430e315b9a409490a3604a619321b476cb520.zip
[SPARK-12768][SQL] Remove CaseKeyWhen expression
This patch removes CaseKeyWhen expression and replaces it with a factory method that generates the equivalent CaseWhen. This reduces the amount of code we'd need to maintain in the future for both code generation and optimizer. Note that we introduced CaseKeyWhen to avoid duplicate evaluations of the key. This is no longer a problem because we now have common subexpression elimination. Author: Reynold Xin <rxin@databricks.com> Closes #10722 from rxin/SPARK-12768.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala20
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala187
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala2
3 files changed, 38 insertions, 171 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index e326ea7827..75c36d9310 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -638,8 +638,7 @@ object HiveTypeCoercion {
*/
object CaseWhenCoercion extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
- case c: CaseWhenLike if c.childrenResolved && !c.valueTypesEqual =>
- logDebug(s"Input values for null casting ${c.valueTypes.mkString(",")}")
+ case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual =>
val maybeCommonType = findWiderCommonType(c.valueTypes)
maybeCommonType.map { commonType =>
val castedBranches = c.branches.grouped(2).map {
@@ -649,22 +648,7 @@ object HiveTypeCoercion {
Seq(Cast(elseVal, commonType))
case other => other
}.reduce(_ ++ _)
- c match {
- case _: CaseWhen => CaseWhen(castedBranches)
- case CaseKeyWhen(key, _) => CaseKeyWhen(key, castedBranches)
- }
- }.getOrElse(c)
-
- case c: CaseKeyWhen if c.childrenResolved && !c.resolved =>
- val maybeCommonType =
- findWiderCommonType((c.key +: c.whenList).map(_.dataType))
- maybeCommonType.map { commonType =>
- val castedBranches = c.branches.grouped(2).map {
- case Seq(whenExpr, thenExpr) if whenExpr.dataType != commonType =>
- Seq(Cast(whenExpr, commonType), thenExpr)
- case other => other
- }.reduce(_ ++ _)
- CaseKeyWhen(Cast(c.key, commonType), castedBranches)
+ CaseWhen(castedBranches)
}.getOrElse(c)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index 379e62a26e..5a1462433d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.util.{sequenceOption, TypeUtils}
+import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
@@ -78,17 +78,23 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
override def sql: String = s"(IF(${predicate.sql}, ${trueValue.sql}, ${falseValue.sql}))"
}
-trait CaseWhenLike extends Expression {
+/**
+ * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
+ * When a = true, returns b; when c = true, returns d; else returns e.
+ */
+case class CaseWhen(branches: Seq[Expression]) extends Expression {
+
+ // Use private[this] Array to speed up evaluation.
+ @transient private[this] lazy val branchesArr = branches.toArray
- // Note that `branches` are considered in consecutive pairs (cond, val), and the optional last
- // element is the value for the default catch-all case (if provided).
- // Hence, `branches` consists of at least two elements, and can have an odd or even length.
- def branches: Seq[Expression]
+ override def children: Seq[Expression] = branches
@transient lazy val whenList =
branches.sliding(2, 2).collect { case Seq(whenExpr, _) => whenExpr }.toSeq
+
@transient lazy val thenList =
branches.sliding(2, 2).collect { case Seq(_, thenExpr) => thenExpr }.toSeq
+
val elseValue = if (branches.length % 2 == 0) None else Option(branches.last)
// both then and else expressions should be considered.
@@ -97,47 +103,26 @@ trait CaseWhenLike extends Expression {
case Seq(dt1, dt2) => dt1.sameType(dt2)
}
- override def checkInputDataTypes(): TypeCheckResult = {
- if (valueTypesEqual) {
- checkTypesInternal()
- } else {
- TypeCheckResult.TypeCheckFailure(
- "THEN and ELSE expressions should all be same type or coercible to a common type")
- }
- }
-
- protected def checkTypesInternal(): TypeCheckResult
-
override def dataType: DataType = thenList.head.dataType
override def nullable: Boolean = {
// If no value is nullable and no elseValue is provided, the whole statement defaults to null.
thenList.exists(_.nullable) || elseValue.map(_.nullable).getOrElse(true)
}
-}
-
-// scalastyle:off
-/**
- * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
- * Refer to this link for the corresponding semantics:
- * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions
- */
-// scalastyle:on
-case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike {
-
- // Use private[this] Array to speed up evaluation.
- @transient private[this] lazy val branchesArr = branches.toArray
- override def children: Seq[Expression] = branches
-
- override protected def checkTypesInternal(): TypeCheckResult = {
- if (whenList.forall(_.dataType == BooleanType)) {
- TypeCheckResult.TypeCheckSuccess
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (valueTypesEqual) {
+ if (whenList.forall(_.dataType == BooleanType)) {
+ TypeCheckResult.TypeCheckSuccess
+ } else {
+ val index = whenList.indexWhere(_.dataType != BooleanType)
+ TypeCheckResult.TypeCheckFailure(
+ s"WHEN expressions in CaseWhen should all be boolean type, " +
+ s"but the ${index + 1}th when expression's type is ${whenList(index)}")
+ }
} else {
- val index = whenList.indexWhere(_.dataType != BooleanType)
TypeCheckResult.TypeCheckFailure(
- s"WHEN expressions in CaseWhen should all be boolean type, " +
- s"but the ${index + 1}th when expression's type is ${whenList(index)}")
+ "THEN and ELSE expressions should all be same type or coercible to a common type")
}
}
@@ -227,125 +212,23 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike {
}
}
-// scalastyle:off
/**
* Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END".
- * Refer to this link for the corresponding semantics:
- * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions
+ * When a = b, returns c; when a = d, returns e; else returns f.
*/
-// scalastyle:on
-case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseWhenLike {
-
- // Use private[this] Array to speed up evaluation.
- @transient private[this] lazy val branchesArr = branches.toArray
-
- override def children: Seq[Expression] = key +: branches
-
- override protected def checkTypesInternal(): TypeCheckResult = {
- if ((key +: whenList).map(_.dataType).distinct.size > 1) {
- TypeCheckResult.TypeCheckFailure(
- "key and WHEN expressions should all be same type or coercible to a common type")
- } else {
- TypeCheckResult.TypeCheckSuccess
- }
- }
-
- private def evalElse(input: InternalRow): Any = {
- if (branchesArr.length % 2 == 0) {
- null
- } else {
- branchesArr(branchesArr.length - 1).eval(input)
- }
- }
-
- /** Written in imperative fashion for performance considerations. */
- override def eval(input: InternalRow): Any = {
- val evaluatedKey = key.eval(input)
- // If key is null, we can just return the else part or null if there is no else.
- // If key is not null but doesn't match any when part, we need to return
- // the else part or null if there is no else, according to Hive's semantics.
- if (evaluatedKey != null) {
- val len = branchesArr.length
- var i = 0
- while (i < len - 1) {
- if (evaluatedKey == branchesArr(i).eval(input)) {
- return branchesArr(i + 1).eval(input)
- }
- i += 2
- }
- }
- evalElse(input)
- }
-
- override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- val keyEval = key.gen(ctx)
- val len = branchesArr.length
- val got = ctx.freshName("got")
-
- val cases = (0 until len/2).map { i =>
- val cond = branchesArr(i * 2).gen(ctx)
- val res = branchesArr(i * 2 + 1).gen(ctx)
- s"""
- if (!$got) {
- ${cond.code}
- if (!${cond.isNull} && ${ctx.genEqual(key.dataType, keyEval.value, cond.value)}) {
- $got = true;
- ${res.code}
- ${ev.isNull} = ${res.isNull};
- ${ev.value} = ${res.value};
- }
- }
- """
- }.mkString("\n")
-
- val other = if (len % 2 == 1) {
- val res = branchesArr(len - 1).gen(ctx)
- s"""
- if (!$got) {
- ${res.code}
- ${ev.isNull} = ${res.isNull};
- ${ev.value} = ${res.value};
- }
- """
- } else {
- ""
- }
-
- s"""
- boolean $got = false;
- boolean ${ev.isNull} = true;
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
- ${keyEval.code}
- if (!${keyEval.isNull}) {
- $cases
+object CaseKeyWhen {
+ def apply(key: Expression, branches: Seq[Expression]): CaseWhen = {
+ val newBranches = branches.zipWithIndex.map { case (expr, i) =>
+ if (i % 2 == 0 && i != branches.size - 1) {
+ // If this expression is at even position, then it is either a branch condition, or
+ // the very last value that is the "else value". The "i != branches.size - 1" makes
+ // sure we are not adding an EqualTo to the "else value".
+ EqualTo(key, expr)
+ } else {
+ expr
}
- $other
- """
- }
-
- override def toString: String = {
- s"CASE $key" + branches.sliding(2, 2).map {
- case Seq(cond, value) => s" WHEN $cond THEN $value"
- case Seq(elseValue) => s" ELSE $elseValue"
- }.mkString
- }
-
- override def sql: String = {
- val keySQL = key.sql
- val branchesSQL = branches.map(_.sql)
- val (cases, maybeElse) = if (branches.length % 2 == 0) {
- (branchesSQL, None)
- } else {
- (branchesSQL.init, Some(branchesSQL.last))
}
-
- val head = s"CASE $keySQL "
- val tail = maybeElse.map(e => s" ELSE $e").getOrElse("") + " END"
- val body = cases.grouped(2).map {
- case Seq(whenExpr, thenExpr) => s"WHEN $whenExpr THEN $thenExpr"
- }.mkString(" ")
-
- head + body + tail
+ CaseWhen(newBranches)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index 58d808c558..23b11af9ac 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -299,7 +299,7 @@ class HiveTypeCoercionSuite extends PlanTest {
}
test("type coercion for CaseKeyWhen") {
- ruleTest(HiveTypeCoercion.CaseWhenCoercion,
+ ruleTest(HiveTypeCoercion.ImplicitTypeCasts,
CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))),
CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a")))
)