From 97dee313f23b00f15638cb72a4a80c1f197f8a9d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 12 May 2015 21:43:34 -0700 Subject: [SPARK-7321][SQL] Add Column expression for conditional statements (when/otherwise) This builds on https://github.com/apache/spark/pull/5932 and should close https://github.com/apache/spark/pull/5932 as well. As an example: ```python df.select(when(df['age'] == 2, 3).otherwise(4).alias("age")).collect() ``` Author: Reynold Xin Author: kaka1992 Closes #6072 from rxin/when-expr and squashes the following commits: 8f49201 [Reynold Xin] Throw exception if otherwise is applied twice. 0455eda [Reynold Xin] Reset run-tests. bfb9d9f [Reynold Xin] Updated documentation and test cases. 762f6a5 [Reynold Xin] Merge pull request #5932 from kaka1992/IFCASE 95724c6 [kaka1992] Update 8218d0a [kaka1992] Update 801009e [kaka1992] Update 76d6346 [kaka1992] [SPARK-7321][SQL] Add Column expression for conditional statements (if, case) --- python/pyspark/sql/__init__.py | 2 ++ python/pyspark/sql/dataframe.py | 31 +++++++++++++++++++++++++++++++ python/pyspark/sql/functions.py | 26 ++++++++++++++++++++++++-- 3 files changed, 57 insertions(+), 2 deletions(-) (limited to 'python') diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index b60b991dd4..7192c89b3d 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -32,6 +32,8 @@ Important classes of Spark SQL and DataFrames: Aggregation methods, returned by :func:`DataFrame.groupBy`. - L{DataFrameNaFunctions} Methods for handling missing data (null values). + - L{DataFrameStatFunctions} + Methods for statistics functionality. - L{functions} List of built-in functions available for :class:`DataFrame`. - L{types} diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 078acfdf7e..82cb1c2fdb 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1546,6 +1546,37 @@ class Column(object): """ return (self >= lowerBound) & (self <= upperBound) + @ignore_unicode_prefix + def when(self, condition, value): + """Evaluates a list of conditions and returns one of multiple possible result expressions. + If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions. + + See :func:`pyspark.sql.functions.when` for example usage. + + :param condition: a boolean :class:`Column` expression. + :param value: a literal value, or a :class:`Column` expression. + + """ + sc = SparkContext._active_spark_context + if not isinstance(condition, Column): + raise TypeError("condition should be a Column") + v = value._jc if isinstance(value, Column) else value + jc = sc._jvm.functions.when(condition._jc, v) + return Column(jc) + + @ignore_unicode_prefix + def otherwise(self, value): + """Evaluates a list of conditions and returns one of multiple possible result expressions. + If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions. + + See :func:`pyspark.sql.functions.when` for example usage. + + :param value: a literal value, or a :class:`Column` expression. + """ + v = value._jc if isinstance(value, Column) else value + jc = self._jc.otherwise(value) + return Column(jc) + def __repr__(self): return 'Column<%s>' % self._jc.toString().encode('utf8') diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 38a043a3c5..d91265ee0b 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -32,13 +32,14 @@ from pyspark.sql.dataframe import Column, _to_java_column, _to_seq __all__ = [ 'approxCountDistinct', + 'coalesce', 'countDistinct', 'monotonicallyIncreasingId', 'rand', 'randn', 'sparkPartitionId', - 'coalesce', - 'udf'] + 'udf', + 'when'] def _create_function(name, doc=""): @@ -291,6 +292,27 @@ def struct(*cols): return Column(jc) +def when(condition, value): + """Evaluates a list of conditions and returns one of multiple possible result expressions. + If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions. + + :param condition: a boolean :class:`Column` expression. + :param value: a literal value, or a :class:`Column` expression. + + >>> df.select(when(df['age'] == 2, 3).otherwise(4).alias("age")).collect() + [Row(age=3), Row(age=4)] + + >>> df.select(when(df.age == 2, df.age + 1).alias("age")).collect() + [Row(age=3), Row(age=None)] + """ + sc = SparkContext._active_spark_context + if not isinstance(condition, Column): + raise TypeError("condition should be a Column") + v = value._jc if isinstance(value, Column) else value + jc = sc._jvm.functions.when(condition._jc, v) + return Column(jc) + + class UserDefinedFunction(object): """ User defined function in Python -- cgit v1.2.3