diff options
author | Nick Buroojy <nick.buroojy@civitaslearning.com> | 2015-11-09 14:30:37 -0800 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2015-11-09 14:30:52 -0800 |
commit | f138cb873335654476d1cd1070900b552dd8b21a (patch) | |
tree | 91bcc549fe561c4f100197f42bd8ce0ad03062be /python/pyspark/sql | |
parent | b7720fa45525cff6e812fa448d0841cb41f6c8a5 (diff) | |
download | spark-f138cb873335654476d1cd1070900b552dd8b21a.tar.gz spark-f138cb873335654476d1cd1070900b552dd8b21a.tar.bz2 spark-f138cb873335654476d1cd1070900b552dd8b21a.zip |
[SPARK-9301][SQL] Add collect_set and collect_list aggregate functions
For now they are thin wrappers around the corresponding Hive UDAFs.
One limitation with these in Hive 0.13.0 is they only support aggregating primitive types.
I chose snake_case here instead of camelCase because it seems to be used in the majority of the multi-word fns.
Do we also want to add these to `functions.py`?
This approach was recommended here: https://github.com/apache/spark/pull/8592#issuecomment-154247089
marmbrus rxin
Author: Nick Buroojy <nick.buroojy@civitaslearning.com>
Closes #9526 from nburoojy/nick/udaf-alias.
(cherry picked from commit a6ee4f989d020420dd08b97abb24802200ff23b2)
Signed-off-by: Michael Armbrust <michael@databricks.com>
Diffstat (limited to 'python/pyspark/sql')
-rw-r--r-- | python/pyspark/sql/functions.py | 25 | ||||
-rw-r--r-- | python/pyspark/sql/tests.py | 17 |
2 files changed, 31 insertions, 11 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 2f7c2f4aac..962f676d40 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -124,17 +124,20 @@ _functions_1_4 = { _functions_1_6 = { # unary math functions - "stddev": "Aggregate function: returns the unbiased sample standard deviation of" + - " the expression in a group.", - "stddev_samp": "Aggregate function: returns the unbiased sample standard deviation of" + - " the expression in a group.", - "stddev_pop": "Aggregate function: returns population standard deviation of" + - " the expression in a group.", - "variance": "Aggregate function: returns the population variance of the values in a group.", - "var_samp": "Aggregate function: returns the unbiased variance of the values in a group.", - "var_pop": "Aggregate function: returns the population variance of the values in a group.", - "skewness": "Aggregate function: returns the skewness of the values in a group.", - "kurtosis": "Aggregate function: returns the kurtosis of the values in a group." + 'stddev': 'Aggregate function: returns the unbiased sample standard deviation of' + + ' the expression in a group.', + 'stddev_samp': 'Aggregate function: returns the unbiased sample standard deviation of' + + ' the expression in a group.', + 'stddev_pop': 'Aggregate function: returns population standard deviation of' + + ' the expression in a group.', + 'variance': 'Aggregate function: returns the population variance of the values in a group.', + 'var_samp': 'Aggregate function: returns the unbiased variance of the values in a group.', + 'var_pop': 'Aggregate function: returns the population variance of the values in a group.', + 'skewness': 'Aggregate function: returns the skewness of the values in a group.', + 'kurtosis': 'Aggregate function: returns the kurtosis of the values in a group.', + 'collect_list': 'Aggregate function: returns a list of objects with duplicates.', + 'collect_set': 'Aggregate function: returns a set of objects with duplicate elements' + + ' eliminated.' } # math functions that take two arguments as input diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 4c03a0d4ff..e224574bcb 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1230,6 +1230,23 @@ class HiveContextSQLTests(ReusedPySparkTestCase): for r, ex in zip(rs, expected): self.assertEqual(tuple(r), ex[:len(r)]) + def test_collect_functions(self): + df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + from pyspark.sql import functions + + self.assertEqual( + sorted(df.select(functions.collect_set(df.key).alias('r')).collect()[0].r), + [1, 2]) + self.assertEqual( + sorted(df.select(functions.collect_list(df.key).alias('r')).collect()[0].r), + [1, 1, 1, 2]) + self.assertEqual( + sorted(df.select(functions.collect_set(df.value).alias('r')).collect()[0].r), + ["1", "2"]) + self.assertEqual( + sorted(df.select(functions.collect_list(df.value).alias('r')).collect()[0].r), + ["1", "2", "2", "2"]) + if __name__ == "__main__": if xmlrunner: |