diff options
author | Herman van Hovell <hvanhovell@questtec.nl> | 2016-01-31 13:56:13 -0800 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-01-31 13:56:13 -0800 |
commit | 5a8b978fabb60aa178274f86432c63680c8b351a (patch) | |
tree | dd3de9b6cd79870813ccc8ca898da182f2bd881b /python/pyspark | |
parent | 0e6d92d042b0a2920d8df5959d5913ba0166a678 (diff) | |
download | spark-5a8b978fabb60aa178274f86432c63680c8b351a.tar.gz spark-5a8b978fabb60aa178274f86432c63680c8b351a.tar.bz2 spark-5a8b978fabb60aa178274f86432c63680c8b351a.zip |
[SPARK-13049] Add First/last with ignore nulls to functions.scala
This PR adds the ability to specify the ```ignoreNulls``` option to the functions dsl, e.g:
```df.select($"id", last($"value", ignoreNulls = true).over(Window.partitionBy($"id").orderBy($"other"))```
This PR is some where between a bug fix (see the JIRA) and a new feature. I am not sure if we should backport to 1.6.
cc yhuai
Author: Herman van Hovell <hvanhovell@questtec.nl>
Closes #10957 from hvanhovell/SPARK-13049.
Diffstat (limited to 'python/pyspark')
-rw-r--r-- | python/pyspark/sql/functions.py | 26 | ||||
-rw-r--r-- | python/pyspark/sql/tests.py | 10 |
2 files changed, 34 insertions, 2 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 719eca8f55..0d57085267 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -81,8 +81,6 @@ _functions = { 'max': 'Aggregate function: returns the maximum value of the expression in a group.', 'min': 'Aggregate function: returns the minimum value of the expression in a group.', - 'first': 'Aggregate function: returns the first value in a group.', - 'last': 'Aggregate function: returns the last value in a group.', 'count': 'Aggregate function: returns the number of items in a group.', 'sum': 'Aggregate function: returns the sum of all values in the expression.', 'avg': 'Aggregate function: returns the average of the values in a group.', @@ -278,6 +276,18 @@ def countDistinct(col, *cols): return Column(jc) +@since(1.3) +def first(col, ignorenulls=False): + """Aggregate function: returns the first value in a group. + + The function by default returns the first values it sees. It will return the first non-null + value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.first(_to_java_column(col), ignorenulls) + return Column(jc) + + @since(1.6) def input_file_name(): """Creates a string column for the file name of the current Spark task. @@ -310,6 +320,18 @@ def isnull(col): return Column(sc._jvm.functions.isnull(_to_java_column(col))) +@since(1.3) +def last(col, ignorenulls=False): + """Aggregate function: returns the last value in a group. + + The function by default returns the last values it sees. It will return the last non-null + value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.last(_to_java_column(col), ignorenulls) + return Column(jc) + + @since(1.6) def monotonically_increasing_id(): """A column that generates monotonically increasing 64-bit integers. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 410efbafe0..e30aa0a796 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -641,6 +641,16 @@ class SQLTests(ReusedPySparkTestCase): self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0]) self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0]) + def test_first_last_ignorenulls(self): + from pyspark.sql import functions + df = self.sqlCtx.range(0, 100) + df2 = df.select(functions.when(df.id % 3 == 0, None).otherwise(df.id).alias("id")) + df3 = df2.select(functions.first(df2.id, False).alias('a'), + functions.first(df2.id, True).alias('b'), + functions.last(df2.id, False).alias('c'), + functions.last(df2.id, True).alias('d')) + self.assertEqual([Row(a=None, b=1, c=None, d=98)], df3.collect()) + def test_corr(self): import math df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF() |