aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNick Buroojy <nick.buroojy@civitaslearning.com>2015-11-09 14:30:37 -0800
committerMichael Armbrust <michael@databricks.com>2015-11-09 14:30:52 -0800
commitf138cb873335654476d1cd1070900b552dd8b21a (patch)
tree91bcc549fe561c4f100197f42bd8ce0ad03062be
parentb7720fa45525cff6e812fa448d0841cb41f6c8a5 (diff)
downloadspark-f138cb873335654476d1cd1070900b552dd8b21a.tar.gz
spark-f138cb873335654476d1cd1070900b552dd8b21a.tar.bz2
spark-f138cb873335654476d1cd1070900b552dd8b21a.zip
[SPARK-9301][SQL] Add collect_set and collect_list aggregate functions
For now they are thin wrappers around the corresponding Hive UDAFs. One limitation with these in Hive 0.13.0 is they only support aggregating primitive types. I chose snake_case here instead of camelCase because it seems to be used in the majority of the multi-word fns. Do we also want to add these to `functions.py`? This approach was recommended here: https://github.com/apache/spark/pull/8592#issuecomment-154247089 marmbrus rxin Author: Nick Buroojy <nick.buroojy@civitaslearning.com> Closes #9526 from nburoojy/nick/udaf-alias. (cherry picked from commit a6ee4f989d020420dd08b97abb24802200ff23b2) Signed-off-by: Michael Armbrust <michael@databricks.com>
-rw-r--r--python/pyspark/sql/functions.py25
-rw-r--r--python/pyspark/sql/tests.py17
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala20
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala15
4 files changed, 64 insertions, 13 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 2f7c2f4aac..962f676d40 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -124,17 +124,20 @@ _functions_1_4 = {
_functions_1_6 = {
# unary math functions
- "stddev": "Aggregate function: returns the unbiased sample standard deviation of" +
- " the expression in a group.",
- "stddev_samp": "Aggregate function: returns the unbiased sample standard deviation of" +
- " the expression in a group.",
- "stddev_pop": "Aggregate function: returns population standard deviation of" +
- " the expression in a group.",
- "variance": "Aggregate function: returns the population variance of the values in a group.",
- "var_samp": "Aggregate function: returns the unbiased variance of the values in a group.",
- "var_pop": "Aggregate function: returns the population variance of the values in a group.",
- "skewness": "Aggregate function: returns the skewness of the values in a group.",
- "kurtosis": "Aggregate function: returns the kurtosis of the values in a group."
+ 'stddev': 'Aggregate function: returns the unbiased sample standard deviation of' +
+ ' the expression in a group.',
+ 'stddev_samp': 'Aggregate function: returns the unbiased sample standard deviation of' +
+ ' the expression in a group.',
+ 'stddev_pop': 'Aggregate function: returns population standard deviation of' +
+ ' the expression in a group.',
+ 'variance': 'Aggregate function: returns the population variance of the values in a group.',
+ 'var_samp': 'Aggregate function: returns the unbiased variance of the values in a group.',
+ 'var_pop': 'Aggregate function: returns the population variance of the values in a group.',
+ 'skewness': 'Aggregate function: returns the skewness of the values in a group.',
+ 'kurtosis': 'Aggregate function: returns the kurtosis of the values in a group.',
+ 'collect_list': 'Aggregate function: returns a list of objects with duplicates.',
+ 'collect_set': 'Aggregate function: returns a set of objects with duplicate elements' +
+ ' eliminated.'
}
# math functions that take two arguments as input
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 4c03a0d4ff..e224574bcb 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -1230,6 +1230,23 @@ class HiveContextSQLTests(ReusedPySparkTestCase):
for r, ex in zip(rs, expected):
self.assertEqual(tuple(r), ex[:len(r)])
+ def test_collect_functions(self):
+ df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
+ from pyspark.sql import functions
+
+ self.assertEqual(
+ sorted(df.select(functions.collect_set(df.key).alias('r')).collect()[0].r),
+ [1, 2])
+ self.assertEqual(
+ sorted(df.select(functions.collect_list(df.key).alias('r')).collect()[0].r),
+ [1, 1, 1, 2])
+ self.assertEqual(
+ sorted(df.select(functions.collect_set(df.value).alias('r')).collect()[0].r),
+ ["1", "2"])
+ self.assertEqual(
+ sorted(df.select(functions.collect_list(df.value).alias('r')).collect()[0].r),
+ ["1", "2", "2", "2"])
+
if __name__ == "__main__":
if xmlrunner:
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 0462758988..3f0b24b68b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -175,6 +175,26 @@ object functions {
def avg(columnName: String): Column = avg(Column(columnName))
/**
+ * Aggregate function: returns a list of objects with duplicates.
+ *
+ * For now this is an alias for the collect_list Hive UDAF.
+ *
+ * @group agg_funcs
+ * @since 1.6.0
+ */
+ def collect_list(e: Column): Column = callUDF("collect_list", e)
+
+ /**
+ * Aggregate function: returns a set of objects with duplicate elements eliminated.
+ *
+ * For now this is an alias for the collect_set Hive UDAF.
+ *
+ * @group agg_funcs
+ * @since 1.6.0
+ */
+ def collect_set(e: Column): Column = callUDF("collect_set", e)
+
+ /**
* Aggregate function: returns the Pearson Correlation Coefficient for two columns.
*
* @group agg_funcs
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala
index 2e5cae415e..9864acf765 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.hive
-import org.apache.spark.sql.{DataFrame, QueryTest}
+import org.apache.spark.sql.{DataFrame, QueryTest, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.scalatest.BeforeAndAfterAll
@@ -32,7 +32,7 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with
private var testData: DataFrame = _
override def beforeAll() {
- testData = Seq((1, 2), (2, 4)).toDF("a", "b")
+ testData = Seq((1, 2), (2, 2), (3, 4)).toDF("a", "b")
hiveContext.registerDataFrameAsTable(testData, "mytable")
}
@@ -52,6 +52,17 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with
)
}
+ test("collect functions") {
+ checkAnswer(
+ testData.select(collect_list($"a"), collect_list($"b")),
+ Seq(Row(Seq(1, 2, 3), Seq(2, 2, 4)))
+ )
+ checkAnswer(
+ testData.select(collect_set($"a"), collect_set($"b")),
+ Seq(Row(Seq(1, 2, 3), Seq(2, 4)))
+ )
+ }
+
test("cube") {
checkAnswer(
testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")),