aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorzero323 <zero323@users.noreply.github.com>2017-04-21 12:06:21 -0700
committerFelix Cheung <felixcheung@apache.org>2017-04-21 12:06:21 -0700
commitfd648bff63f91a30810910dfc5664eea0ff5e6f9 (patch)
treef00aacc693efa151d857684b3cb0907ce74fd775
parenteb00378f0eed6afbf328ae6cd541cc202d14c1f0 (diff)
downloadspark-fd648bff63f91a30810910dfc5664eea0ff5e6f9.tar.gz
spark-fd648bff63f91a30810910dfc5664eea0ff5e6f9.tar.bz2
spark-fd648bff63f91a30810910dfc5664eea0ff5e6f9.zip
[SPARK-20371][R] Add wrappers for collect_list and collect_set
## What changes were proposed in this pull request? Adds wrappers for `collect_list` and `collect_set`. ## How was this patch tested? Unit tests, `check-cran.sh` Author: zero323 <zero323@users.noreply.github.com> Closes #17672 from zero323/SPARK-20371.
-rw-r--r--R/pkg/NAMESPACE2
-rw-r--r--R/pkg/R/functions.R40
-rw-r--r--R/pkg/R/generics.R9
-rw-r--r--R/pkg/inst/tests/testthat/test_sparkSQL.R22
4 files changed, 73 insertions, 0 deletions
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index b6b559adf0..e804e30e14 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -203,6 +203,8 @@ exportMethods("%in%",
"cbrt",
"ceil",
"ceiling",
+ "collect_list",
+ "collect_set",
"column",
"concat",
"concat_ws",
diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R
index f854df11e5..e7decb9186 100644
--- a/R/pkg/R/functions.R
+++ b/R/pkg/R/functions.R
@@ -3705,3 +3705,43 @@ setMethod("create_map",
jc <- callJStatic("org.apache.spark.sql.functions", "map", jcols)
column(jc)
})
+
+#' collect_list
+#'
+#' Creates a list of objects with duplicates.
+#'
+#' @param x Column to compute on
+#'
+#' @rdname collect_list
+#' @name collect_list
+#' @family agg_funcs
+#' @aliases collect_list,Column-method
+#' @export
+#' @examples \dontrun{collect_list(df$x)}
+#' @note collect_list since 2.3.0
+setMethod("collect_list",
+ signature(x = "Column"),
+ function(x) {
+ jc <- callJStatic("org.apache.spark.sql.functions", "collect_list", x@jc)
+ column(jc)
+ })
+
+#' collect_set
+#'
+#' Creates a list of objects with duplicate elements eliminated.
+#'
+#' @param x Column to compute on
+#'
+#' @rdname collect_set
+#' @name collect_set
+#' @family agg_funcs
+#' @aliases collect_set,Column-method
+#' @export
+#' @examples \dontrun{collect_set(df$x)}
+#' @note collect_set since 2.3.0
+setMethod("collect_set",
+ signature(x = "Column"),
+ function(x) {
+ jc <- callJStatic("org.apache.spark.sql.functions", "collect_set", x@jc)
+ column(jc)
+ })
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index da46823f52..61d248ebd2 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -918,6 +918,14 @@ setGeneric("cbrt", function(x) { standardGeneric("cbrt") })
#' @export
setGeneric("ceil", function(x) { standardGeneric("ceil") })
+#' @rdname collect_list
+#' @export
+setGeneric("collect_list", function(x) { standardGeneric("collect_list") })
+
+#' @rdname collect_set
+#' @export
+setGeneric("collect_set", function(x) { standardGeneric("collect_set") })
+
#' @rdname column
#' @export
setGeneric("column", function(x) { standardGeneric("column") })
@@ -1358,6 +1366,7 @@ setGeneric("window", function(x, ...) { standardGeneric("window") })
#' @export
setGeneric("year", function(x) { standardGeneric("year") })
+
###################### Spark.ML Methods ##########################
#' @rdname fitted
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index 9e87a47106..bf2093fdc4 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -1731,6 +1731,28 @@ test_that("group by, agg functions", {
expect_true(abs(sd(1:2) - 0.7071068) < 1e-6)
expect_true(abs(var(1:5, 1:5) - 2.5) < 1e-6)
+ # Test collect_list and collect_set
+ gd3_collections_local <- collect(
+ agg(gd3, collect_set(df8$age), collect_list(df8$age))
+ )
+
+ expect_equal(
+ unlist(gd3_collections_local[gd3_collections_local$name == "Andy", 2]),
+ c(30)
+ )
+
+ expect_equal(
+ unlist(gd3_collections_local[gd3_collections_local$name == "Andy", 3]),
+ c(30, 30)
+ )
+
+ expect_equal(
+ sort(unlist(
+ gd3_collections_local[gd3_collections_local$name == "Justin", 3]
+ )),
+ c(1, 19)
+ )
+
unlink(jsonPath2)
unlink(jsonPath3)
})