aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2016-01-13 12:44:35 -0800
committerReynold Xin <rxin@databricks.com>2016-01-13 12:44:35 -0800
commitcbbcd8e4250aeec700f04c231f8be2f787243f1f (patch)
tree87741859e4b7ca40feac2829f164e3e03bb4b167 /sql/core
parentc2ea79f96acd076351b48162644ed1cff4c8e090 (diff)
downloadspark-cbbcd8e4250aeec700f04c231f8be2f787243f1f.tar.gz
spark-cbbcd8e4250aeec700f04c231f8be2f787243f1f.tar.bz2
spark-cbbcd8e4250aeec700f04c231f8be2f787243f1f.zip
[SPARK-12791][SQL] Simplify CaseWhen by breaking "branches" into "conditions" and "values"
This pull request rewrites CaseWhen expression to break the single, monolithic "branches" field into a sequence of tuples (Seq[(condition, value)]) and an explicit optional elseValue field. Prior to this pull request, each even position in "branches" represents the condition for each branch, and each odd position represents the value for each branch. The use of them have been pretty confusing with a lot sliding windows or grouped(2) calls. Author: Reynold Xin <rxin@databricks.com> Closes #10734 from rxin/simplify-case.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala19
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala2
2 files changed, 11 insertions, 10 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index e8c61d6e01..6a020f9f28 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -437,8 +437,11 @@ class Column(protected[sql] val expr: Expression) extends Logging {
* @since 1.4.0
*/
def when(condition: Column, value: Any): Column = this.expr match {
- case CaseWhen(branches: Seq[Expression]) =>
- withExpr { CaseWhen(branches ++ Seq(lit(condition).expr, lit(value).expr)) }
+ case CaseWhen(branches, None) =>
+ withExpr { CaseWhen(branches :+ (condition.expr, lit(value).expr)) }
+ case CaseWhen(branches, Some(_)) =>
+ throw new IllegalArgumentException(
+ "when() cannot be applied once otherwise() is applied")
case _ =>
throw new IllegalArgumentException(
"when() can only be applied on a Column previously generated by when() function")
@@ -466,13 +469,11 @@ class Column(protected[sql] val expr: Expression) extends Logging {
* @since 1.4.0
*/
def otherwise(value: Any): Column = this.expr match {
- case CaseWhen(branches: Seq[Expression]) =>
- if (branches.size % 2 == 0) {
- withExpr { CaseWhen(branches :+ lit(value).expr) }
- } else {
- throw new IllegalArgumentException(
- "otherwise() can only be applied once on a Column previously generated by when()")
- }
+ case CaseWhen(branches, None) =>
+ withExpr { CaseWhen(branches, Option(lit(value).expr)) }
+ case CaseWhen(branches, Some(_)) =>
+ throw new IllegalArgumentException(
+ "otherwise() can only be applied once on a Column previously generated by when()")
case _ =>
throw new IllegalArgumentException(
"otherwise() can only be applied on a Column previously generated by when()")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 71fea2716b..b8ea2261e9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1042,7 +1042,7 @@ object functions extends LegacyFunctions {
* @since 1.4.0
*/
def when(condition: Column, value: Any): Column = withExpr {
- CaseWhen(Seq(condition.expr, lit(value).expr))
+ CaseWhen(Seq((condition.expr, lit(value).expr)))
}
/**