aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala61
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala24
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala21
3 files changed, 106 insertions, 0 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 4773dedf72..42f5bcda49 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -328,6 +328,67 @@ class Column(protected[sql] val expr: Expression) extends Logging {
def eqNullSafe(other: Any): Column = this <=> other
/**
+ * Evaluates a list of conditions and returns one of multiple possible result expressions.
+ * If otherwise is not defined at the end, null is returned for unmatched conditions.
+ *
+ * {{{
+ * // Example: encoding gender string column into integer.
+ *
+ * // Scala:
+ * people.select(when(people("gender") === "male", 0)
+ * .when(people("gender") === "female", 1)
+ * .otherwise(2))
+ *
+ * // Java:
+ * people.select(when(col("gender").equalTo("male"), 0)
+ * .when(col("gender").equalTo("female"), 1)
+ * .otherwise(2))
+ * }}}
+ *
+ * @group expr_ops
+ */
+ def when(condition: Column, value: Any):Column = this.expr match {
+ case CaseWhen(branches: Seq[Expression]) =>
+ CaseWhen(branches ++ Seq(lit(condition).expr, lit(value).expr))
+ case _ =>
+ throw new IllegalArgumentException(
+ "when() can only be applied on a Column previously generated by when() function")
+ }
+
+ /**
+ * Evaluates a list of conditions and returns one of multiple possible result expressions.
+ * If otherwise is not defined at the end, null is returned for unmatched conditions.
+ *
+ * {{{
+ * // Example: encoding gender string column into integer.
+ *
+ * // Scala:
+ * people.select(when(people("gender") === "male", 0)
+ * .when(people("gender") === "female", 1)
+ * .otherwise(2))
+ *
+ * // Java:
+ * people.select(when(col("gender").equalTo("male"), 0)
+ * .when(col("gender").equalTo("female"), 1)
+ * .otherwise(2))
+ * }}}
+ *
+ * @group expr_ops
+ */
+ def otherwise(value: Any):Column = this.expr match {
+ case CaseWhen(branches: Seq[Expression]) =>
+ if (branches.size % 2 == 0) {
+ CaseWhen(branches :+ lit(value).expr)
+ } else {
+ throw new IllegalArgumentException(
+ "otherwise() can only be applied once on a Column previously generated by when()")
+ }
+ case _ =>
+ throw new IllegalArgumentException(
+ "otherwise() can only be applied on a Column previously generated by when()")
+ }
+
+ /**
* True if the current column is between the lower bound and upper bound, inclusive.
*
* @group java_expr_ops
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 215787e40b..099e1d8f03 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -420,6 +420,30 @@ object functions {
def not(e: Column): Column = !e
/**
+ * Evaluates a list of conditions and returns one of multiple possible result expressions.
+ * If otherwise is not defined at the end, null is returned for unmatched conditions.
+ *
+ * {{{
+ * // Example: encoding gender string column into integer.
+ *
+ * // Scala:
+ * people.select(when(people("gender") === "male", 0)
+ * .when(people("gender") === "female", 1)
+ * .otherwise(2))
+ *
+ * // Java:
+ * people.select(when(col("gender").equalTo("male"), 0)
+ * .when(col("gender").equalTo("female"), 1)
+ * .otherwise(2))
+ * }}}
+ *
+ * @group normal_funcs
+ */
+ def when(condition: Column, value: Any): Column = {
+ CaseWhen(Seq(condition.expr, lit(value).expr))
+ }
+
+ /**
* Generate a random column with i.i.d. samples from U[0.0, 1.0].
*
* @group normal_funcs
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index d96186c268..269e185543 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -255,6 +255,27 @@ class ColumnExpressionSuite extends QueryTest {
Row(false, true) :: Row(true, false) :: Row(true, true) :: Nil)
}
+ test("SPARK-7321 when conditional statements") {
+ val testData = (1 to 3).map(i => (i, i.toString)).toDF("key", "value")
+
+ checkAnswer(
+ testData.select(when($"key" === 1, -1).when($"key" === 2, -2).otherwise(0)),
+ Seq(Row(-1), Row(-2), Row(0))
+ )
+
+ // Without the ending otherwise, return null for unmatched conditions.
+ // Also test putting a non-literal value in the expression.
+ checkAnswer(
+ testData.select(when($"key" === 1, lit(0) - $"key").when($"key" === 2, -2)),
+ Seq(Row(-1), Row(-2), Row(null))
+ )
+
+ // Test error handling for invalid expressions.
+ intercept[IllegalArgumentException] { $"key".when($"key" === 1, -1) }
+ intercept[IllegalArgumentException] { $"key".otherwise(-1) }
+ intercept[IllegalArgumentException] { when($"key" === 1, -1).otherwise(-1).otherwise(-1) }
+ }
+
test("sqrt") {
checkAnswer(
testData.select(sqrt('key)).orderBy('key.asc),