aboutsummaryrefslogtreecommitdiff
path: root/R
diff options
context:
space:
mode:
authorgatorsmile <gatorsmile@gmail.com>2015-12-11 20:55:16 -0800
committerShivaram Venkataraman <shivaram@cs.berkeley.edu>2015-12-11 20:55:16 -0800
commit1e3526c2d3de723225024fedd45753b556e18fc6 (patch)
tree49655ead89d5782307e76df088a6fceef36ef278 /R
parent1e799d617a28cd0eaa8f22d103ea8248c4655ae5 (diff)
downloadspark-1e3526c2d3de723225024fedd45753b556e18fc6.tar.gz
spark-1e3526c2d3de723225024fedd45753b556e18fc6.tar.bz2
spark-1e3526c2d3de723225024fedd45753b556e18fc6.zip
[SPARK-12158][SPARKR][SQL] Fix 'sample' functions that break R unit test cases
The existing sample functions miss the parameter `seed`, however, the corresponding function interface in `generics` has such a parameter. Thus, although the function caller can call the function with the 'seed', we are not using the value. This could cause SparkR unit tests failed. For example, I hit it in another PR: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/47213/consoleFull Author: gatorsmile <gatorsmile@gmail.com> Closes #10160 from gatorsmile/sampleR.
Diffstat (limited to 'R')
-rw-r--r--R/pkg/R/DataFrame.R17
-rw-r--r--R/pkg/inst/tests/testthat/test_sparkSQL.R4
2 files changed, 15 insertions, 6 deletions
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 975b058c0a..764597d1e3 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -662,6 +662,7 @@ setMethod("unique",
#' @param x A SparkSQL DataFrame
#' @param withReplacement Sampling with replacement or not
#' @param fraction The (rough) sample target fraction
+#' @param seed Randomness seed value
#'
#' @family DataFrame functions
#' @rdname sample
@@ -677,13 +678,17 @@ setMethod("unique",
#' collect(sample(df, TRUE, 0.5))
#'}
setMethod("sample",
- # TODO : Figure out how to send integer as java.lang.Long to JVM so
- # we can send seed as an argument through callJMethod
signature(x = "DataFrame", withReplacement = "logical",
fraction = "numeric"),
- function(x, withReplacement, fraction) {
+ function(x, withReplacement, fraction, seed) {
if (fraction < 0.0) stop(cat("Negative fraction value:", fraction))
- sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction)
+ if (!missing(seed)) {
+ # TODO : Figure out how to send integer as java.lang.Long to JVM so
+ # we can send seed as an argument through callJMethod
+ sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction, as.integer(seed))
+ } else {
+ sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction)
+ }
dataFrame(sdf)
})
@@ -692,8 +697,8 @@ setMethod("sample",
setMethod("sample_frac",
signature(x = "DataFrame", withReplacement = "logical",
fraction = "numeric"),
- function(x, withReplacement, fraction) {
- sample(x, withReplacement, fraction)
+ function(x, withReplacement, fraction, seed) {
+ sample(x, withReplacement, fraction, seed)
})
#' nrow
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index ed9b2c9d4d..071fd310fd 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -724,6 +724,10 @@ test_that("sample on a DataFrame", {
sampled2 <- sample(df, FALSE, 0.1, 0) # set seed for predictable result
expect_true(count(sampled2) < 3)
+ count1 <- count(sample(df, FALSE, 0.1, 0))
+ count2 <- count(sample(df, FALSE, 0.1, 0))
+ expect_equal(count1, count2)
+
# Also test sample_frac
sampled3 <- sample_frac(df, FALSE, 0.1, 0) # set seed for predictable result
expect_true(count(sampled3) < 3)