aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorHerman van Hovell <hvanhovell@questtec.nl>2016-01-31 13:56:13 -0800
committerReynold Xin <rxin@databricks.com>2016-01-31 13:56:13 -0800
commit5a8b978fabb60aa178274f86432c63680c8b351a (patch)
treedd3de9b6cd79870813ccc8ca898da182f2bd881b /python
parent0e6d92d042b0a2920d8df5959d5913ba0166a678 (diff)
downloadspark-5a8b978fabb60aa178274f86432c63680c8b351a.tar.gz
spark-5a8b978fabb60aa178274f86432c63680c8b351a.tar.bz2
spark-5a8b978fabb60aa178274f86432c63680c8b351a.zip
[SPARK-13049] Add First/last with ignore nulls to functions.scala
This PR adds the ability to specify the ```ignoreNulls``` option to the functions dsl, e.g: ```df.select($"id", last($"value", ignoreNulls = true).over(Window.partitionBy($"id").orderBy($"other"))``` This PR is some where between a bug fix (see the JIRA) and a new feature. I am not sure if we should backport to 1.6. cc yhuai Author: Herman van Hovell <hvanhovell@questtec.nl> Closes #10957 from hvanhovell/SPARK-13049.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql/functions.py26
-rw-r--r--python/pyspark/sql/tests.py10
2 files changed, 34 insertions, 2 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 719eca8f55..0d57085267 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -81,8 +81,6 @@ _functions = {
'max': 'Aggregate function: returns the maximum value of the expression in a group.',
'min': 'Aggregate function: returns the minimum value of the expression in a group.',
- 'first': 'Aggregate function: returns the first value in a group.',
- 'last': 'Aggregate function: returns the last value in a group.',
'count': 'Aggregate function: returns the number of items in a group.',
'sum': 'Aggregate function: returns the sum of all values in the expression.',
'avg': 'Aggregate function: returns the average of the values in a group.',
@@ -278,6 +276,18 @@ def countDistinct(col, *cols):
return Column(jc)
+@since(1.3)
+def first(col, ignorenulls=False):
+ """Aggregate function: returns the first value in a group.
+
+ The function by default returns the first values it sees. It will return the first non-null
+ value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
+ """
+ sc = SparkContext._active_spark_context
+ jc = sc._jvm.functions.first(_to_java_column(col), ignorenulls)
+ return Column(jc)
+
+
@since(1.6)
def input_file_name():
"""Creates a string column for the file name of the current Spark task.
@@ -310,6 +320,18 @@ def isnull(col):
return Column(sc._jvm.functions.isnull(_to_java_column(col)))
+@since(1.3)
+def last(col, ignorenulls=False):
+ """Aggregate function: returns the last value in a group.
+
+ The function by default returns the last values it sees. It will return the last non-null
+ value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
+ """
+ sc = SparkContext._active_spark_context
+ jc = sc._jvm.functions.last(_to_java_column(col), ignorenulls)
+ return Column(jc)
+
+
@since(1.6)
def monotonically_increasing_id():
"""A column that generates monotonically increasing 64-bit integers.
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 410efbafe0..e30aa0a796 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -641,6 +641,16 @@ class SQLTests(ReusedPySparkTestCase):
self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0])
self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0])
+ def test_first_last_ignorenulls(self):
+ from pyspark.sql import functions
+ df = self.sqlCtx.range(0, 100)
+ df2 = df.select(functions.when(df.id % 3 == 0, None).otherwise(df.id).alias("id"))
+ df3 = df2.select(functions.first(df2.id, False).alias('a'),
+ functions.first(df2.id, True).alias('b'),
+ functions.last(df2.id, False).alias('c'),
+ functions.last(df2.id, True).alias('d'))
+ self.assertEqual([Row(a=None, b=1, c=None, d=98)], df3.collect())
+
def test_corr(self):
import math
df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF()