From e243c5ffacd70ecadaf5c91668955dcc8141e060 Mon Sep 17 00:00:00 2001 From: Zongheng Yang Date: Tue, 17 Jun 2014 13:30:17 +0200 Subject: [SPARK-2053][SQL] Add Catalyst expressions for CASE WHEN. JIRA ticket: https://issues.apache.org/jira/browse/SPARK-2053 This PR adds support for two types of CASE statements present in Hive. The first type is of the form `CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END`, with the semantics like a chain of if statements. The second type is of the form `CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END`, with the semantics like a switch statement on key `a`. Both forms are implemented in `CaseWhen`. [This link](https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions) contains more detailed descriptions on their semantics. Notes / Open issues: * Please check if any implicit contracts / invariants are broken in the implementations (especially for the operators). I am not very familiar with them and I currently find them tricky to spot. * We should decide whether or not a non-boolean condition is allowed in a branch of `CaseWhen`. Hive throws a `SemanticException` for this situation and I think it'd be good to mimic it -- the question is where in the whole Spark SQL pipeline should we signal an exception for such a query. Author: Zongheng Yang Closes #1055 from concretevitamin/caseWhen and squashes the following commits: 4226eb9 [Zongheng Yang] Comment. 79d26fc [Zongheng Yang] Merge branch 'master' into caseWhen caf9383 [Zongheng Yang] Update a FIXME. 9d26ab8 [Zongheng Yang] Add @transient marker. 788a0d9 [Zongheng Yang] Implement CastNulls, which fixes udf_case and udf_when. 7ef284f [Zongheng Yang] Refactors: remove redundant passes, improve toString, mark transient. f47ae7b [Zongheng Yang] Modify queries in tests to have shorter golden files. 1c1fbfc [Zongheng Yang] Cleanups per review comments. 7d2b7e2 [Zongheng Yang] Translate CaseKeyWhen to CaseWhen at parsing time. 47d406a [Zongheng Yang] Do toArray once and lazily outside of eval(). bb3d109 [Zongheng Yang] Update scaladoc of a method. aea3195 [Zongheng Yang] Fix bug that branchesArr is not used; remove unused import. 96870a8 [Zongheng Yang] Turn off scalastyle for some comments. 7392f3a [Zongheng Yang] Minor cleanup. 2cf08bb [Zongheng Yang] Merge branch 'master' into caseWhen 9f84b40 [Zongheng Yang] Add golden outputs from Hive. db51a85 [Zongheng Yang] Add allCondBooleans check; uncomment tests. 3f9ef0a [Zongheng Yang] Cleanups and bug fixes (mainly in eval() and resolved). be54bc8 [Zongheng Yang] Rewrite eval() to a low-level implementation. Separate two CASE stmts. f2bcb9d [Zongheng Yang] WIP 5906f75 [Zongheng Yang] WIP efd019b [Zongheng Yang] eval() and toString() bug fixes. 7d81e95 [Zongheng Yang] Clean up resolved. a31d782 [Zongheng Yang] Finish up Case. --- .../sql/catalyst/analysis/HiveTypeCoercion.scala | 41 +++++++++++- .../sql/catalyst/expressions/Expression.scala | 10 ++- .../sql/catalyst/expressions/predicates.scala | 76 +++++++++++++++++++++- .../apache/spark/sql/catalyst/util/package.scala | 2 +- .../expressions/ExpressionEvaluationSuite.scala | 2 +- .../scala/org/apache/spark/sql/hive/HiveQl.scala | 17 +++++ ...THOUT key #1-0-36750f0f6727c287c471309689ff7563 | 14 ++++ ...THOUT key #2-0-e3a2b981ebff7e273537dd6c43ece0c0 | 14 ++++ ...THOUT key #3-0-be5efc0574a97ec465e2686f4a724bd5 | 14 ++++ ...THOUT key #4-0-631f824a91b7230657bea7a05e393a1e | 14 ++++ ... with key #1-0-616830b2011da0990e87a188fb609299 | 14 ++++ ... with key #2-0-6c5b5a997949f9e5ab9676b60e95657b | 14 ++++ ... with key #3-0-a241862582c47d9e98be95339d35c7c4 | 14 ++++ ... with key #4-0-ea87ca38ead8858d2337792dcd430226 | 14 ++++ .../spark/sql/hive/execution/HiveQuerySuite.scala | 38 +++++++++++ 15 files changed, 290 insertions(+), 8 deletions(-) create mode 100644 sql/hive/src/test/resources/golden/case statements WITHOUT key #1-0-36750f0f6727c287c471309689ff7563 create mode 100644 sql/hive/src/test/resources/golden/case statements WITHOUT key #2-0-e3a2b981ebff7e273537dd6c43ece0c0 create mode 100644 sql/hive/src/test/resources/golden/case statements WITHOUT key #3-0-be5efc0574a97ec465e2686f4a724bd5 create mode 100644 sql/hive/src/test/resources/golden/case statements WITHOUT key #4-0-631f824a91b7230657bea7a05e393a1e create mode 100644 sql/hive/src/test/resources/golden/case statements with key #1-0-616830b2011da0990e87a188fb609299 create mode 100644 sql/hive/src/test/resources/golden/case statements with key #2-0-6c5b5a997949f9e5ab9676b60e95657b create mode 100644 sql/hive/src/test/resources/golden/case statements with key #3-0-a241862582c47d9e98be95339d35c7c4 create mode 100644 sql/hive/src/test/resources/golden/case statements with key #4-0-ea87ca38ead8858d2337792dcd430226 (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 326feea6fe..d291814c8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -31,8 +31,16 @@ import org.apache.spark.sql.catalyst.types._ trait HiveTypeCoercion { val typeCoercionRules = - List(PropagateTypes, ConvertNaNs, WidenTypes, PromoteStrings, BooleanComparisons, BooleanCasts, - StringToIntegralCasts, FunctionArgumentConversion) + PropagateTypes :: + ConvertNaNs :: + WidenTypes :: + PromoteStrings :: + BooleanComparisons :: + BooleanCasts :: + StringToIntegralCasts :: + FunctionArgumentConversion :: + CastNulls :: + Nil /** * Applies any changes to [[catalyst.expressions.AttributeReference AttributeReference]] data @@ -282,4 +290,33 @@ trait HiveTypeCoercion { Average(Cast(e, DoubleType)) } } + + /** + * Ensures that NullType gets casted to some other types under certain circumstances. + */ + object CastNulls extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case cw @ CaseWhen(branches) => + val valueTypes = branches.sliding(2, 2).map { + case Seq(_, value) if value.resolved => Some(value.dataType) + case Seq(elseVal) if elseVal.resolved => Some(elseVal.dataType) + case _ => None + }.toSeq + if (valueTypes.distinct.size == 2 && valueTypes.exists(_ == Some(NullType))) { + val otherType = valueTypes.filterNot(_ == Some(NullType))(0).get + val transformedBranches = branches.sliding(2, 2).map { + case Seq(cond, value) if value.resolved && value.dataType == NullType => + Seq(cond, Cast(value, otherType)) + case Seq(elseVal) if elseVal.resolved && elseVal.dataType == NullType => + Seq(Cast(elseVal, otherType)) + case s => s + }.reduce(_ ++ _) + CaseWhen(transformedBranches) + } else { + // It is possible to have more types due to the possibility of short-circuiting. + cw + } + } + } + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 41398ff956..3912f5f437 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -28,8 +28,6 @@ abstract class Expression extends TreeNode[Expression] { /** The narrowest possible type that is produced when this expression is evaluated. */ type EvaluatedType <: Any - def dataType: DataType - /** * Returns true when an expression is a candidate for static evaluation before the query is * executed. @@ -53,12 +51,18 @@ abstract class Expression extends TreeNode[Expression] { /** * Returns `true` if this expression and all its children have been resolved to a specific schema - * and `false` if it is still contains any unresolved placeholders. Implementations of expressions + * and `false` if it still contains any unresolved placeholders. Implementations of expressions * should override this if the resolution of this type of expression involves more than just * the resolution of its children. */ lazy val resolved: Boolean = childrenResolved + /** + * Returns the [[types.DataType DataType]] of the result of evaluating this expression. It is + * invalid to query the dataType of an unresolved expression (i.e., when `resolved` == false). + */ + def dataType: DataType + /** * Returns true if all the children of this expression have been resolved to a specific schema * and false if any still contains any unresolved placeholders. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index d111578530..2902906df2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.types.BooleanType @@ -202,3 +201,78 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi override def toString = s"if ($predicate) $trueValue else $falseValue" } + +// scalastyle:off +/** + * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END". + * Refer to this link for the corresponding semantics: + * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions + * + * The other form of case statements "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END" gets + * translated to this form at parsing time. Namely, such a statement gets translated to + * "CASE WHEN a=b THEN c [WHEN a=d THEN e]* [ELSE f] END". + * + * Note that `branches` are considered in consecutive pairs (cond, val), and the optional last + * element is the value for the default catch-all case (if provided). Hence, `branches` consists of + * at least two elements, and can have an odd or even length. + */ +// scalastyle:on +case class CaseWhen(branches: Seq[Expression]) extends Expression { + type EvaluatedType = Any + def children = branches + def references = children.flatMap(_.references).toSet + def dataType = { + if (!resolved) { + throw new UnresolvedException(this, "cannot resolve due to differing types in some branches") + } + branches(1).dataType + } + + @transient private[this] lazy val branchesArr = branches.toArray + @transient private[this] lazy val predicates = + branches.sliding(2, 2).collect { case Seq(cond, _) => cond }.toSeq + @transient private[this] lazy val values = + branches.sliding(2, 2).collect { case Seq(_, value) => value }.toSeq + + override def nullable = { + // If no value is nullable and no elseValue is provided, the whole statement defaults to null. + values.exists(_.nullable) || (values.length % 2 == 0) + } + + override lazy val resolved = { + if (!childrenResolved) { + false + } else { + val allCondBooleans = predicates.forall(_.dataType == BooleanType) + val dataTypesEqual = values.map(_.dataType).distinct.size <= 1 + allCondBooleans && dataTypesEqual + } + } + + /** Written in imperative fashion for performance considerations. Same for CaseKeyWhen. */ + override def eval(input: Row): Any = { + val len = branchesArr.length + var i = 0 + // If all branches fail and an elseVal is not provided, the whole statement + // defaults to null, according to Hive's semantics. + var res: Any = null + while (i < len - 1) { + if (branchesArr(i).eval(input) == true) { + res = branchesArr(i + 1).eval(input) + return res + } + i += 2 + } + if (i == len - 1) { + res = branchesArr(i).eval(input) + } + res + } + + override def toString = { + "CASE" + branches.sliding(2, 2).map { + case Seq(cond, value) => s" WHEN $cond THEN $value" + case Seq(elseValue) => s" ELSE $elseValue" + }.mkString + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index 49fc4f70fd..d8da45ae70 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -115,7 +115,7 @@ package object util { } /* FIX ME - implicit class debugLogging(a: AnyRef) { + implicit class debugLogging(a: Any) { def debugLogging() { org.apache.log4j.Logger.getLogger(a.getClass.getName).setLevel(org.apache.log4j.Level.DEBUG) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 1132a30b42..8c3b062d0f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -35,7 +35,7 @@ class ExpressionEvaluationSuite extends FunSuite { /** * Checks for three-valued-logic. Based on: * http://en.wikipedia.org/wiki/Null_(SQL)#Comparisons_with_NULL_and_the_three-valued_logic_.283VL.29 - * + * I.e. in flat cpo "False -> Unknown -> True", OR is lowest upper bound, AND is greatest lower bound. * p q p OR q p AND q p = q * True True True True True * True False True False False diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index b745d8ffd8..844673f66d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -811,6 +811,8 @@ private[hive] object HiveQl { val IN = "(?i)IN".r val DIV = "(?i)DIV".r val BETWEEN = "(?i)BETWEEN".r + val WHEN = "(?i)WHEN".r + val CASE = "(?i)CASE".r protected def nodeToExpr(node: Node): Expression = node match { /* Attribute References */ @@ -917,6 +919,21 @@ private[hive] object HiveQl { case Token(OR(), left :: right:: Nil) => Or(nodeToExpr(left), nodeToExpr(right)) case Token(NOT(), child :: Nil) => Not(nodeToExpr(child)) + /* Case statements */ + case Token("TOK_FUNCTION", Token(WHEN(), Nil) :: branches) => + CaseWhen(branches.map(nodeToExpr)) + case Token("TOK_FUNCTION", Token(CASE(), Nil) :: branches) => + val transformed = branches.drop(1).sliding(2, 2).map { + case Seq(condVal, value) => + // FIXME (SPARK-2155): the key will get evaluated for multiple times in CaseWhen's eval(). + // Hence effectful / non-deterministic key expressions are *not* supported at the moment. + // We should consider adding new Expressions to get around this. + Seq(Equals(nodeToExpr(branches(0)), nodeToExpr(condVal)), + nodeToExpr(value)) + case Seq(elseVal) => Seq(nodeToExpr(elseVal)) + }.toSeq.reduce(_ ++ _) + CaseWhen(transformed) + /* Complex datatype manipulation */ case Token("[", child :: ordinal :: Nil) => GetItem(nodeToExpr(child), nodeToExpr(ordinal)) diff --git a/sql/hive/src/test/resources/golden/case statements WITHOUT key #1-0-36750f0f6727c287c471309689ff7563 b/sql/hive/src/test/resources/golden/case statements WITHOUT key #1-0-36750f0f6727c287c471309689ff7563 new file mode 100644 index 0000000000..816fe57d16 --- /dev/null +++ b/sql/hive/src/test/resources/golden/case statements WITHOUT key #1-0-36750f0f6727c287c471309689ff7563 @@ -0,0 +1,14 @@ +NULL +3 +3 +3 +NULL +NULL +3 +3 +3 +3 +NULL +3 +3 +3 diff --git a/sql/hive/src/test/resources/golden/case statements WITHOUT key #2-0-e3a2b981ebff7e273537dd6c43ece0c0 b/sql/hive/src/test/resources/golden/case statements WITHOUT key #2-0-e3a2b981ebff7e273537dd6c43ece0c0 new file mode 100644 index 0000000000..4cca081e6e --- /dev/null +++ b/sql/hive/src/test/resources/golden/case statements WITHOUT key #2-0-e3a2b981ebff7e273537dd6c43ece0c0 @@ -0,0 +1,14 @@ +4 +3 +3 +3 +4 +4 +3 +3 +3 +3 +4 +3 +3 +3 diff --git a/sql/hive/src/test/resources/golden/case statements WITHOUT key #3-0-be5efc0574a97ec465e2686f4a724bd5 b/sql/hive/src/test/resources/golden/case statements WITHOUT key #3-0-be5efc0574a97ec465e2686f4a724bd5 new file mode 100644 index 0000000000..8d0416a8f8 --- /dev/null +++ b/sql/hive/src/test/resources/golden/case statements WITHOUT key #3-0-be5efc0574a97ec465e2686f4a724bd5 @@ -0,0 +1,14 @@ +2 +3 +3 +3 +2 +2 +3 +3 +3 +3 +NULL +3 +3 +3 diff --git a/sql/hive/src/test/resources/golden/case statements WITHOUT key #4-0-631f824a91b7230657bea7a05e393a1e b/sql/hive/src/test/resources/golden/case statements WITHOUT key #4-0-631f824a91b7230657bea7a05e393a1e new file mode 100644 index 0000000000..6ed452bcd8 --- /dev/null +++ b/sql/hive/src/test/resources/golden/case statements WITHOUT key #4-0-631f824a91b7230657bea7a05e393a1e @@ -0,0 +1,14 @@ +2 +3 +3 +3 +2 +2 +3 +3 +3 +3 +0 +3 +3 +3 diff --git a/sql/hive/src/test/resources/golden/case statements with key #1-0-616830b2011da0990e87a188fb609299 b/sql/hive/src/test/resources/golden/case statements with key #1-0-616830b2011da0990e87a188fb609299 new file mode 100644 index 0000000000..3f5a2fbbe9 --- /dev/null +++ b/sql/hive/src/test/resources/golden/case statements with key #1-0-616830b2011da0990e87a188fb609299 @@ -0,0 +1,14 @@ +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL diff --git a/sql/hive/src/test/resources/golden/case statements with key #2-0-6c5b5a997949f9e5ab9676b60e95657b b/sql/hive/src/test/resources/golden/case statements with key #2-0-6c5b5a997949f9e5ab9676b60e95657b new file mode 100644 index 0000000000..e1ca6e76d1 --- /dev/null +++ b/sql/hive/src/test/resources/golden/case statements with key #2-0-6c5b5a997949f9e5ab9676b60e95657b @@ -0,0 +1,14 @@ +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +3 +0 +0 +0 diff --git a/sql/hive/src/test/resources/golden/case statements with key #3-0-a241862582c47d9e98be95339d35c7c4 b/sql/hive/src/test/resources/golden/case statements with key #3-0-a241862582c47d9e98be95339d35c7c4 new file mode 100644 index 0000000000..896207fdbc --- /dev/null +++ b/sql/hive/src/test/resources/golden/case statements with key #3-0-a241862582c47d9e98be95339d35c7c4 @@ -0,0 +1,14 @@ +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +3 +NULL +NULL +NULL diff --git a/sql/hive/src/test/resources/golden/case statements with key #4-0-ea87ca38ead8858d2337792dcd430226 b/sql/hive/src/test/resources/golden/case statements with key #4-0-ea87ca38ead8858d2337792dcd430226 new file mode 100644 index 0000000000..e1ca6e76d1 --- /dev/null +++ b/sql/hive/src/test/resources/golden/case statements with key #4-0-ea87ca38ead8858d2337792dcd430226 @@ -0,0 +1,14 @@ +0 +0 +0 +0 +0 +0 +0 +0 +0 +0 +3 +0 +0 +0 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 04652587f9..fe698f0fc5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -164,6 +164,44 @@ class HiveQuerySuite extends HiveComparisonTest { hql("SELECT * FROM src").toString } + createQueryTest("case statements with key #1", + "SELECT (CASE 1 WHEN 2 THEN 3 END) FROM src where key < 15") + + createQueryTest("case statements with key #2", + "SELECT (CASE key WHEN 2 THEN 3 ELSE 0 END) FROM src WHERE key < 15") + + createQueryTest("case statements with key #3", + "SELECT (CASE key WHEN 2 THEN 3 WHEN NULL THEN 4 END) FROM src WHERE key < 15") + + createQueryTest("case statements with key #4", + "SELECT (CASE key WHEN 2 THEN 3 WHEN NULL THEN 4 ELSE 0 END) FROM src WHERE key < 15") + + createQueryTest("case statements WITHOUT key #1", + "SELECT (CASE WHEN key > 2 THEN 3 END) FROM src WHERE key < 15") + + createQueryTest("case statements WITHOUT key #2", + "SELECT (CASE WHEN key > 2 THEN 3 ELSE 4 END) FROM src WHERE key < 15") + + createQueryTest("case statements WITHOUT key #3", + "SELECT (CASE WHEN key > 2 THEN 3 WHEN 2 > key THEN 2 END) FROM src WHERE key < 15") + + createQueryTest("case statements WITHOUT key #4", + "SELECT (CASE WHEN key > 2 THEN 3 WHEN 2 > key THEN 2 ELSE 0 END) FROM src WHERE key < 15") + + test("implement identity function using case statement") { + val actual = hql("SELECT (CASE key WHEN key THEN key END) FROM src").collect().toSet + val expected = hql("SELECT key FROM src").collect().toSet + assert(actual === expected) + } + + // TODO: adopt this test when Spark SQL has the functionality / framework to report errors. + // See https://github.com/apache/spark/pull/1055#issuecomment-45820167 for a discussion. + ignore("non-boolean conditions in a CaseWhen are illegal") { + intercept[Exception] { + hql("SELECT (CASE WHEN key > 2 THEN 3 WHEN 1 THEN 2 ELSE 0 END) FROM src").collect() + } + } + private val explainCommandClassName = classOf[execution.ExplainCommand].getSimpleName.stripSuffix("$") -- cgit v1.2.3