diff options
author | Wenchen Fan <cloud0fan@outlook.com> | 2015-05-07 16:26:49 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2015-05-07 16:26:49 -0700 |
commit | 35f0173b8f67e2e506fc4575be6430cfb66e2238 (patch) | |
tree | 94c74039c393804752b762de0351c5dba1c9a4d7 /sql/catalyst | |
parent | 937ba798c56770ec54276b9259e47ae65ee93967 (diff) | |
download | spark-35f0173b8f67e2e506fc4575be6430cfb66e2238.tar.gz spark-35f0173b8f67e2e506fc4575be6430cfb66e2238.tar.bz2 spark-35f0173b8f67e2e506fc4575be6430cfb66e2238.zip |
[SPARK-2155] [SQL] [WHEN D THEN E] [ELSE F] add CaseKeyWhen for "CASE a WHEN b THEN c * END"
Avoid translating to CaseWhen and evaluate the key expression many times.
Author: Wenchen Fan <cloud0fan@outlook.com>
Closes #5979 from cloud-fan/condition and squashes the following commits:
3ce54e1 [Wenchen Fan] add CaseKeyWhen
Diffstat (limited to 'sql/catalyst')
5 files changed, 145 insertions, 71 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 1d3a2dc0d9..b06bfb2ce8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -296,13 +296,13 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { | LOWER ~ "(" ~> expression <~ ")" ^^ { case exp => Lower(exp) } | IF ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^ { case c ~ t ~ f => If(c, t, f) } - | CASE ~> expression.? ~ (WHEN ~> expression ~ (THEN ~> expression)).* ~ + | CASE ~> expression.? ~ rep1(WHEN ~> expression ~ (THEN ~> expression)) ~ (ELSE ~> expression).? <~ END ^^ { case casePart ~ altPart ~ elsePart => - val altExprs = altPart.flatMap { case whenExpr ~ thenExpr => - Seq(casePart.fold(whenExpr)(EqualTo(_, whenExpr)), thenExpr) - } - CaseWhen(altExprs ++ elsePart.toList) + val branches = altPart.flatMap { case whenExpr ~ thenExpr => + Seq(whenExpr, thenExpr) + } ++ elsePart + casePart.map(CaseKeyWhen(_, branches)).getOrElse(CaseWhen(branches)) } | (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) <~ ")" ^^ { case s ~ p => Substring(s, p, Literal(Integer.MAX_VALUE)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 873c75c525..168a4e30ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -631,31 +631,24 @@ trait HiveTypeCoercion { import HiveTypeCoercion._ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case cw @ CaseWhen(branches) if !cw.resolved && !branches.exists(!_.resolved) => - val valueTypes = branches.sliding(2, 2).map { - case Seq(_, value) => value.dataType - case Seq(elseVal) => elseVal.dataType - }.toSeq - - logDebug(s"Input values for null casting ${valueTypes.mkString(",")}") - - if (valueTypes.distinct.size > 1) { - val commonType = valueTypes.reduce { (v1, v2) => - findTightestCommonType(v1, v2) - .getOrElse(sys.error( - s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2")) - } - val transformedBranches = branches.sliding(2, 2).map { - case Seq(cond, value) if value.dataType != commonType => - Seq(cond, Cast(value, commonType)) - case Seq(elseVal) if elseVal.dataType != commonType => - Seq(Cast(elseVal, commonType)) - case s => s - }.reduce(_ ++ _) - CaseWhen(transformedBranches) - } else { - // Types match up. Hopefully some other rule fixes whatever is wrong with resolution. - cw + case cw: CaseWhenLike if !cw.resolved && cw.childrenResolved && !cw.valueTypesEqual => + logDebug(s"Input values for null casting ${cw.valueTypes.mkString(",")}") + val commonType = cw.valueTypes.reduce { (v1, v2) => + findTightestCommonType(v1, v2).getOrElse(sys.error( + s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2")) + } + val transformedBranches = cw.branches.sliding(2, 2).map { + case Seq(when, value) if value.dataType != commonType => + Seq(when, Cast(value, commonType)) + case Seq(elseVal) if elseVal.dataType != commonType => + Seq(Cast(elseVal, commonType)) + case s => s + }.reduce(_ ++ _) + cw match { + case _: CaseWhen => + CaseWhen(transformedBranches) + case CaseKeyWhen(key, _) => + CaseKeyWhen(key, transformedBranches) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 4fd1bc4dd6..0837a3179d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -64,7 +64,7 @@ abstract class Expression extends TreeNode[Expression] { * Returns true if all the children of this expression have been resolved to a specific schema * and false if any still contains any unresolved placeholders. */ - def childrenResolved: Boolean = !children.exists(!_.resolved) + def childrenResolved: Boolean = children.forall(_.resolved) /** * Returns a string representation of this expression that does not have developer centric diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 26c38c56c0..50b0f3ee5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -353,79 +353,134 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi override def toString: String = s"if ($predicate) $trueValue else $falseValue" } +trait CaseWhenLike extends Expression { + self: Product => + + type EvaluatedType = Any + + // Note that `branches` are considered in consecutive pairs (cond, val), and the optional last + // element is the value for the default catch-all case (if provided). + // Hence, `branches` consists of at least two elements, and can have an odd or even length. + def branches: Seq[Expression] + + @transient lazy val whenList = + branches.sliding(2, 2).collect { case Seq(whenExpr, _) => whenExpr }.toSeq + @transient lazy val thenList = + branches.sliding(2, 2).collect { case Seq(_, thenExpr) => thenExpr }.toSeq + val elseValue = if (branches.length % 2 == 0) None else Option(branches.last) + + // both then and else val should be considered. + def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType) + def valueTypesEqual: Boolean = valueTypes.distinct.size <= 1 + + override def dataType: DataType = { + if (!resolved) { + throw new UnresolvedException(this, "cannot resolve due to differing types in some branches") + } + valueTypes.head + } + + override def nullable: Boolean = { + // If no value is nullable and no elseValue is provided, the whole statement defaults to null. + thenList.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true)) + } +} + // scalastyle:off /** * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END". * Refer to this link for the corresponding semantics: * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions - * - * The other form of case statements "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END" gets - * translated to this form at parsing time. Namely, such a statement gets translated to - * "CASE WHEN a=b THEN c [WHEN a=d THEN e]* [ELSE f] END". - * - * Note that `branches` are considered in consecutive pairs (cond, val), and the optional last - * element is the value for the default catch-all case (if provided). Hence, `branches` consists of - * at least two elements, and can have an odd or even length. */ // scalastyle:on -case class CaseWhen(branches: Seq[Expression]) extends Expression { - type EvaluatedType = Any +case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { + + // Use private[this] Array to speed up evaluation. + @transient private[this] lazy val branchesArr = branches.toArray override def children: Seq[Expression] = branches - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException(this, "cannot resolve due to differing types in some branches") + override lazy val resolved: Boolean = + childrenResolved && + whenList.forall(_.dataType == BooleanType) && + valueTypesEqual + + /** Written in imperative fashion for performance considerations. */ + override def eval(input: Row): Any = { + val len = branchesArr.length + var i = 0 + // If all branches fail and an elseVal is not provided, the whole statement + // defaults to null, according to Hive's semantics. + while (i < len - 1) { + if (branchesArr(i).eval(input) == true) { + return branchesArr(i + 1).eval(input) + } + i += 2 + } + var res: Any = null + if (i == len - 1) { + res = branchesArr(i).eval(input) } - branches(1).dataType + return res } + override def toString: String = { + "CASE" + branches.sliding(2, 2).map { + case Seq(cond, value) => s" WHEN $cond THEN $value" + case Seq(elseValue) => s" ELSE $elseValue" + }.mkString + } +} + +// scalastyle:off +/** + * Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END". + * Refer to this link for the corresponding semantics: + * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions + */ +// scalastyle:on +case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseWhenLike { + + // Use private[this] Array to speed up evaluation. @transient private[this] lazy val branchesArr = branches.toArray - @transient private[this] lazy val predicates = - branches.sliding(2, 2).collect { case Seq(cond, _) => cond }.toSeq - @transient private[this] lazy val values = - branches.sliding(2, 2).collect { case Seq(_, value) => value }.toSeq - @transient private[this] lazy val elseValue = - if (branches.length % 2 == 0) None else Option(branches.last) - override def nullable: Boolean = { - // If no value is nullable and no elseValue is provided, the whole statement defaults to null. - values.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true)) - } + override def children: Seq[Expression] = key +: branches - override lazy val resolved: Boolean = { - if (!childrenResolved) { - false - } else { - val allCondBooleans = predicates.forall(_.dataType == BooleanType) - // both then and else val should be considered. - val dataTypesEqual = (values ++ elseValue).map(_.dataType).distinct.size <= 1 - allCondBooleans && dataTypesEqual - } - } + override lazy val resolved: Boolean = + childrenResolved && valueTypesEqual /** Written in imperative fashion for performance considerations. */ override def eval(input: Row): Any = { + val evaluatedKey = key.eval(input) val len = branchesArr.length var i = 0 // If all branches fail and an elseVal is not provided, the whole statement // defaults to null, according to Hive's semantics. - var res: Any = null while (i < len - 1) { - if (branchesArr(i).eval(input) == true) { - res = branchesArr(i + 1).eval(input) - return res + if (equalNullSafe(evaluatedKey, branchesArr(i).eval(input))) { + return branchesArr(i + 1).eval(input) } i += 2 } + var res: Any = null if (i == len - 1) { res = branchesArr(i).eval(input) } - res + return res + } + + private def equalNullSafe(l: Any, r: Any) = { + if (l == null && r == null) { + true + } else if (l == null || r == null) { + false + } else { + l == r + } } override def toString: String = { - "CASE" + branches.sliding(2, 2).map { + s"CASE $key" + branches.sliding(2, 2).map { case Seq(cond, value) => s" WHEN $cond THEN $value" case Seq(elseValue) => s" ELSE $elseValue" }.mkString diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index faaa55aa5e..88d36d153c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -850,6 +850,32 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite { assert(CaseWhen(Seq(c2, c4_notNull, c3, c5)).nullable === true) } + test("case key when") { + val row = create_row(null, 1, 2, "a", "b", "c") + val c1 = 'a.int.at(0) + val c2 = 'a.int.at(1) + val c3 = 'a.int.at(2) + val c4 = 'a.string.at(3) + val c5 = 'a.string.at(4) + val c6 = 'a.string.at(5) + + val literalNull = Literal.create(null, BooleanType) + val literalInt = Literal(1) + val literalString = Literal("a") + + checkEvaluation(CaseKeyWhen(c1, Seq(c2, c4, c5)), "b", row) + checkEvaluation(CaseKeyWhen(c1, Seq(c2, c4, literalNull, c5, c6)), "b", row) + checkEvaluation(CaseKeyWhen(c2, Seq(literalInt, c4, c5)), "a", row) + checkEvaluation(CaseKeyWhen(c2, Seq(c1, c4, c5)), "b", row) + checkEvaluation(CaseKeyWhen(c4, Seq(literalString, c2, c3)), 1, row) + checkEvaluation(CaseKeyWhen(c4, Seq(c1, c3, c5, c2, Literal(3))), 3, row) + + checkEvaluation(CaseKeyWhen(literalInt, Seq(c2, c4, c5)), "a", row) + checkEvaluation(CaseKeyWhen(literalString, Seq(c5, c2, c4, c3)), 2, row) + checkEvaluation(CaseKeyWhen(literalInt, Seq(c5, c2, c4, c3)), null, row) + checkEvaluation(CaseKeyWhen(literalNull, Seq(c5, c2, c1, c3)), 2, row) + } + test("complex type") { val row = create_row( "^Ba*n", // 0 |