aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/sql/column.py24
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala26
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala137
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala9
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala15
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala51
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala19
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala2
12 files changed, 156 insertions, 138 deletions
diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
index 900def59d2..320451c52c 100644
--- a/python/pyspark/sql/column.py
+++ b/python/pyspark/sql/column.py
@@ -368,12 +368,12 @@ class Column(object):
>>> from pyspark.sql import functions as F
>>> df.select(df.name, F.when(df.age > 4, 1).when(df.age < 3, -1).otherwise(0)).show()
- +-----+--------------------------------------------------------+
- | name|CASE WHEN (age > 4) THEN 1 WHEN (age < 3) THEN -1 ELSE 0|
- +-----+--------------------------------------------------------+
- |Alice| -1|
- | Bob| 1|
- +-----+--------------------------------------------------------+
+ +-----+------------------------------------------------------------+
+ | name|CASE WHEN (age > 4) THEN 1 WHEN (age < 3) THEN -1 ELSE 0 END|
+ +-----+------------------------------------------------------------+
+ |Alice| -1|
+ | Bob| 1|
+ +-----+------------------------------------------------------------+
"""
if not isinstance(condition, Column):
raise TypeError("condition should be a Column")
@@ -393,12 +393,12 @@ class Column(object):
>>> from pyspark.sql import functions as F
>>> df.select(df.name, F.when(df.age > 3, 1).otherwise(0)).show()
- +-----+---------------------------------+
- | name|CASE WHEN (age > 3) THEN 1 ELSE 0|
- +-----+---------------------------------+
- |Alice| 0|
- | Bob| 1|
- +-----+---------------------------------+
+ +-----+-------------------------------------+
+ | name|CASE WHEN (age > 3) THEN 1 ELSE 0 END|
+ +-----+-------------------------------------+
+ |Alice| 0|
+ | Bob| 1|
+ +-----+-------------------------------------+
"""
v = value._jc if isinstance(value, Column) else value
jc = self._jc.otherwise(v)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala
index c87b6c8e95..d0fbdacf6e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala
@@ -752,7 +752,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
/* Case statements */
case Token("TOK_FUNCTION", Token(WHEN(), Nil) :: branches) =>
- CaseWhen(branches.map(nodeToExpr))
+ CaseWhen.createFromParser(branches.map(nodeToExpr))
case Token("TOK_FUNCTION", Token(CASE(), Nil) :: branches) =>
val keyExpr = nodeToExpr(branches.head)
CaseKeyWhen(keyExpr, branches.drop(1).map(nodeToExpr))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index 6ec408a673..85ff4ea0c9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -305,7 +305,8 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser {
throw new AnalysisException(s"invalid function approximate($s) $udfName")
}
}
- | CASE ~> whenThenElse ^^ CaseWhen
+ | CASE ~> whenThenElse ^^
+ { case branches => CaseWhen.createFromParser(branches) }
| CASE ~> expression ~ whenThenElse ^^
{ case keyPart ~ branches => CaseKeyWhen(keyPart, branches) }
)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 980b5d52fa..2737fe32cd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -621,14 +621,24 @@ object HiveTypeCoercion {
case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual =>
val maybeCommonType = findWiderCommonType(c.valueTypes)
maybeCommonType.map { commonType =>
- val castedBranches = c.branches.grouped(2).map {
- case Seq(when, value) if value.dataType != commonType =>
- Seq(when, Cast(value, commonType))
- case Seq(elseVal) if elseVal.dataType != commonType =>
- Seq(Cast(elseVal, commonType))
- case other => other
- }.reduce(_ ++ _)
- CaseWhen(castedBranches)
+ var changed = false
+ val newBranches = c.branches.map { case (condition, value) =>
+ if (value.dataType.sameType(commonType)) {
+ (condition, value)
+ } else {
+ changed = true
+ (condition, Cast(value, commonType))
+ }
+ }
+ val newElseValue = c.elseValue.map { value =>
+ if (value.dataType.sameType(commonType)) {
+ value
+ } else {
+ changed = true
+ Cast(value, commonType)
+ }
+ }
+ if (changed) CaseWhen(newBranches, newElseValue) else c
}.getOrElse(c)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index 5a1462433d..8cc7bc1da2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -81,44 +81,39 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
/**
* Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
* When a = true, returns b; when c = true, returns d; else returns e.
+ *
+ * @param branches seq of (branch condition, branch value)
+ * @param elseValue optional value for the else branch
*/
-case class CaseWhen(branches: Seq[Expression]) extends Expression {
-
- // Use private[this] Array to speed up evaluation.
- @transient private[this] lazy val branchesArr = branches.toArray
-
- override def children: Seq[Expression] = branches
-
- @transient lazy val whenList =
- branches.sliding(2, 2).collect { case Seq(whenExpr, _) => whenExpr }.toSeq
-
- @transient lazy val thenList =
- branches.sliding(2, 2).collect { case Seq(_, thenExpr) => thenExpr }.toSeq
+case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[Expression] = None)
+ extends Expression {
- val elseValue = if (branches.length % 2 == 0) None else Option(branches.last)
+ override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue
// both then and else expressions should be considered.
- def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType)
+ def valueTypes: Seq[DataType] = branches.map(_._2.dataType) ++ elseValue.map(_.dataType)
+
def valueTypesEqual: Boolean = valueTypes.size <= 1 || valueTypes.sliding(2, 1).forall {
case Seq(dt1, dt2) => dt1.sameType(dt2)
}
- override def dataType: DataType = thenList.head.dataType
+ override def dataType: DataType = branches.head._2.dataType
override def nullable: Boolean = {
- // If no value is nullable and no elseValue is provided, the whole statement defaults to null.
- thenList.exists(_.nullable) || elseValue.map(_.nullable).getOrElse(true)
+ // Result is nullable if any of the branch is nullable, or if the else value is nullable
+ branches.exists(_._2.nullable) || elseValue.map(_.nullable).getOrElse(true)
}
override def checkInputDataTypes(): TypeCheckResult = {
+ // Make sure all branch conditions are boolean types.
if (valueTypesEqual) {
- if (whenList.forall(_.dataType == BooleanType)) {
+ if (branches.forall(_._1.dataType == BooleanType)) {
TypeCheckResult.TypeCheckSuccess
} else {
- val index = whenList.indexWhere(_.dataType != BooleanType)
+ val index = branches.indexWhere(_._1.dataType != BooleanType)
TypeCheckResult.TypeCheckFailure(
s"WHEN expressions in CaseWhen should all be boolean type, " +
- s"but the ${index + 1}th when expression's type is ${whenList(index)}")
+ s"but the ${index + 1}th when expression's type is ${branches(index)._1}")
}
} else {
TypeCheckResult.TypeCheckFailure(
@@ -127,31 +122,26 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression {
}
override def eval(input: InternalRow): Any = {
- // Written in imperative fashion for performance considerations
- val len = branchesArr.length
var i = 0
- // If all branches fail and an elseVal is not provided, the whole statement
- // defaults to null, according to Hive's semantics.
- while (i < len - 1) {
- if (branchesArr(i).eval(input) == true) {
- return branchesArr(i + 1).eval(input)
+ while (i < branches.size) {
+ if (java.lang.Boolean.TRUE.equals(branches(i)._1.eval(input))) {
+ return branches(i)._2.eval(input)
}
- i += 2
+ i += 1
}
- var res: Any = null
- if (i == len - 1) {
- res = branchesArr(i).eval(input)
+ if (elseValue.isDefined) {
+ return elseValue.get.eval(input)
+ } else {
+ return null
}
- return res
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- val len = branchesArr.length
val got = ctx.freshName("got")
- val cases = (0 until len/2).map { i =>
- val cond = branchesArr(i * 2).gen(ctx)
- val res = branchesArr(i * 2 + 1).gen(ctx)
+ val cases = branches.map { case (condition, value) =>
+ val cond = condition.gen(ctx)
+ val res = value.gen(ctx)
s"""
if (!$got) {
${cond.code}
@@ -165,17 +155,19 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression {
"""
}.mkString("\n")
- val other = if (len % 2 == 1) {
- val res = branchesArr(len - 1).gen(ctx)
- s"""
+ val elseCase = {
+ if (elseValue.isDefined) {
+ val res = elseValue.get.gen(ctx)
+ s"""
if (!$got) {
${res.code}
${ev.isNull} = ${res.isNull};
${ev.value} = ${res.value};
}
- """
- } else {
- ""
+ """
+ } else {
+ ""
+ }
}
s"""
@@ -183,32 +175,42 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression {
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
$cases
- $other
+ $elseCase
"""
}
override def toString: String = {
- "CASE" + branches.sliding(2, 2).map {
- case Seq(cond, value) => s" WHEN $cond THEN $value"
- case Seq(elseValue) => s" ELSE $elseValue"
- }.mkString
+ val cases = branches.map { case (c, v) => s" WHEN $c THEN $v" }.mkString
+ val elseCase = elseValue.map(" ELSE " + _).getOrElse("")
+ "CASE" + cases + elseCase + " END"
}
override def sql: String = {
- val branchesSQL = branches.map(_.sql)
- val (cases, maybeElse) = if (branches.length % 2 == 0) {
- (branchesSQL, None)
- } else {
- (branchesSQL.init, Some(branchesSQL.last))
- }
+ val cases = branches.map { case (c, v) => s" WHEN ${c.sql} THEN ${v.sql}" }.mkString
+ val elseCase = elseValue.map(" ELSE " + _.sql).getOrElse("")
+ "CASE" + cases + elseCase + " END"
+ }
+}
- val head = s"CASE "
- val tail = maybeElse.map(e => s" ELSE $e").getOrElse("") + " END"
- val body = cases.grouped(2).map {
- case Seq(whenExpr, thenExpr) => s"WHEN $whenExpr THEN $thenExpr"
- }.mkString(" ")
+/** Factory methods for CaseWhen. */
+object CaseWhen {
- head + body + tail
+ def apply(branches: Seq[(Expression, Expression)], elseValue: Expression): CaseWhen = {
+ CaseWhen(branches, Option(elseValue))
+ }
+
+ /**
+ * A factory method to faciliate the creation of this expression when used in parsers.
+ * @param branches Expressions at even position are the branch conditions, and expressions at odd
+ * position are branch values.
+ */
+ def createFromParser(branches: Seq[Expression]): CaseWhen = {
+ val cases = branches.grouped(2).flatMap {
+ case cond :: value :: Nil => Some((cond, value))
+ case value :: Nil => None
+ }.toArray.toSeq // force materialization to make the seq serializable
+ val elseValue = if (branches.size % 2 == 1) Some(branches.last) else None
+ CaseWhen(cases, elseValue)
}
}
@@ -218,17 +220,12 @@ case class CaseWhen(branches: Seq[Expression]) extends Expression {
*/
object CaseKeyWhen {
def apply(key: Expression, branches: Seq[Expression]): CaseWhen = {
- val newBranches = branches.zipWithIndex.map { case (expr, i) =>
- if (i % 2 == 0 && i != branches.size - 1) {
- // If this expression is at even position, then it is either a branch condition, or
- // the very last value that is the "else value". The "i != branches.size - 1" makes
- // sure we are not adding an EqualTo to the "else value".
- EqualTo(key, expr)
- } else {
- expr
- }
- }
- CaseWhen(newBranches)
+ val cases = branches.grouped(2).flatMap {
+ case cond :: value :: Nil => Some((EqualTo(key, cond), value))
+ case value :: Nil => None
+ }.toArray.toSeq // force materialization to make the seq serializable
+ val elseValue = if (branches.size % 2 == 1) Some(branches.last) else None
+ CaseWhen(cases, elseValue)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index d4be545a35..d0b29aa01f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -315,6 +315,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
} else {
arg
}
+ case tuple @ (arg1: TreeNode[_], arg2: TreeNode[_]) =>
+ val newChild1 = nextOperation(arg1.asInstanceOf[BaseType], rule)
+ val newChild2 = nextOperation(arg2.asInstanceOf[BaseType], rule)
+ if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) {
+ changed = true
+ (newChild1, newChild2)
+ } else {
+ tuple
+ }
case other => other
}
case nonChild: AnyRef => nonChild
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index cf84855885..975cd87d09 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -239,7 +239,7 @@ class AnalysisSuite extends AnalysisTest {
test("SPARK-12102: Ignore nullablity when comparing two sides of case") {
val relation = LocalRelation('a.struct('x.int), 'b.struct('x.int.withNullability(false)))
- val plan = relation.select(CaseWhen(Seq(Literal(true), 'a, 'b)).as("val"))
+ val plan = relation.select(CaseWhen(Seq((Literal(true), 'a.attr)), 'b).as("val"))
assertAnalysisSuccess(plan)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index 0521ed848c..59549e3998 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -132,13 +132,13 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertErrorForDifferingTypes(If('booleanField, 'intField, 'booleanField))
assertError(
- CaseWhen(Seq('booleanField, 'intField, 'booleanField, 'mapField)),
+ CaseWhen(Seq(('booleanField.attr, 'intField.attr), ('booleanField.attr, 'mapField.attr))),
"THEN and ELSE expressions should all be same type or coercible to a common type")
assertError(
CaseKeyWhen('intField, Seq('intField, 'stringField, 'intField, 'mapField)),
"THEN and ELSE expressions should all be same type or coercible to a common type")
assertError(
- CaseWhen(Seq('booleanField, 'intField, 'intField, 'intField)),
+ CaseWhen(Seq(('booleanField.attr, 'intField.attr), ('intField.attr, 'intField.attr))),
"WHEN expressions in CaseWhen should all be boolean type")
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index 40378c6727..b1f6c0b802 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -308,15 +308,14 @@ class HiveTypeCoercionSuite extends PlanTest {
CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a")))
)
ruleTest(HiveTypeCoercion.CaseWhenCoercion,
- CaseWhen(Seq(Literal(true), Literal(1.2), Literal.create(1, DecimalType(7, 2)))),
- CaseWhen(Seq(
- Literal(true), Literal(1.2), Cast(Literal.create(1, DecimalType(7, 2)), DoubleType)))
+ CaseWhen(Seq((Literal(true), Literal(1.2))), Literal.create(1, DecimalType(7, 2))),
+ CaseWhen(Seq((Literal(true), Literal(1.2))),
+ Cast(Literal.create(1, DecimalType(7, 2)), DoubleType))
)
ruleTest(HiveTypeCoercion.CaseWhenCoercion,
- CaseWhen(Seq(Literal(true), Literal(100L), Literal.create(1, DecimalType(7, 2)))),
- CaseWhen(Seq(
- Literal(true), Cast(Literal(100L), DecimalType(22, 2)),
- Cast(Literal.create(1, DecimalType(7, 2)), DecimalType(22, 2))))
+ CaseWhen(Seq((Literal(true), Literal(100L))), Literal.create(1, DecimalType(7, 2))),
+ CaseWhen(Seq((Literal(true), Cast(Literal(100L), DecimalType(22, 2)))),
+ Cast(Literal.create(1, DecimalType(7, 2)), DecimalType(22, 2)))
)
}
@@ -452,7 +451,7 @@ class HiveTypeCoercionSuite extends PlanTest {
val expectedTypes = Seq(DecimalType(10, 5), DecimalType(10, 5), DecimalType(15, 5),
DecimalType(25, 5), DoubleType, DoubleType)
- rightTypes.zip(expectedTypes).map { case (rType, expectedType) =>
+ rightTypes.zip(expectedTypes).foreach { case (rType, expectedType) =>
val plan2 = LocalRelation(
AttributeReference("r", rType)())
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
index 4029da5925..3c581ecdaf 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
@@ -80,38 +80,39 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
val c5 = 'a.string.at(4)
val c6 = 'a.string.at(5)
- checkEvaluation(CaseWhen(Seq(c1, c4, c6)), "c", row)
- checkEvaluation(CaseWhen(Seq(c2, c4, c6)), "c", row)
- checkEvaluation(CaseWhen(Seq(c3, c4, c6)), "a", row)
- checkEvaluation(CaseWhen(Seq(Literal.create(null, BooleanType), c4, c6)), "c", row)
- checkEvaluation(CaseWhen(Seq(Literal.create(false, BooleanType), c4, c6)), "c", row)
- checkEvaluation(CaseWhen(Seq(Literal.create(true, BooleanType), c4, c6)), "a", row)
-
- checkEvaluation(CaseWhen(Seq(c3, c4, c2, c5, c6)), "a", row)
- checkEvaluation(CaseWhen(Seq(c2, c4, c3, c5, c6)), "b", row)
- checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5, c6)), "c", row)
- checkEvaluation(CaseWhen(Seq(c1, c4, c2, c5)), null, row)
-
- assert(CaseWhen(Seq(c2, c4, c6)).nullable === true)
- assert(CaseWhen(Seq(c2, c4, c3, c5, c6)).nullable === true)
- assert(CaseWhen(Seq(c2, c4, c3, c5)).nullable === true)
+ checkEvaluation(CaseWhen(Seq((c1, c4)), c6), "c", row)
+ checkEvaluation(CaseWhen(Seq((c2, c4)), c6), "c", row)
+ checkEvaluation(CaseWhen(Seq((c3, c4)), c6), "a", row)
+ checkEvaluation(CaseWhen(Seq((Literal.create(null, BooleanType), c4)), c6), "c", row)
+ checkEvaluation(CaseWhen(Seq((Literal.create(false, BooleanType), c4)), c6), "c", row)
+ checkEvaluation(CaseWhen(Seq((Literal.create(true, BooleanType), c4)), c6), "a", row)
+
+ checkEvaluation(CaseWhen(Seq((c3, c4), (c2, c5)), c6), "a", row)
+ checkEvaluation(CaseWhen(Seq((c2, c4), (c3, c5)), c6), "b", row)
+ checkEvaluation(CaseWhen(Seq((c1, c4), (c2, c5)), c6), "c", row)
+ checkEvaluation(CaseWhen(Seq((c1, c4), (c2, c5))), null, row)
+
+ assert(CaseWhen(Seq((c2, c4)), c6).nullable === true)
+ assert(CaseWhen(Seq((c2, c4), (c3, c5)), c6).nullable === true)
+ assert(CaseWhen(Seq((c2, c4), (c3, c5))).nullable === true)
val c4_notNull = 'a.boolean.notNull.at(3)
val c5_notNull = 'a.boolean.notNull.at(4)
val c6_notNull = 'a.boolean.notNull.at(5)
- assert(CaseWhen(Seq(c2, c4_notNull, c6_notNull)).nullable === false)
- assert(CaseWhen(Seq(c2, c4, c6_notNull)).nullable === true)
- assert(CaseWhen(Seq(c2, c4_notNull, c6)).nullable === true)
+ assert(CaseWhen(Seq((c2, c4_notNull)), c6_notNull).nullable === false)
+ assert(CaseWhen(Seq((c2, c4)), c6_notNull).nullable === true)
+ assert(CaseWhen(Seq((c2, c4_notNull))).nullable === true)
+ assert(CaseWhen(Seq((c2, c4_notNull)), c6).nullable === true)
- assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6_notNull)).nullable === false)
- assert(CaseWhen(Seq(c2, c4, c3, c5_notNull, c6_notNull)).nullable === true)
- assert(CaseWhen(Seq(c2, c4_notNull, c3, c5, c6_notNull)).nullable === true)
- assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull, c6)).nullable === true)
+ assert(CaseWhen(Seq((c2, c4_notNull), (c3, c5_notNull)), c6_notNull).nullable === false)
+ assert(CaseWhen(Seq((c2, c4), (c3, c5_notNull)), c6_notNull).nullable === true)
+ assert(CaseWhen(Seq((c2, c4_notNull), (c3, c5)), c6_notNull).nullable === true)
+ assert(CaseWhen(Seq((c2, c4_notNull), (c3, c5_notNull)), c6).nullable === true)
- assert(CaseWhen(Seq(c2, c4_notNull, c3, c5_notNull)).nullable === true)
- assert(CaseWhen(Seq(c2, c4, c3, c5_notNull)).nullable === true)
- assert(CaseWhen(Seq(c2, c4_notNull, c3, c5)).nullable === true)
+ assert(CaseWhen(Seq((c2, c4_notNull), (c3, c5_notNull))).nullable === true)
+ assert(CaseWhen(Seq((c2, c4), (c3, c5_notNull))).nullable === true)
+ assert(CaseWhen(Seq((c2, c4_notNull), (c3, c5))).nullable === true)
}
test("case key when") {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index e8c61d6e01..6a020f9f28 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -437,8 +437,11 @@ class Column(protected[sql] val expr: Expression) extends Logging {
* @since 1.4.0
*/
def when(condition: Column, value: Any): Column = this.expr match {
- case CaseWhen(branches: Seq[Expression]) =>
- withExpr { CaseWhen(branches ++ Seq(lit(condition).expr, lit(value).expr)) }
+ case CaseWhen(branches, None) =>
+ withExpr { CaseWhen(branches :+ (condition.expr, lit(value).expr)) }
+ case CaseWhen(branches, Some(_)) =>
+ throw new IllegalArgumentException(
+ "when() cannot be applied once otherwise() is applied")
case _ =>
throw new IllegalArgumentException(
"when() can only be applied on a Column previously generated by when() function")
@@ -466,13 +469,11 @@ class Column(protected[sql] val expr: Expression) extends Logging {
* @since 1.4.0
*/
def otherwise(value: Any): Column = this.expr match {
- case CaseWhen(branches: Seq[Expression]) =>
- if (branches.size % 2 == 0) {
- withExpr { CaseWhen(branches :+ lit(value).expr) }
- } else {
- throw new IllegalArgumentException(
- "otherwise() can only be applied once on a Column previously generated by when()")
- }
+ case CaseWhen(branches, None) =>
+ withExpr { CaseWhen(branches, Option(lit(value).expr)) }
+ case CaseWhen(branches, Some(_)) =>
+ throw new IllegalArgumentException(
+ "otherwise() can only be applied once on a Column previously generated by when()")
case _ =>
throw new IllegalArgumentException(
"otherwise() can only be applied on a Column previously generated by when()")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 71fea2716b..b8ea2261e9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1042,7 +1042,7 @@ object functions extends LegacyFunctions {
* @since 1.4.0
*/
def when(condition: Column, value: Any): Column = withExpr {
- CaseWhen(Seq(condition.expr, lit(value).expr))
+ CaseWhen(Seq((condition.expr, lit(value).expr)))
}
/**