diff options
author | Reynold Xin <rxin@databricks.com> | 2016-01-13 12:44:35 -0800 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-01-13 12:44:35 -0800 |
commit | cbbcd8e4250aeec700f04c231f8be2f787243f1f (patch) | |
tree | 87741859e4b7ca40feac2829f164e3e03bb4b167 /sql/core | |
parent | c2ea79f96acd076351b48162644ed1cff4c8e090 (diff) | |
download | spark-cbbcd8e4250aeec700f04c231f8be2f787243f1f.tar.gz spark-cbbcd8e4250aeec700f04c231f8be2f787243f1f.tar.bz2 spark-cbbcd8e4250aeec700f04c231f8be2f787243f1f.zip |
[SPARK-12791][SQL] Simplify CaseWhen by breaking "branches" into "conditions" and "values"
This pull request rewrites CaseWhen expression to break the single, monolithic "branches" field into a sequence of tuples (Seq[(condition, value)]) and an explicit optional elseValue field.
Prior to this pull request, each even position in "branches" represents the condition for each branch, and each odd position represents the value for each branch. The use of them have been pretty confusing with a lot sliding windows or grouped(2) calls.
Author: Reynold Xin <rxin@databricks.com>
Closes #10734 from rxin/simplify-case.
Diffstat (limited to 'sql/core')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/Column.scala | 19 | ||||
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/functions.scala | 2 |
2 files changed, 11 insertions, 10 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index e8c61d6e01..6a020f9f28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -437,8 +437,11 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @since 1.4.0 */ def when(condition: Column, value: Any): Column = this.expr match { - case CaseWhen(branches: Seq[Expression]) => - withExpr { CaseWhen(branches ++ Seq(lit(condition).expr, lit(value).expr)) } + case CaseWhen(branches, None) => + withExpr { CaseWhen(branches :+ (condition.expr, lit(value).expr)) } + case CaseWhen(branches, Some(_)) => + throw new IllegalArgumentException( + "when() cannot be applied once otherwise() is applied") case _ => throw new IllegalArgumentException( "when() can only be applied on a Column previously generated by when() function") @@ -466,13 +469,11 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @since 1.4.0 */ def otherwise(value: Any): Column = this.expr match { - case CaseWhen(branches: Seq[Expression]) => - if (branches.size % 2 == 0) { - withExpr { CaseWhen(branches :+ lit(value).expr) } - } else { - throw new IllegalArgumentException( - "otherwise() can only be applied once on a Column previously generated by when()") - } + case CaseWhen(branches, None) => + withExpr { CaseWhen(branches, Option(lit(value).expr)) } + case CaseWhen(branches, Some(_)) => + throw new IllegalArgumentException( + "otherwise() can only be applied once on a Column previously generated by when()") case _ => throw new IllegalArgumentException( "otherwise() can only be applied on a Column previously generated by when()") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 71fea2716b..b8ea2261e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1042,7 +1042,7 @@ object functions extends LegacyFunctions { * @since 1.4.0 */ def when(condition: Column, value: Any): Column = withExpr { - CaseWhen(Seq(condition.expr, lit(value).expr)) + CaseWhen(Seq((condition.expr, lit(value).expr))) } /** |