diff options
author | gatorsmile <gatorsmile@gmail.com> | 2015-12-11 20:55:16 -0800 |
---|---|---|
committer | Shivaram Venkataraman <shivaram@cs.berkeley.edu> | 2015-12-11 20:55:16 -0800 |
commit | 1e3526c2d3de723225024fedd45753b556e18fc6 (patch) | |
tree | 49655ead89d5782307e76df088a6fceef36ef278 /R/pkg | |
parent | 1e799d617a28cd0eaa8f22d103ea8248c4655ae5 (diff) | |
download | spark-1e3526c2d3de723225024fedd45753b556e18fc6.tar.gz spark-1e3526c2d3de723225024fedd45753b556e18fc6.tar.bz2 spark-1e3526c2d3de723225024fedd45753b556e18fc6.zip |
[SPARK-12158][SPARKR][SQL] Fix 'sample' functions that break R unit test cases
The existing sample functions miss the parameter `seed`, however, the corresponding function interface in `generics` has such a parameter. Thus, although the function caller can call the function with the 'seed', we are not using the value.
This could cause SparkR unit tests failed. For example, I hit it in another PR:
https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/47213/consoleFull
Author: gatorsmile <gatorsmile@gmail.com>
Closes #10160 from gatorsmile/sampleR.
Diffstat (limited to 'R/pkg')
-rw-r--r-- | R/pkg/R/DataFrame.R | 17 | ||||
-rw-r--r-- | R/pkg/inst/tests/testthat/test_sparkSQL.R | 4 |
2 files changed, 15 insertions, 6 deletions
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 975b058c0a..764597d1e3 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -662,6 +662,7 @@ setMethod("unique", #' @param x A SparkSQL DataFrame #' @param withReplacement Sampling with replacement or not #' @param fraction The (rough) sample target fraction +#' @param seed Randomness seed value #' #' @family DataFrame functions #' @rdname sample @@ -677,13 +678,17 @@ setMethod("unique", #' collect(sample(df, TRUE, 0.5)) #'} setMethod("sample", - # TODO : Figure out how to send integer as java.lang.Long to JVM so - # we can send seed as an argument through callJMethod signature(x = "DataFrame", withReplacement = "logical", fraction = "numeric"), - function(x, withReplacement, fraction) { + function(x, withReplacement, fraction, seed) { if (fraction < 0.0) stop(cat("Negative fraction value:", fraction)) - sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction) + if (!missing(seed)) { + # TODO : Figure out how to send integer as java.lang.Long to JVM so + # we can send seed as an argument through callJMethod + sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction, as.integer(seed)) + } else { + sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction) + } dataFrame(sdf) }) @@ -692,8 +697,8 @@ setMethod("sample", setMethod("sample_frac", signature(x = "DataFrame", withReplacement = "logical", fraction = "numeric"), - function(x, withReplacement, fraction) { - sample(x, withReplacement, fraction) + function(x, withReplacement, fraction, seed) { + sample(x, withReplacement, fraction, seed) }) #' nrow diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index ed9b2c9d4d..071fd310fd 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -724,6 +724,10 @@ test_that("sample on a DataFrame", { sampled2 <- sample(df, FALSE, 0.1, 0) # set seed for predictable result expect_true(count(sampled2) < 3) + count1 <- count(sample(df, FALSE, 0.1, 0)) + count2 <- count(sample(df, FALSE, 0.1, 0)) + expect_equal(count1, count2) + # Also test sample_frac sampled3 <- sample_frac(df, FALSE, 0.1, 0) # set seed for predictable result expect_true(count(sampled3) < 3) |