aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark')
-rw-r--r--python/pyspark/sql/functions.py25
-rw-r--r--python/pyspark/sql/tests.py17
2 files changed, 31 insertions, 11 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 2f7c2f4aac..962f676d40 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -124,17 +124,20 @@ _functions_1_4 = {
_functions_1_6 = {
# unary math functions
- "stddev": "Aggregate function: returns the unbiased sample standard deviation of" +
- " the expression in a group.",
- "stddev_samp": "Aggregate function: returns the unbiased sample standard deviation of" +
- " the expression in a group.",
- "stddev_pop": "Aggregate function: returns population standard deviation of" +
- " the expression in a group.",
- "variance": "Aggregate function: returns the population variance of the values in a group.",
- "var_samp": "Aggregate function: returns the unbiased variance of the values in a group.",
- "var_pop": "Aggregate function: returns the population variance of the values in a group.",
- "skewness": "Aggregate function: returns the skewness of the values in a group.",
- "kurtosis": "Aggregate function: returns the kurtosis of the values in a group."
+ 'stddev': 'Aggregate function: returns the unbiased sample standard deviation of' +
+ ' the expression in a group.',
+ 'stddev_samp': 'Aggregate function: returns the unbiased sample standard deviation of' +
+ ' the expression in a group.',
+ 'stddev_pop': 'Aggregate function: returns population standard deviation of' +
+ ' the expression in a group.',
+ 'variance': 'Aggregate function: returns the population variance of the values in a group.',
+ 'var_samp': 'Aggregate function: returns the unbiased variance of the values in a group.',
+ 'var_pop': 'Aggregate function: returns the population variance of the values in a group.',
+ 'skewness': 'Aggregate function: returns the skewness of the values in a group.',
+ 'kurtosis': 'Aggregate function: returns the kurtosis of the values in a group.',
+ 'collect_list': 'Aggregate function: returns a list of objects with duplicates.',
+ 'collect_set': 'Aggregate function: returns a set of objects with duplicate elements' +
+ ' eliminated.'
}
# math functions that take two arguments as input
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 4c03a0d4ff..e224574bcb 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -1230,6 +1230,23 @@ class HiveContextSQLTests(ReusedPySparkTestCase):
for r, ex in zip(rs, expected):
self.assertEqual(tuple(r), ex[:len(r)])
+ def test_collect_functions(self):
+ df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
+ from pyspark.sql import functions
+
+ self.assertEqual(
+ sorted(df.select(functions.collect_set(df.key).alias('r')).collect()[0].r),
+ [1, 2])
+ self.assertEqual(
+ sorted(df.select(functions.collect_list(df.key).alias('r')).collect()[0].r),
+ [1, 1, 1, 2])
+ self.assertEqual(
+ sorted(df.select(functions.collect_set(df.value).alias('r')).collect()[0].r),
+ ["1", "2"])
+ self.assertEqual(
+ sorted(df.select(functions.collect_list(df.value).alias('r')).collect()[0].r),
+ ["1", "2", "2", "2"])
+
if __name__ == "__main__":
if xmlrunner: