aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala20
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala15
2 files changed, 33 insertions, 2 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 0462758988..3f0b24b68b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -175,6 +175,26 @@ object functions {
def avg(columnName: String): Column = avg(Column(columnName))
/**
+ * Aggregate function: returns a list of objects with duplicates.
+ *
+ * For now this is an alias for the collect_list Hive UDAF.
+ *
+ * @group agg_funcs
+ * @since 1.6.0
+ */
+ def collect_list(e: Column): Column = callUDF("collect_list", e)
+
+ /**
+ * Aggregate function: returns a set of objects with duplicate elements eliminated.
+ *
+ * For now this is an alias for the collect_set Hive UDAF.
+ *
+ * @group agg_funcs
+ * @since 1.6.0
+ */
+ def collect_set(e: Column): Column = callUDF("collect_set", e)
+
+ /**
* Aggregate function: returns the Pearson Correlation Coefficient for two columns.
*
* @group agg_funcs
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala
index 2e5cae415e..9864acf765 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.hive
-import org.apache.spark.sql.{DataFrame, QueryTest}
+import org.apache.spark.sql.{DataFrame, QueryTest, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.scalatest.BeforeAndAfterAll
@@ -32,7 +32,7 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with
private var testData: DataFrame = _
override def beforeAll() {
- testData = Seq((1, 2), (2, 4)).toDF("a", "b")
+ testData = Seq((1, 2), (2, 2), (3, 4)).toDF("a", "b")
hiveContext.registerDataFrameAsTable(testData, "mytable")
}
@@ -52,6 +52,17 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with
)
}
+ test("collect functions") {
+ checkAnswer(
+ testData.select(collect_list($"a"), collect_list($"b")),
+ Seq(Row(Seq(1, 2, 3), Seq(2, 2, 4)))
+ )
+ checkAnswer(
+ testData.select(collect_set($"a"), collect_set($"b")),
+ Seq(Row(Seq(1, 2, 3), Seq(2, 4)))
+ )
+ }
+
test("cube") {
checkAnswer(
testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")),