diff options
author | Reynold Xin <rxin@databricks.com> | 2015-05-12 21:43:34 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-05-12 21:43:34 -0700 |
commit | 97dee313f23b00f15638cb72a4a80c1f197f8a9d (patch) | |
tree | b6718297822929afef06cb8550e765c8ad637efe | |
parent | 8fd55358b7fc1c7545d823bef7b39769f731c1ee (diff) | |
download | spark-97dee313f23b00f15638cb72a4a80c1f197f8a9d.tar.gz spark-97dee313f23b00f15638cb72a4a80c1f197f8a9d.tar.bz2 spark-97dee313f23b00f15638cb72a4a80c1f197f8a9d.zip |
[SPARK-7321][SQL] Add Column expression for conditional statements (when/otherwise)
This builds on https://github.com/apache/spark/pull/5932 and should close https://github.com/apache/spark/pull/5932 as well.
As an example:
```python
df.select(when(df['age'] == 2, 3).otherwise(4).alias("age")).collect()
```
Author: Reynold Xin <rxin@databricks.com>
Author: kaka1992 <kaka_1992@163.com>
Closes #6072 from rxin/when-expr and squashes the following commits:
8f49201 [Reynold Xin] Throw exception if otherwise is applied twice.
0455eda [Reynold Xin] Reset run-tests.
bfb9d9f [Reynold Xin] Updated documentation and test cases.
762f6a5 [Reynold Xin] Merge pull request #5932 from kaka1992/IFCASE
95724c6 [kaka1992] Update
8218d0a [kaka1992] Update
801009e [kaka1992] Update
76d6346 [kaka1992] [SPARK-7321][SQL] Add Column expression for conditional statements (if, case)
-rw-r--r-- | python/pyspark/sql/__init__.py | 2 | ||||
-rw-r--r-- | python/pyspark/sql/dataframe.py | 31 | ||||
-rw-r--r-- | python/pyspark/sql/functions.py | 26 | ||||
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/Column.scala | 61 | ||||
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/functions.scala | 24 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala | 21 |
6 files changed, 163 insertions, 2 deletions
diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index b60b991dd4..7192c89b3d 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -32,6 +32,8 @@ Important classes of Spark SQL and DataFrames: Aggregation methods, returned by :func:`DataFrame.groupBy`. - L{DataFrameNaFunctions} Methods for handling missing data (null values). + - L{DataFrameStatFunctions} + Methods for statistics functionality. - L{functions} List of built-in functions available for :class:`DataFrame`. - L{types} diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 078acfdf7e..82cb1c2fdb 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1546,6 +1546,37 @@ class Column(object): """ return (self >= lowerBound) & (self <= upperBound) + @ignore_unicode_prefix + def when(self, condition, value): + """Evaluates a list of conditions and returns one of multiple possible result expressions. + If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions. + + See :func:`pyspark.sql.functions.when` for example usage. + + :param condition: a boolean :class:`Column` expression. + :param value: a literal value, or a :class:`Column` expression. + + """ + sc = SparkContext._active_spark_context + if not isinstance(condition, Column): + raise TypeError("condition should be a Column") + v = value._jc if isinstance(value, Column) else value + jc = sc._jvm.functions.when(condition._jc, v) + return Column(jc) + + @ignore_unicode_prefix + def otherwise(self, value): + """Evaluates a list of conditions and returns one of multiple possible result expressions. + If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions. + + See :func:`pyspark.sql.functions.when` for example usage. + + :param value: a literal value, or a :class:`Column` expression. + """ + v = value._jc if isinstance(value, Column) else value + jc = self._jc.otherwise(value) + return Column(jc) + def __repr__(self): return 'Column<%s>' % self._jc.toString().encode('utf8') diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 38a043a3c5..d91265ee0b 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -32,13 +32,14 @@ from pyspark.sql.dataframe import Column, _to_java_column, _to_seq __all__ = [ 'approxCountDistinct', + 'coalesce', 'countDistinct', 'monotonicallyIncreasingId', 'rand', 'randn', 'sparkPartitionId', - 'coalesce', - 'udf'] + 'udf', + 'when'] def _create_function(name, doc=""): @@ -291,6 +292,27 @@ def struct(*cols): return Column(jc) +def when(condition, value): + """Evaluates a list of conditions and returns one of multiple possible result expressions. + If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions. + + :param condition: a boolean :class:`Column` expression. + :param value: a literal value, or a :class:`Column` expression. + + >>> df.select(when(df['age'] == 2, 3).otherwise(4).alias("age")).collect() + [Row(age=3), Row(age=4)] + + >>> df.select(when(df.age == 2, df.age + 1).alias("age")).collect() + [Row(age=3), Row(age=None)] + """ + sc = SparkContext._active_spark_context + if not isinstance(condition, Column): + raise TypeError("condition should be a Column") + v = value._jc if isinstance(value, Column) else value + jc = sc._jvm.functions.when(condition._jc, v) + return Column(jc) + + class UserDefinedFunction(object): """ User defined function in Python diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 4773dedf72..42f5bcda49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -328,6 +328,67 @@ class Column(protected[sql] val expr: Expression) extends Logging { def eqNullSafe(other: Any): Column = this <=> other /** + * Evaluates a list of conditions and returns one of multiple possible result expressions. + * If otherwise is not defined at the end, null is returned for unmatched conditions. + * + * {{{ + * // Example: encoding gender string column into integer. + * + * // Scala: + * people.select(when(people("gender") === "male", 0) + * .when(people("gender") === "female", 1) + * .otherwise(2)) + * + * // Java: + * people.select(when(col("gender").equalTo("male"), 0) + * .when(col("gender").equalTo("female"), 1) + * .otherwise(2)) + * }}} + * + * @group expr_ops + */ + def when(condition: Column, value: Any):Column = this.expr match { + case CaseWhen(branches: Seq[Expression]) => + CaseWhen(branches ++ Seq(lit(condition).expr, lit(value).expr)) + case _ => + throw new IllegalArgumentException( + "when() can only be applied on a Column previously generated by when() function") + } + + /** + * Evaluates a list of conditions and returns one of multiple possible result expressions. + * If otherwise is not defined at the end, null is returned for unmatched conditions. + * + * {{{ + * // Example: encoding gender string column into integer. + * + * // Scala: + * people.select(when(people("gender") === "male", 0) + * .when(people("gender") === "female", 1) + * .otherwise(2)) + * + * // Java: + * people.select(when(col("gender").equalTo("male"), 0) + * .when(col("gender").equalTo("female"), 1) + * .otherwise(2)) + * }}} + * + * @group expr_ops + */ + def otherwise(value: Any):Column = this.expr match { + case CaseWhen(branches: Seq[Expression]) => + if (branches.size % 2 == 0) { + CaseWhen(branches :+ lit(value).expr) + } else { + throw new IllegalArgumentException( + "otherwise() can only be applied once on a Column previously generated by when()") + } + case _ => + throw new IllegalArgumentException( + "otherwise() can only be applied on a Column previously generated by when()") + } + + /** * True if the current column is between the lower bound and upper bound, inclusive. * * @group java_expr_ops diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 215787e40b..099e1d8f03 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -420,6 +420,30 @@ object functions { def not(e: Column): Column = !e /** + * Evaluates a list of conditions and returns one of multiple possible result expressions. + * If otherwise is not defined at the end, null is returned for unmatched conditions. + * + * {{{ + * // Example: encoding gender string column into integer. + * + * // Scala: + * people.select(when(people("gender") === "male", 0) + * .when(people("gender") === "female", 1) + * .otherwise(2)) + * + * // Java: + * people.select(when(col("gender").equalTo("male"), 0) + * .when(col("gender").equalTo("female"), 1) + * .otherwise(2)) + * }}} + * + * @group normal_funcs + */ + def when(condition: Column, value: Any): Column = { + CaseWhen(Seq(condition.expr, lit(value).expr)) + } + + /** * Generate a random column with i.i.d. samples from U[0.0, 1.0]. * * @group normal_funcs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index d96186c268..269e185543 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -255,6 +255,27 @@ class ColumnExpressionSuite extends QueryTest { Row(false, true) :: Row(true, false) :: Row(true, true) :: Nil) } + test("SPARK-7321 when conditional statements") { + val testData = (1 to 3).map(i => (i, i.toString)).toDF("key", "value") + + checkAnswer( + testData.select(when($"key" === 1, -1).when($"key" === 2, -2).otherwise(0)), + Seq(Row(-1), Row(-2), Row(0)) + ) + + // Without the ending otherwise, return null for unmatched conditions. + // Also test putting a non-literal value in the expression. + checkAnswer( + testData.select(when($"key" === 1, lit(0) - $"key").when($"key" === 2, -2)), + Seq(Row(-1), Row(-2), Row(null)) + ) + + // Test error handling for invalid expressions. + intercept[IllegalArgumentException] { $"key".when($"key" === 1, -1) } + intercept[IllegalArgumentException] { $"key".otherwise(-1) } + intercept[IllegalArgumentException] { when($"key" === 1, -1).otherwise(-1).otherwise(-1) } + } + test("sqrt") { checkAnswer( testData.select(sqrt('key)).orderBy('key.asc), |