aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/functions.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/sql/functions.py')
-rw-r--r--python/pyspark/sql/functions.py26
1 files changed, 24 insertions, 2 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 38a043a3c5..d91265ee0b 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -32,13 +32,14 @@ from pyspark.sql.dataframe import Column, _to_java_column, _to_seq
__all__ = [
'approxCountDistinct',
+ 'coalesce',
'countDistinct',
'monotonicallyIncreasingId',
'rand',
'randn',
'sparkPartitionId',
- 'coalesce',
- 'udf']
+ 'udf',
+ 'when']
def _create_function(name, doc=""):
@@ -291,6 +292,27 @@ def struct(*cols):
return Column(jc)
+def when(condition, value):
+ """Evaluates a list of conditions and returns one of multiple possible result expressions.
+ If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.
+
+ :param condition: a boolean :class:`Column` expression.
+ :param value: a literal value, or a :class:`Column` expression.
+
+ >>> df.select(when(df['age'] == 2, 3).otherwise(4).alias("age")).collect()
+ [Row(age=3), Row(age=4)]
+
+ >>> df.select(when(df.age == 2, df.age + 1).alias("age")).collect()
+ [Row(age=3), Row(age=None)]
+ """
+ sc = SparkContext._active_spark_context
+ if not isinstance(condition, Column):
+ raise TypeError("condition should be a Column")
+ v = value._jc if isinstance(value, Column) else value
+ jc = sc._jvm.functions.when(condition._jc, v)
+ return Column(jc)
+
+
class UserDefinedFunction(object):
"""
User defined function in Python