aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-05-12 21:43:34 -0700
committerReynold Xin <rxin@databricks.com>2015-05-12 21:43:34 -0700
commit97dee313f23b00f15638cb72a4a80c1f197f8a9d (patch)
treeb6718297822929afef06cb8550e765c8ad637efe /python/pyspark/sql
parent8fd55358b7fc1c7545d823bef7b39769f731c1ee (diff)
downloadspark-97dee313f23b00f15638cb72a4a80c1f197f8a9d.tar.gz
spark-97dee313f23b00f15638cb72a4a80c1f197f8a9d.tar.bz2
spark-97dee313f23b00f15638cb72a4a80c1f197f8a9d.zip
[SPARK-7321][SQL] Add Column expression for conditional statements (when/otherwise)
This builds on https://github.com/apache/spark/pull/5932 and should close https://github.com/apache/spark/pull/5932 as well. As an example: ```python df.select(when(df['age'] == 2, 3).otherwise(4).alias("age")).collect() ``` Author: Reynold Xin <rxin@databricks.com> Author: kaka1992 <kaka_1992@163.com> Closes #6072 from rxin/when-expr and squashes the following commits: 8f49201 [Reynold Xin] Throw exception if otherwise is applied twice. 0455eda [Reynold Xin] Reset run-tests. bfb9d9f [Reynold Xin] Updated documentation and test cases. 762f6a5 [Reynold Xin] Merge pull request #5932 from kaka1992/IFCASE 95724c6 [kaka1992] Update 8218d0a [kaka1992] Update 801009e [kaka1992] Update 76d6346 [kaka1992] [SPARK-7321][SQL] Add Column expression for conditional statements (if, case)
Diffstat (limited to 'python/pyspark/sql')
-rw-r--r--python/pyspark/sql/__init__.py2
-rw-r--r--python/pyspark/sql/dataframe.py31
-rw-r--r--python/pyspark/sql/functions.py26
3 files changed, 57 insertions, 2 deletions
diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py
index b60b991dd4..7192c89b3d 100644
--- a/python/pyspark/sql/__init__.py
+++ b/python/pyspark/sql/__init__.py
@@ -32,6 +32,8 @@ Important classes of Spark SQL and DataFrames:
Aggregation methods, returned by :func:`DataFrame.groupBy`.
- L{DataFrameNaFunctions}
Methods for handling missing data (null values).
+ - L{DataFrameStatFunctions}
+ Methods for statistics functionality.
- L{functions}
List of built-in functions available for :class:`DataFrame`.
- L{types}
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 078acfdf7e..82cb1c2fdb 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -1546,6 +1546,37 @@ class Column(object):
"""
return (self >= lowerBound) & (self <= upperBound)
+ @ignore_unicode_prefix
+ def when(self, condition, value):
+ """Evaluates a list of conditions and returns one of multiple possible result expressions.
+ If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.
+
+ See :func:`pyspark.sql.functions.when` for example usage.
+
+ :param condition: a boolean :class:`Column` expression.
+ :param value: a literal value, or a :class:`Column` expression.
+
+ """
+ sc = SparkContext._active_spark_context
+ if not isinstance(condition, Column):
+ raise TypeError("condition should be a Column")
+ v = value._jc if isinstance(value, Column) else value
+ jc = sc._jvm.functions.when(condition._jc, v)
+ return Column(jc)
+
+ @ignore_unicode_prefix
+ def otherwise(self, value):
+ """Evaluates a list of conditions and returns one of multiple possible result expressions.
+ If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.
+
+ See :func:`pyspark.sql.functions.when` for example usage.
+
+ :param value: a literal value, or a :class:`Column` expression.
+ """
+ v = value._jc if isinstance(value, Column) else value
+ jc = self._jc.otherwise(value)
+ return Column(jc)
+
def __repr__(self):
return 'Column<%s>' % self._jc.toString().encode('utf8')
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 38a043a3c5..d91265ee0b 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -32,13 +32,14 @@ from pyspark.sql.dataframe import Column, _to_java_column, _to_seq
__all__ = [
'approxCountDistinct',
+ 'coalesce',
'countDistinct',
'monotonicallyIncreasingId',
'rand',
'randn',
'sparkPartitionId',
- 'coalesce',
- 'udf']
+ 'udf',
+ 'when']
def _create_function(name, doc=""):
@@ -291,6 +292,27 @@ def struct(*cols):
return Column(jc)
+def when(condition, value):
+ """Evaluates a list of conditions and returns one of multiple possible result expressions.
+ If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.
+
+ :param condition: a boolean :class:`Column` expression.
+ :param value: a literal value, or a :class:`Column` expression.
+
+ >>> df.select(when(df['age'] == 2, 3).otherwise(4).alias("age")).collect()
+ [Row(age=3), Row(age=4)]
+
+ >>> df.select(when(df.age == 2, df.age + 1).alias("age")).collect()
+ [Row(age=3), Row(age=None)]
+ """
+ sc = SparkContext._active_spark_context
+ if not isinstance(condition, Column):
+ raise TypeError("condition should be a Column")
+ v = value._jc if isinstance(value, Column) else value
+ jc = sc._jvm.functions.when(condition._jc, v)
+ return Column(jc)
+
+
class UserDefinedFunction(object):
"""
User defined function in Python