aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSun Rui <sunrui2016@gmail.com>2016-05-12 17:50:55 -0700
committerShivaram Venkataraman <shivaram@cs.berkeley.edu>2016-05-12 17:50:55 -0700
commitb3930f74a0929b2cdcbbe5cbe34f0b1d35eb01cc (patch)
tree6c08e7b8ca13d7e73f5c667b94bb6039211fd2bc
parentbb1362eb3b36b553dca246b95f59ba7fd8adcc8a (diff)
downloadspark-b3930f74a0929b2cdcbbe5cbe34f0b1d35eb01cc.tar.gz
spark-b3930f74a0929b2cdcbbe5cbe34f0b1d35eb01cc.tar.bz2
spark-b3930f74a0929b2cdcbbe5cbe34f0b1d35eb01cc.zip
[SPARK-15202][SPARKR] add dapplyCollect() method for DataFrame in SparkR.
## What changes were proposed in this pull request? dapplyCollect() applies an R function on each partition of a SparkDataFrame and collects the result back to R as a data.frame. ``` dapplyCollect(df, function(ldf) {...}) ``` ## How was this patch tested? SparkR unit tests. Author: Sun Rui <sunrui2016@gmail.com> Closes #12989 from sun-rui/SPARK-15202.
-rw-r--r--R/pkg/NAMESPACE1
-rw-r--r--R/pkg/R/DataFrame.R86
-rw-r--r--R/pkg/R/generics.R4
-rw-r--r--R/pkg/inst/tests/testthat/test_sparkSQL.R21
4 files changed, 95 insertions, 17 deletions
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 1432ab8a9d..239ad065d0 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -47,6 +47,7 @@ exportMethods("arrange",
"covar_pop",
"crosstab",
"dapply",
+ "dapplyCollect",
"describe",
"dim",
"distinct",
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 43c46b8474..0c2a194483 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -1153,9 +1153,27 @@ setMethod("summarize",
agg(x, ...)
})
+dapplyInternal <- function(x, func, schema) {
+ packageNamesArr <- serialize(.sparkREnv[[".packages"]],
+ connection = NULL)
+
+ broadcastArr <- lapply(ls(.broadcastNames),
+ function(name) { get(name, .broadcastNames) })
+
+ sdf <- callJStatic(
+ "org.apache.spark.sql.api.r.SQLUtils",
+ "dapply",
+ x@sdf,
+ serialize(cleanClosure(func), connection = NULL),
+ packageNamesArr,
+ broadcastArr,
+ if (is.null(schema)) { schema } else { schema$jobj })
+ dataFrame(sdf)
+}
+
#' dapply
#'
-#' Apply a function to each partition of a DataFrame.
+#' Apply a function to each partition of a SparkDataFrame.
#'
#' @param x A SparkDataFrame
#' @param func A function to be applied to each partition of the SparkDataFrame.
@@ -1197,21 +1215,57 @@ setMethod("summarize",
setMethod("dapply",
signature(x = "SparkDataFrame", func = "function", schema = "structType"),
function(x, func, schema) {
- packageNamesArr <- serialize(.sparkREnv[[".packages"]],
- connection = NULL)
-
- broadcastArr <- lapply(ls(.broadcastNames),
- function(name) { get(name, .broadcastNames) })
-
- sdf <- callJStatic(
- "org.apache.spark.sql.api.r.SQLUtils",
- "dapply",
- x@sdf,
- serialize(cleanClosure(func), connection = NULL),
- packageNamesArr,
- broadcastArr,
- schema$jobj)
- dataFrame(sdf)
+ dapplyInternal(x, func, schema)
+ })
+
+#' dapplyCollect
+#'
+#' Apply a function to each partition of a SparkDataFrame and collect the result back
+#’ to R as a data.frame.
+#'
+#' @param x A SparkDataFrame
+#' @param func A function to be applied to each partition of the SparkDataFrame.
+#' func should have only one parameter, to which a data.frame corresponds
+#' to each partition will be passed.
+#' The output of func should be a data.frame.
+#' @family SparkDataFrame functions
+#' @rdname dapply
+#' @name dapplyCollect
+#' @export
+#' @examples
+#' \dontrun{
+#' df <- createDataFrame (sqlContext, iris)
+#' ldf <- dapplyCollect(df, function(x) { x })
+#'
+#' # filter and add a column
+#' df <- createDataFrame (
+#' sqlContext,
+#' list(list(1L, 1, "1"), list(2L, 2, "2"), list(3L, 3, "3")),
+#' c("a", "b", "c"))
+#' ldf <- dapplyCollect(
+#' df,
+#' function(x) {
+#' y <- x[x[1] > 1, ]
+#' y <- cbind(y, y[1] + 1L)
+#' })
+#' # the result
+#' # a b c d
+#' # 2 2 2 3
+#' # 3 3 3 4
+#' }
+setMethod("dapplyCollect",
+ signature(x = "SparkDataFrame", func = "function"),
+ function(x, func) {
+ df <- dapplyInternal(x, func, NULL)
+
+ content <- callJMethod(df@sdf, "collect")
+ # content is a list of items of struct type. Each item has a single field
+ # which is a serialized data.frame corresponds to one partition of the
+ # SparkDataFrame.
+ ldfs <- lapply(content, function(x) { unserialize(x[[1]]) })
+ ldf <- do.call(rbind, ldfs)
+ row.names(ldf) <- NULL
+ ldf
})
############################## RDD Map Functions ##################################
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 8563be1e64..ed76ad6b73 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -450,6 +450,10 @@ setGeneric("covar_pop", function(col1, col2) {standardGeneric("covar_pop") })
#' @export
setGeneric("dapply", function(x, func, schema) { standardGeneric("dapply") })
+#' @rdname dapply
+#' @export
+setGeneric("dapplyCollect", function(x, func) { standardGeneric("dapplyCollect") })
+
#' @rdname summary
#' @export
setGeneric("describe", function(x, col, ...) { standardGeneric("describe") })
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index 0f67bc2e33..6a99b43e5a 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -2043,7 +2043,7 @@ test_that("Histogram", {
expect_equal(histogram(df, "x")$counts, c(4, 0, 0, 0, 0, 0, 0, 0, 0, 1))
})
-test_that("dapply() on a DataFrame", {
+test_that("dapply() and dapplyCollect() on a DataFrame", {
df <- createDataFrame (
sqlContext,
list(list(1L, 1, "1"), list(2L, 2, "2"), list(3L, 3, "3")),
@@ -2053,6 +2053,8 @@ test_that("dapply() on a DataFrame", {
result <- collect(df1)
expect_identical(ldf, result)
+ result <- dapplyCollect(df, function(x) { x })
+ expect_identical(ldf, result)
# Filter and add a column
schema <- structType(structField("a", "integer"), structField("b", "double"),
@@ -2070,6 +2072,16 @@ test_that("dapply() on a DataFrame", {
rownames(expected) <- NULL
expect_identical(expected, result)
+ result <- dapplyCollect(
+ df,
+ function(x) {
+ y <- x[x$a > 1, ]
+ y <- cbind(y, y$a + 1L)
+ })
+ expected1 <- expected
+ names(expected1) <- names(result)
+ expect_identical(expected1, result)
+
# Remove the added column
df2 <- dapply(
df1,
@@ -2080,6 +2092,13 @@ test_that("dapply() on a DataFrame", {
result <- collect(df2)
expected <- expected[, c("a", "b", "c")]
expect_identical(expected, result)
+
+ result <- dapplyCollect(
+ df1,
+ function(x) {
+ x[, c("a", "b", "c")]
+ })
+ expect_identical(expected, result)
})
test_that("repartition by columns on DataFrame", {