aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--R/pkg/NAMESPACE1
-rw-r--r--R/pkg/R/DataFrame.R37
-rw-r--r--R/pkg/R/generics.R4
-rw-r--r--R/pkg/inst/tests/testthat/test_sparkSQL.R18
4 files changed, 60 insertions, 0 deletions
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 5db43ae649..9412ec3f9e 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -81,6 +81,7 @@ exportMethods("arrange",
"orderBy",
"persist",
"printSchema",
+ "randomSplit",
"rbind",
"registerTempTable",
"rename",
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 231e4f0f4e..4e044565f4 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -2934,3 +2934,40 @@ setMethod("write.jdbc",
write <- callJMethod(write, "mode", jmode)
invisible(callJMethod(write, "jdbc", url, tableName, jprops))
})
+
+#' randomSplit
+#'
+#' Return a list of randomly split dataframes with the provided weights.
+#'
+#' @param x A SparkDataFrame
+#' @param weights A vector of weights for splits, will be normalized if they don't sum to 1
+#' @param seed A seed to use for random split
+#'
+#' @family SparkDataFrame functions
+#' @rdname randomSplit
+#' @name randomSplit
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlContext <- sparkRSQL.init(sc)
+#' df <- createDataFrame(data.frame(id = 1:1000))
+#' df_list <- randomSplit(df, c(2, 3, 5), 0)
+#' # df_list contains 3 SparkDataFrames with each having about 200, 300 and 500 rows respectively
+#' sapply(df_list, count)
+#' }
+#' @note since 2.0.0
+setMethod("randomSplit",
+ signature(x = "SparkDataFrame", weights = "numeric"),
+ function(x, weights, seed) {
+ if (!all(sapply(weights, function(c) { c >= 0 }))) {
+ stop("all weight values should not be negative")
+ }
+ normalized_list <- as.list(weights / sum(weights))
+ if (!missing(seed)) {
+ sdfs <- callJMethod(x@sdf, "randomSplit", normalized_list, as.integer(seed))
+ } else {
+ sdfs <- callJMethod(x@sdf, "randomSplit", normalized_list)
+ }
+ sapply(sdfs, dataFrame)
+ })
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 594bf2eadc..6e754afab6 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -679,6 +679,10 @@ setGeneric("withColumnRenamed",
#' @export
setGeneric("write.df", function(df, path, ...) { standardGeneric("write.df") })
+#' @rdname randomSplit
+#' @export
+setGeneric("randomSplit", function(x, weights, seed) { standardGeneric("randomSplit") })
+
###################### Column Methods ##########################
#' @rdname column
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index 7aa03a9048..607bd9c12f 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -2280,6 +2280,24 @@ test_that("createDataFrame sqlContext parameter backward compatibility", {
expect_equal(collect(before), collect(after))
})
+test_that("randomSplit", {
+ num <- 4000
+ df <- createDataFrame(data.frame(id = 1:num))
+
+ weights <- c(2, 3, 5)
+ df_list <- randomSplit(df, weights)
+ expect_equal(length(weights), length(df_list))
+ counts <- sapply(df_list, count)
+ expect_equal(num, sum(counts))
+ expect_true(all(sapply(abs(counts / num - weights / sum(weights)), function(e) { e < 0.05 })))
+
+ df_list <- randomSplit(df, weights, 0)
+ expect_equal(length(weights), length(df_list))
+ counts <- sapply(df_list, count)
+ expect_equal(num, sum(counts))
+ expect_true(all(sapply(abs(counts / num - weights / sum(weights)), function(e) { e < 0.05 })))
+})
+
unlink(parquetPath)
unlink(jsonPath)
unlink(jsonPathNa)