aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-05-07 16:26:49 -0700
committerMichael Armbrust <michael@databricks.com>2015-05-07 16:26:49 -0700
commit35f0173b8f67e2e506fc4575be6430cfb66e2238 (patch)
tree94c74039c393804752b762de0351c5dba1c9a4d7 /sql
parent937ba798c56770ec54276b9259e47ae65ee93967 (diff)
downloadspark-35f0173b8f67e2e506fc4575be6430cfb66e2238.tar.gz
spark-35f0173b8f67e2e506fc4575be6430cfb66e2238.tar.bz2
spark-35f0173b8f67e2e506fc4575be6430cfb66e2238.zip
[SPARK-2155] [SQL] [WHEN D THEN E] [ELSE F] add CaseKeyWhen for "CASE a WHEN b THEN c * END"
Avoid translating to CaseWhen and evaluate the key expression many times. Author: Wenchen Fan <cloud0fan@outlook.com> Closes #5979 from cloud-fan/condition and squashes the following commits: 3ce54e1 [Wenchen Fan] add CaseKeyWhen
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala43
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala135
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala26
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala9
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala12
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala7
8 files changed, 159 insertions, 85 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index 1d3a2dc0d9..b06bfb2ce8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -296,13 +296,13 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
| LOWER ~ "(" ~> expression <~ ")" ^^ { case exp => Lower(exp) }
| IF ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^
{ case c ~ t ~ f => If(c, t, f) }
- | CASE ~> expression.? ~ (WHEN ~> expression ~ (THEN ~> expression)).* ~
+ | CASE ~> expression.? ~ rep1(WHEN ~> expression ~ (THEN ~> expression)) ~
(ELSE ~> expression).? <~ END ^^ {
case casePart ~ altPart ~ elsePart =>
- val altExprs = altPart.flatMap { case whenExpr ~ thenExpr =>
- Seq(casePart.fold(whenExpr)(EqualTo(_, whenExpr)), thenExpr)
- }
- CaseWhen(altExprs ++ elsePart.toList)
+ val branches = altPart.flatMap { case whenExpr ~ thenExpr =>
+ Seq(whenExpr, thenExpr)
+ } ++ elsePart
+ casePart.map(CaseKeyWhen(_, branches)).getOrElse(CaseWhen(branches))
}
| (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) <~ ")" ^^
{ case s ~ p => Substring(s, p, Literal(Integer.MAX_VALUE)) }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 873c75c525..168a4e30ea 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -631,31 +631,24 @@ trait HiveTypeCoercion {
import HiveTypeCoercion._
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
- case cw @ CaseWhen(branches) if !cw.resolved && !branches.exists(!_.resolved) =>
- val valueTypes = branches.sliding(2, 2).map {
- case Seq(_, value) => value.dataType
- case Seq(elseVal) => elseVal.dataType
- }.toSeq
-
- logDebug(s"Input values for null casting ${valueTypes.mkString(",")}")
-
- if (valueTypes.distinct.size > 1) {
- val commonType = valueTypes.reduce { (v1, v2) =>
- findTightestCommonType(v1, v2)
- .getOrElse(sys.error(
- s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2"))
- }
- val transformedBranches = branches.sliding(2, 2).map {
- case Seq(cond, value) if value.dataType != commonType =>
- Seq(cond, Cast(value, commonType))
- case Seq(elseVal) if elseVal.dataType != commonType =>
- Seq(Cast(elseVal, commonType))
- case s => s
- }.reduce(_ ++ _)
- CaseWhen(transformedBranches)
- } else {
- // Types match up. Hopefully some other rule fixes whatever is wrong with resolution.
- cw
+ case cw: CaseWhenLike if !cw.resolved && cw.childrenResolved && !cw.valueTypesEqual =>
+ logDebug(s"Input values for null casting ${cw.valueTypes.mkString(",")}")
+ val commonType = cw.valueTypes.reduce { (v1, v2) =>
+ findTightestCommonType(v1, v2).getOrElse(sys.error(
+ s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2"))
+ }
+ val transformedBranches = cw.branches.sliding(2, 2).map {
+ case Seq(when, value) if value.dataType != commonType =>
+ Seq(when, Cast(value, commonType))
+ case Seq(elseVal) if elseVal.dataType != commonType =>
+ Seq(Cast(elseVal, commonType))
+ case s => s
+ }.reduce(_ ++ _)
+ cw match {
+ case _: CaseWhen =>
+ CaseWhen(transformedBranches)
+ case CaseKeyWhen(key, _) =>
+ CaseKeyWhen(key, transformedBranches)
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 4fd1bc4dd6..0837a3179d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -64,7 +64,7 @@ abstract class Expression extends TreeNode[Expression] {
* Returns true if all the children of this expression have been resolved to a specific schema
* and false if any still contains any unresolved placeholders.
*/
- def childrenResolved: Boolean = !children.exists(!_.resolved)
+ def childrenResolved: Boolean = children.forall(_.resolved)
/**
* Returns a string representation of this expression that does not have developer centric
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 26c38c56c0..50b0f3ee5f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -353,79 +353,134 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
override def toString: String = s"if ($predicate) $trueValue else $falseValue"
}
+trait CaseWhenLike extends Expression {
+ self: Product =>
+
+ type EvaluatedType = Any
+
+ // Note that `branches` are considered in consecutive pairs (cond, val), and the optional last
+ // element is the value for the default catch-all case (if provided).
+ // Hence, `branches` consists of at least two elements, and can have an odd or even length.
+ def branches: Seq[Expression]
+
+ @transient lazy val whenList =
+ branches.sliding(2, 2).collect { case Seq(whenExpr, _) => whenExpr }.toSeq
+ @transient lazy val thenList =
+ branches.sliding(2, 2).collect { case Seq(_, thenExpr) => thenExpr }.toSeq
+ val elseValue = if (branches.length % 2 == 0) None else Option(branches.last)
+
+ // both then and else val should be considered.
+ def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType)
+ def valueTypesEqual: Boolean = valueTypes.distinct.size <= 1
+
+ override def dataType: DataType = {
+ if (!resolved) {
+ throw new UnresolvedException(this, "cannot resolve due to differing types in some branches")
+ }
+ valueTypes.head
+ }
+
+ override def nullable: Boolean = {
+ // If no value is nullable and no elseValue is provided, the whole statement defaults to null.
+ thenList.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true))
+ }
+}
+
// scalastyle:off
/**
* Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
* Refer to this link for the corresponding semantics:
* https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions
- *
- * The other form of case statements "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END" gets
- * translated to this form at parsing time. Namely, such a statement gets translated to
- * "CASE WHEN a=b THEN c [WHEN a=d THEN e]* [ELSE f] END".
- *
- * Note that `branches` are considered in consecutive pairs (cond, val), and the optional last
- * element is the value for the default catch-all case (if provided). Hence, `branches` consists of
- * at least two elements, and can have an odd or even length.
*/
// scalastyle:on
-case class CaseWhen(branches: Seq[Expression]) extends Expression {
- type EvaluatedType = Any
+case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike {
+
+ // Use private[this] Array to speed up evaluation.
+ @transient private[this] lazy val branchesArr = branches.toArray
override def children: Seq[Expression] = branches
- override def dataType: DataType = {
- if (!resolved) {
- throw new UnresolvedException(this, "cannot resolve due to differing types in some branches")
+ override lazy val resolved: Boolean =
+ childrenResolved &&
+ whenList.forall(_.dataType == BooleanType) &&
+ valueTypesEqual
+
+ /** Written in imperative fashion for performance considerations. */
+ override def eval(input: Row): Any = {
+ val len = branchesArr.length
+ var i = 0
+ // If all branches fail and an elseVal is not provided, the whole statement
+ // defaults to null, according to Hive's semantics.
+ while (i < len - 1) {
+ if (branchesArr(i).eval(input) == true) {
+ return branchesArr(i + 1).eval(input)
+ }
+ i += 2
+ }
+ var res: Any = null
+ if (i == len - 1) {
+ res = branchesArr(i).eval(input)
}
- branches(1).dataType
+ return res
}
+ override def toString: String = {
+ "CASE" + branches.sliding(2, 2).map {
+ case Seq(cond, value) => s" WHEN $cond THEN $value"
+ case Seq(elseValue) => s" ELSE $elseValue"
+ }.mkString
+ }
+}
+
+// scalastyle:off
+/**
+ * Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END".
+ * Refer to this link for the corresponding semantics:
+ * https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF#LanguageManualUDF-ConditionalFunctions
+ */
+// scalastyle:on
+case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseWhenLike {
+
+ // Use private[this] Array to speed up evaluation.
@transient private[this] lazy val branchesArr = branches.toArray
- @transient private[this] lazy val predicates =
- branches.sliding(2, 2).collect { case Seq(cond, _) => cond }.toSeq
- @transient private[this] lazy val values =
- branches.sliding(2, 2).collect { case Seq(_, value) => value }.toSeq
- @transient private[this] lazy val elseValue =
- if (branches.length % 2 == 0) None else Option(branches.last)
- override def nullable: Boolean = {
- // If no value is nullable and no elseValue is provided, the whole statement defaults to null.
- values.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true))
- }
+ override def children: Seq[Expression] = key +: branches
- override lazy val resolved: Boolean = {
- if (!childrenResolved) {
- false
- } else {
- val allCondBooleans = predicates.forall(_.dataType == BooleanType)
- // both then and else val should be considered.
- val dataTypesEqual = (values ++ elseValue).map(_.dataType).distinct.size <= 1
- allCondBooleans && dataTypesEqual
- }
- }
+ override lazy val resolved: Boolean =
+ childrenResolved && valueTypesEqual
/** Written in imperative fashion for performance considerations. */
override def eval(input: Row): Any = {
+ val evaluatedKey = key.eval(input)
val len = branchesArr.length
var i = 0
// If all branches fail and an elseVal is not provided, the whole statement
// defaults to null, according to Hive's semantics.
- var res: Any = null
while (i < len - 1) {
- if (branchesArr(i).eval(input) == true) {
- res = branchesArr(i + 1).eval(input)
- return res
+ if (equalNullSafe(evaluatedKey, branchesArr(i).eval(input))) {
+ return branchesArr(i + 1).eval(input)
}
i += 2
}
+ var res: Any = null
if (i == len - 1) {
res = branchesArr(i).eval(input)
}
- res
+ return res
+ }
+
+ private def equalNullSafe(l: Any, r: Any) = {
+ if (l == null && r == null) {
+ true
+ } else if (l == null || r == null) {
+ false
+ } else {
+ l == r
+ }
}
override def toString: String = {
- "CASE" + branches.sliding(2, 2).map {
+ s"CASE $key" + branches.sliding(2, 2).map {
case Seq(cond, value) => s" WHEN $cond THEN $value"
case Seq(elseValue) => s" ELSE $elseValue"
}.mkString
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index faaa55aa5e..88d36d153c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -850,6 +850,32 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
assert(CaseWhen(Seq(c2, c4_notNull, c3, c5)).nullable === true)
}
+ test("case key when") {
+ val row = create_row(null, 1, 2, "a", "b", "c")
+ val c1 = 'a.int.at(0)
+ val c2 = 'a.int.at(1)
+ val c3 = 'a.int.at(2)
+ val c4 = 'a.string.at(3)
+ val c5 = 'a.string.at(4)
+ val c6 = 'a.string.at(5)
+
+ val literalNull = Literal.create(null, BooleanType)
+ val literalInt = Literal(1)
+ val literalString = Literal("a")
+
+ checkEvaluation(CaseKeyWhen(c1, Seq(c2, c4, c5)), "b", row)
+ checkEvaluation(CaseKeyWhen(c1, Seq(c2, c4, literalNull, c5, c6)), "b", row)
+ checkEvaluation(CaseKeyWhen(c2, Seq(literalInt, c4, c5)), "a", row)
+ checkEvaluation(CaseKeyWhen(c2, Seq(c1, c4, c5)), "b", row)
+ checkEvaluation(CaseKeyWhen(c4, Seq(literalString, c2, c3)), 1, row)
+ checkEvaluation(CaseKeyWhen(c4, Seq(c1, c3, c5, c2, Literal(3))), 3, row)
+
+ checkEvaluation(CaseKeyWhen(literalInt, Seq(c2, c4, c5)), "a", row)
+ checkEvaluation(CaseKeyWhen(literalString, Seq(c5, c2, c4, c3)), 2, row)
+ checkEvaluation(CaseKeyWhen(literalInt, Seq(c5, c2, c4, c3)), null, row)
+ checkEvaluation(CaseKeyWhen(literalNull, Seq(c5, c2, c1, c3)), 2, row)
+ }
+
test("complex type") {
val row = create_row(
"^Ba*n", // 0
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
index 481ed49248..4a54120ba8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
@@ -357,11 +357,12 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
* TODO: This can be optimized to use broadcast join when replacementMap is large.
*/
private def replaceCol(col: StructField, replacementMap: Map[_, _]): Column = {
- val branches: Seq[Expression] = replacementMap.flatMap { case (source, target) =>
- df.col(col.name).equalTo(lit(source).cast(col.dataType)).expr ::
- lit(target).cast(col.dataType).expr :: Nil
+ val keyExpr = df.col(col.name).expr
+ def buildExpr(v: Any) = Cast(Literal(v), keyExpr.dataType)
+ val branches = replacementMap.flatMap { case (source, target) =>
+ Seq(buildExpr(source), buildExpr(target))
}.toSeq
- new Column(CaseWhen(branches ++ Seq(df.col(col.name).expr))).as(col.name)
+ new Column(CaseKeyWhen(keyExpr, branches :+ keyExpr)).as(col.name)
}
private def convertToDouble(v: Any): Double = v match {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 4e51473979..6176aee25e 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -1246,16 +1246,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
case Token("TOK_FUNCTION", Token(WHEN(), Nil) :: branches) =>
CaseWhen(branches.map(nodeToExpr))
case Token("TOK_FUNCTION", Token(CASE(), Nil) :: branches) =>
- val transformed = branches.drop(1).sliding(2, 2).map {
- case Seq(condVal, value) =>
- // FIXME (SPARK-2155): the key will get evaluated for multiple times in CaseWhen's eval().
- // Hence effectful / non-deterministic key expressions are *not* supported at the moment.
- // We should consider adding new Expressions to get around this.
- Seq(EqualTo(nodeToExpr(branches(0)), nodeToExpr(condVal)),
- nodeToExpr(value))
- case Seq(elseVal) => Seq(nodeToExpr(elseVal))
- }.toSeq.reduce(_ ++ _)
- CaseWhen(transformed)
+ val keyExpr = nodeToExpr(branches.head)
+ CaseKeyWhen(keyExpr, branches.drop(1).map(nodeToExpr))
/* Complex datatype manipulation */
case Token("[", child :: ordinal :: Nil) =>
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index 616352d223..c605f10175 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -751,4 +751,11 @@ class SQLQuerySuite extends QueryTest {
(6, "c", 0, 6)
).map(i => Row(i._1, i._2, i._3, i._4)))
}
+
+ test("test case key when") {
+ (1 to 5).map(i => (i, i.toString)).toDF("k", "v").registerTempTable("t")
+ checkAnswer(
+ sql("SELECT CASE k WHEN 2 THEN 22 WHEN 4 THEN 44 ELSE 0 END, v FROM t"),
+ Row(0, "1") :: Row(22, "2") :: Row(0, "3") :: Row(44, "4") :: Row(0, "5") :: Nil)
+ }
}