aboutsummaryrefslogtreecommitdiff
path: root/R
diff options
context:
space:
mode:
authorSun Rui <rui.sun@intel.com>2016-04-29 16:41:07 -0700
committerShivaram Venkataraman <shivaram@cs.berkeley.edu>2016-04-29 16:41:07 -0700
commit4ae9fe091c2cb8388c581093d62d3deaef40993e (patch)
treefd84ce605c0ea8bd9d0b2e307119bd5d8651c9f5 /R
parentd78fbcc3cc9c379b4a548ebc816c6f71cc71a16e (diff)
downloadspark-4ae9fe091c2cb8388c581093d62d3deaef40993e.tar.gz
spark-4ae9fe091c2cb8388c581093d62d3deaef40993e.tar.bz2
spark-4ae9fe091c2cb8388c581093d62d3deaef40993e.zip
[SPARK-12919][SPARKR] Implement dapply() on DataFrame in SparkR.
## What changes were proposed in this pull request? dapply() applies an R function on each partition of a DataFrame and returns a new DataFrame. The function signature is: dapply(df, function(localDF) {}, schema = NULL) R function input: local data.frame from the partition on local node R function output: local data.frame Schema specifies the Row format of the resulting DataFrame. It must match the R function's output. If schema is not specified, each partition of the result DataFrame will be serialized in R into a single byte array. Such resulting DataFrame can be processed by successive calls to dapply(). ## How was this patch tested? SparkR unit tests. Author: Sun Rui <rui.sun@intel.com> Author: Sun Rui <sunrui2016@gmail.com> Closes #12493 from sun-rui/SPARK-12919.
Diffstat (limited to 'R')
-rw-r--r--R/pkg/NAMESPACE1
-rw-r--r--R/pkg/R/DataFrame.R61
-rw-r--r--R/pkg/R/generics.R4
-rw-r--r--R/pkg/inst/tests/testthat/test_sparkSQL.R40
-rw-r--r--R/pkg/inst/worker/worker.R36
5 files changed, 141 insertions, 1 deletions
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 002e469efb..647db22747 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -45,6 +45,7 @@ exportMethods("arrange",
"covar_samp",
"covar_pop",
"crosstab",
+ "dapply",
"describe",
"dim",
"distinct",
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index a741fdf709..9e30fa0dbf 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -21,6 +21,7 @@
NULL
setOldClass("jobj")
+setOldClass("structType")
#' @title S4 class that represents a SparkDataFrame
#' @description DataFrames can be created using functions like \link{createDataFrame},
@@ -1125,6 +1126,66 @@ setMethod("summarize",
agg(x, ...)
})
+#' dapply
+#'
+#' Apply a function to each partition of a DataFrame.
+#'
+#' @param x A SparkDataFrame
+#' @param func A function to be applied to each partition of the SparkDataFrame.
+#' func should have only one parameter, to which a data.frame corresponds
+#' to each partition will be passed.
+#' The output of func should be a data.frame.
+#' @param schema The schema of the resulting DataFrame after the function is applied.
+#' It must match the output of func.
+#' @family SparkDataFrame functions
+#' @rdname dapply
+#' @name dapply
+#' @export
+#' @examples
+#' \dontrun{
+#' df <- createDataFrame (sqlContext, iris)
+#' df1 <- dapply(df, function(x) { x }, schema(df))
+#' collect(df1)
+#'
+#' # filter and add a column
+#' df <- createDataFrame (
+#' sqlContext,
+#' list(list(1L, 1, "1"), list(2L, 2, "2"), list(3L, 3, "3")),
+#' c("a", "b", "c"))
+#' schema <- structType(structField("a", "integer"), structField("b", "double"),
+#' structField("c", "string"), structField("d", "integer"))
+#' df1 <- dapply(
+#' df,
+#' function(x) {
+#' y <- x[x[1] > 1, ]
+#' y <- cbind(y, y[1] + 1L)
+#' },
+#' schema)
+#' collect(df1)
+#' # the result
+#' # a b c d
+#' # 1 2 2 2 3
+#' # 2 3 3 3 4
+#' }
+setMethod("dapply",
+ signature(x = "SparkDataFrame", func = "function", schema = "structType"),
+ function(x, func, schema) {
+ packageNamesArr <- serialize(.sparkREnv[[".packages"]],
+ connection = NULL)
+
+ broadcastArr <- lapply(ls(.broadcastNames),
+ function(name) { get(name, .broadcastNames) })
+
+ sdf <- callJStatic(
+ "org.apache.spark.sql.api.r.SQLUtils",
+ "dapply",
+ x@sdf,
+ serialize(cleanClosure(func), connection = NULL),
+ packageNamesArr,
+ broadcastArr,
+ schema$jobj)
+ dataFrame(sdf)
+ })
############################## RDD Map Functions ##################################
# All of the following functions mirror the existing RDD map functions, #
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 62907118ef..3db8925730 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -446,6 +446,10 @@ setGeneric("covar_samp", function(col1, col2) {standardGeneric("covar_samp") })
#' @export
setGeneric("covar_pop", function(col1, col2) {standardGeneric("covar_pop") })
+#' @rdname dapply
+#' @export
+setGeneric("dapply", function(x, func, schema) { standardGeneric("dapply") })
+
#' @rdname summary
#' @export
setGeneric("describe", function(x, col, ...) { standardGeneric("describe") })
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index 7058265ea3..5cf9dc405b 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -2043,6 +2043,46 @@ test_that("Histogram", {
df <- as.DataFrame(sqlContext, data.frame(x = c(1, 2, 3, 4, 100)))
expect_equal(histogram(df, "x")$counts, c(4, 0, 0, 0, 0, 0, 0, 0, 0, 1))
})
+
+test_that("dapply() on a DataFrame", {
+ df <- createDataFrame (
+ sqlContext,
+ list(list(1L, 1, "1"), list(2L, 2, "2"), list(3L, 3, "3")),
+ c("a", "b", "c"))
+ ldf <- collect(df)
+ df1 <- dapply(df, function(x) { x }, schema(df))
+ result <- collect(df1)
+ expect_identical(ldf, result)
+
+
+ # Filter and add a column
+ schema <- structType(structField("a", "integer"), structField("b", "double"),
+ structField("c", "string"), structField("d", "integer"))
+ df1 <- dapply(
+ df,
+ function(x) {
+ y <- x[x$a > 1, ]
+ y <- cbind(y, y$a + 1L)
+ },
+ schema)
+ result <- collect(df1)
+ expected <- ldf[ldf$a > 1, ]
+ expected$d <- expected$a + 1L
+ rownames(expected) <- NULL
+ expect_identical(expected, result)
+
+ # Remove the added column
+ df2 <- dapply(
+ df1,
+ function(x) {
+ x[, c("a", "b", "c")]
+ },
+ schema(df))
+ result <- collect(df2)
+ expected <- expected[, c("a", "b", "c")]
+ expect_identical(expected, result)
+})
+
unlink(parquetPath)
unlink(jsonPath)
unlink(jsonPathNa)
diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R
index b6784dbae3..40cda0c5ef 100644
--- a/R/pkg/inst/worker/worker.R
+++ b/R/pkg/inst/worker/worker.R
@@ -84,6 +84,13 @@ broadcastElap <- elapsedSecs()
# as number of partitions to create.
numPartitions <- SparkR:::readInt(inputCon)
+isDataFrame <- as.logical(SparkR:::readInt(inputCon))
+
+# If isDataFrame, then read column names
+if (isDataFrame) {
+ colNames <- SparkR:::readObject(inputCon)
+}
+
isEmpty <- SparkR:::readInt(inputCon)
if (isEmpty != 0) {
@@ -100,7 +107,34 @@ if (isEmpty != 0) {
# Timing reading input data for execution
inputElap <- elapsedSecs()
- output <- computeFunc(partition, data)
+ if (isDataFrame) {
+ if (deserializer == "row") {
+ # Transform the list of rows into a data.frame
+ # Note that the optional argument stringsAsFactors for rbind is
+ # available since R 3.2.4. So we set the global option here.
+ oldOpt <- getOption("stringsAsFactors")
+ options(stringsAsFactors = FALSE)
+ data <- do.call(rbind.data.frame, data)
+ options(stringsAsFactors = oldOpt)
+
+ names(data) <- colNames
+ } else {
+ # Check to see if data is a valid data.frame
+ stopifnot(deserializer == "byte")
+ stopifnot(class(data) == "data.frame")
+ }
+ output <- computeFunc(data)
+ if (serializer == "row") {
+ # Transform the result data.frame back to a list of rows
+ output <- split(output, seq(nrow(output)))
+ } else {
+ # Serialize the ouput to a byte array
+ stopifnot(serializer == "byte")
+ }
+ } else {
+ output <- computeFunc(partition, data)
+ }
+
# Timing computing
computeElap <- elapsedSecs()