aboutsummaryrefslogtreecommitdiff
path: root/R
diff options
context:
space:
mode:
authorSun Rui <rui.sun@intel.com>2015-10-13 22:31:23 -0700
committerShivaram Venkataraman <shivaram@cs.berkeley.edu>2015-10-13 22:31:23 -0700
commit390b22fad69a33eb6daee25b6b858a2e768670a5 (patch)
tree34a5ada11ef2823c4b082b604b3a9d903647e325 /R
parent8b32885704502ab2a715cf5142d7517181074428 (diff)
downloadspark-390b22fad69a33eb6daee25b6b858a2e768670a5.tar.gz
spark-390b22fad69a33eb6daee25b6b858a2e768670a5.tar.bz2
spark-390b22fad69a33eb6daee25b6b858a2e768670a5.zip
[SPARK-10996] [SPARKR] Implement sampleBy() in DataFrameStatFunctions.
Author: Sun Rui <rui.sun@intel.com> Closes #9023 from sun-rui/SPARK-10996.
Diffstat (limited to 'R')
-rw-r--r--R/pkg/NAMESPACE3
-rw-r--r--R/pkg/R/DataFrame.R14
-rw-r--r--R/pkg/R/generics.R6
-rw-r--r--R/pkg/R/sparkR.R12
-rw-r--r--R/pkg/R/stats.R32
-rw-r--r--R/pkg/R/utils.R18
-rw-r--r--R/pkg/inst/tests/test_sparkSQL.R10
7 files changed, 76 insertions, 19 deletions
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index ed9cd94e03..52f7a0106a 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -65,6 +65,7 @@ exportMethods("arrange",
"repartition",
"sample",
"sample_frac",
+ "sampleBy",
"saveAsParquetFile",
"saveAsTable",
"saveDF",
@@ -254,4 +255,4 @@ export("structField",
"structType.structField",
"print.structType")
-export("as.data.frame") \ No newline at end of file
+export("as.data.frame")
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index b7f5f978eb..993be82a47 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -1831,17 +1831,15 @@ setMethod("fillna",
if (length(colNames) == 0 || !all(colNames != "")) {
stop("value should be an a named list with each name being a column name.")
}
-
- # Convert to the named list to an environment to be passed to JVM
- valueMap <- new.env()
- for (col in colNames) {
- # Check each item in the named list is of valid type
- v <- value[[col]]
+ # Check each item in the named list is of valid type
+ lapply(value, function(v) {
if (!(class(v) %in% c("integer", "numeric", "character"))) {
stop("Each item in value should be an integer, numeric or charactor.")
}
- valueMap[[col]] <- v
- }
+ })
+
+ # Convert to the named list to an environment to be passed to JVM
+ valueMap <- convertNamedListToEnv(value)
# When value is a named list, caller is expected not to pass in cols
if (!is.null(cols)) {
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index c106a00245..4a419f785e 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -509,6 +509,10 @@ setGeneric("sample",
setGeneric("sample_frac",
function(x, withReplacement, fraction, seed) { standardGeneric("sample_frac") })
+#' @rdname statfunctions
+#' @export
+setGeneric("sampleBy", function(x, col, fractions, seed) { standardGeneric("sampleBy") })
+
#' @rdname saveAsParquetFile
#' @export
setGeneric("saveAsParquetFile", function(x, path) { standardGeneric("saveAsParquetFile") })
@@ -1006,4 +1010,4 @@ setGeneric("as.data.frame")
#' @rdname attach
#' @export
-setGeneric("attach") \ No newline at end of file
+setGeneric("attach")
diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R
index cc47110f54..9cf2f1a361 100644
--- a/R/pkg/R/sparkR.R
+++ b/R/pkg/R/sparkR.R
@@ -163,19 +163,13 @@ sparkR.init <- function(
sparkHome <- suppressWarnings(normalizePath(sparkHome))
}
- sparkEnvirMap <- new.env()
- for (varname in names(sparkEnvir)) {
- sparkEnvirMap[[varname]] <- sparkEnvir[[varname]]
- }
+ sparkEnvirMap <- convertNamedListToEnv(sparkEnvir)
- sparkExecutorEnvMap <- new.env()
- if (!any(names(sparkExecutorEnv) == "LD_LIBRARY_PATH")) {
+ sparkExecutorEnvMap <- convertNamedListToEnv(sparkExecutorEnv)
+ if(is.null(sparkExecutorEnvMap$LD_LIBRARY_PATH)) {
sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <-
paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH"))
}
- for (varname in names(sparkExecutorEnv)) {
- sparkExecutorEnvMap[[varname]] <- sparkExecutorEnv[[varname]]
- }
nonEmptyJars <- Filter(function(x) { x != "" }, jars)
localJarPaths <- lapply(nonEmptyJars,
diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R
index 4928cf4d43..f79329b115 100644
--- a/R/pkg/R/stats.R
+++ b/R/pkg/R/stats.R
@@ -127,3 +127,35 @@ setMethod("freqItems", signature(x = "DataFrame", cols = "character"),
sct <- callJMethod(statFunctions, "freqItems", as.list(cols), support)
collect(dataFrame(sct))
})
+
+#' sampleBy
+#'
+#' Returns a stratified sample without replacement based on the fraction given on each stratum.
+#'
+#' @param x A SparkSQL DataFrame
+#' @param col column that defines strata
+#' @param fractions A named list giving sampling fraction for each stratum. If a stratum is
+#' not specified, we treat its fraction as zero.
+#' @param seed random seed
+#' @return A new DataFrame that represents the stratified sample
+#'
+#' @rdname statfunctions
+#' @name sampleBy
+#' @export
+#' @examples
+#'\dontrun{
+#' df <- jsonFile(sqlContext, "/path/to/file.json")
+#' sample <- sampleBy(df, "key", fractions, 36)
+#' }
+setMethod("sampleBy",
+ signature(x = "DataFrame", col = "character",
+ fractions = "list", seed = "numeric"),
+ function(x, col, fractions, seed) {
+ fractionsEnv <- convertNamedListToEnv(fractions)
+
+ statFunctions <- callJMethod(x@sdf, "stat")
+ # Seed is expected to be Long on Scala side, here convert it to an integer
+ # due to SerDe limitation now.
+ sdf <- callJMethod(statFunctions, "sampleBy", col, fractionsEnv, as.integer(seed))
+ dataFrame(sdf)
+ })
diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R
index 94f16c7ac5..0b9e2957fe 100644
--- a/R/pkg/R/utils.R
+++ b/R/pkg/R/utils.R
@@ -605,3 +605,21 @@ structToList <- function(struct) {
class(struct) <- "list"
struct
}
+
+# Convert a named list to an environment to be passed to JVM
+convertNamedListToEnv <- function(namedList) {
+ # Make sure each item in the list has a name
+ names <- names(namedList)
+ stopifnot(
+ if (is.null(names)) {
+ length(namedList) == 0
+ } else {
+ !any(is.na(names))
+ })
+
+ env <- new.env()
+ for (name in names) {
+ env[[name]] <- namedList[[name]]
+ }
+ env
+}
diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R
index 46cab7646d..e1b42b0804 100644
--- a/R/pkg/inst/tests/test_sparkSQL.R
+++ b/R/pkg/inst/tests/test_sparkSQL.R
@@ -1416,6 +1416,16 @@ test_that("freqItems() on a DataFrame", {
expect_identical(result[[2]], list(list(-1, -99)))
})
+test_that("sampleBy() on a DataFrame", {
+ l <- lapply(c(0:99), function(i) { as.character(i %% 3) })
+ df <- createDataFrame(sqlContext, l, "key")
+ fractions <- list("0" = 0.1, "1" = 0.2)
+ sample <- sampleBy(df, "key", fractions, 0)
+ result <- collect(orderBy(count(groupBy(sample, "key")), "key"))
+ expect_identical(as.list(result[1, ]), list(key = "0", count = 2))
+ expect_identical(as.list(result[2, ]), list(key = "1", count = 10))
+})
+
test_that("SQL error message is returned from JVM", {
retError <- tryCatch(sql(sqlContext, "select * from blah"), error = function(e) e)
expect_equal(grepl("Table Not Found: blah", retError), TRUE)