diff options
Diffstat (limited to 'R/pkg/inst/tests')
-rw-r--r-- | R/pkg/inst/tests/testthat/jarTest.R | 4 | ||||
-rw-r--r-- | R/pkg/inst/tests/testthat/packageInAJarTest.R | 4 | ||||
-rw-r--r-- | R/pkg/inst/tests/testthat/test_Serde.R | 2 | ||||
-rw-r--r-- | R/pkg/inst/tests/testthat/test_binaryFile.R | 3 | ||||
-rw-r--r-- | R/pkg/inst/tests/testthat/test_binary_function.R | 3 | ||||
-rw-r--r-- | R/pkg/inst/tests/testthat/test_broadcast.R | 3 | ||||
-rw-r--r-- | R/pkg/inst/tests/testthat/test_context.R | 41 | ||||
-rw-r--r-- | R/pkg/inst/tests/testthat/test_includePackage.R | 3 | ||||
-rw-r--r-- | R/pkg/inst/tests/testthat/test_mllib.R | 5 | ||||
-rw-r--r-- | R/pkg/inst/tests/testthat/test_parallelize_collect.R | 3 | ||||
-rw-r--r-- | R/pkg/inst/tests/testthat/test_rdd.R | 3 | ||||
-rw-r--r-- | R/pkg/inst/tests/testthat/test_shuffle.R | 3 | ||||
-rw-r--r-- | R/pkg/inst/tests/testthat/test_sparkSQL.R | 86 | ||||
-rw-r--r-- | R/pkg/inst/tests/testthat/test_take.R | 17 | ||||
-rw-r--r-- | R/pkg/inst/tests/testthat/test_textFile.R | 3 | ||||
-rw-r--r-- | R/pkg/inst/tests/testthat/test_utils.R | 16 |
16 files changed, 138 insertions, 61 deletions
diff --git a/R/pkg/inst/tests/testthat/jarTest.R b/R/pkg/inst/tests/testthat/jarTest.R index d68bb20950..84e4845f18 100644 --- a/R/pkg/inst/tests/testthat/jarTest.R +++ b/R/pkg/inst/tests/testthat/jarTest.R @@ -16,7 +16,7 @@ # library(SparkR) -sc <- sparkR.init() +sparkSession <- sparkR.session() helloTest <- SparkR:::callJStatic("sparkR.test.hello", "helloWorld", @@ -27,6 +27,6 @@ basicFunction <- SparkR:::callJStatic("sparkR.test.basicFunction", 2L, 2L) -sparkR.stop() +sparkR.session.stop() output <- c(helloTest, basicFunction) writeLines(output) diff --git a/R/pkg/inst/tests/testthat/packageInAJarTest.R b/R/pkg/inst/tests/testthat/packageInAJarTest.R index c26b28b78d..940c91f376 100644 --- a/R/pkg/inst/tests/testthat/packageInAJarTest.R +++ b/R/pkg/inst/tests/testthat/packageInAJarTest.R @@ -17,13 +17,13 @@ library(SparkR) library(sparkPackageTest) -sc <- sparkR.init() +sparkSession <- sparkR.session() run1 <- myfunc(5L) run2 <- myfunc(-4L) -sparkR.stop() +sparkR.session.stop() if (run1 != 6) quit(save = "no", status = 1) diff --git a/R/pkg/inst/tests/testthat/test_Serde.R b/R/pkg/inst/tests/testthat/test_Serde.R index dddce54d70..96fb6dda26 100644 --- a/R/pkg/inst/tests/testthat/test_Serde.R +++ b/R/pkg/inst/tests/testthat/test_Serde.R @@ -17,7 +17,7 @@ context("SerDe functionality") -sc <- sparkR.init() +sparkSession <- sparkR.session() test_that("SerDe of primitive types", { x <- callJStatic("SparkRHandler", "echo", 1L) diff --git a/R/pkg/inst/tests/testthat/test_binaryFile.R b/R/pkg/inst/tests/testthat/test_binaryFile.R index 976a7558a8..b69f017de8 100644 --- a/R/pkg/inst/tests/testthat/test_binaryFile.R +++ b/R/pkg/inst/tests/testthat/test_binaryFile.R @@ -18,7 +18,8 @@ context("functions on binary files") # JavaSparkContext handle -sc <- sparkR.init() +sparkSession <- sparkR.session() +sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) mockFile <- c("Spark is pretty.", "Spark is awesome.") diff --git a/R/pkg/inst/tests/testthat/test_binary_function.R b/R/pkg/inst/tests/testthat/test_binary_function.R index 7bad4d2a7e..6f51d20687 100644 --- a/R/pkg/inst/tests/testthat/test_binary_function.R +++ b/R/pkg/inst/tests/testthat/test_binary_function.R @@ -18,7 +18,8 @@ context("binary functions") # JavaSparkContext handle -sc <- sparkR.init() +sparkSession <- sparkR.session() +sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Data nums <- 1:10 diff --git a/R/pkg/inst/tests/testthat/test_broadcast.R b/R/pkg/inst/tests/testthat/test_broadcast.R index 8be6efc3db..cf1d432771 100644 --- a/R/pkg/inst/tests/testthat/test_broadcast.R +++ b/R/pkg/inst/tests/testthat/test_broadcast.R @@ -18,7 +18,8 @@ context("broadcast variables") # JavaSparkContext handle -sc <- sparkR.init() +sparkSession <- sparkR.session() +sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Partitioned data nums <- 1:2 diff --git a/R/pkg/inst/tests/testthat/test_context.R b/R/pkg/inst/tests/testthat/test_context.R index 126484c995..f123187adf 100644 --- a/R/pkg/inst/tests/testthat/test_context.R +++ b/R/pkg/inst/tests/testthat/test_context.R @@ -56,31 +56,33 @@ test_that("Check masked functions", { test_that("repeatedly starting and stopping SparkR", { for (i in 1:4) { - sc <- sparkR.init() + sc <- suppressWarnings(sparkR.init()) rdd <- parallelize(sc, 1:20, 2L) expect_equal(count(rdd), 20) - sparkR.stop() + suppressWarnings(sparkR.stop()) } }) -test_that("repeatedly starting and stopping SparkR SQL", { - for (i in 1:4) { - sc <- sparkR.init() - sqlContext <- sparkRSQL.init(sc) - df <- createDataFrame(data.frame(a = 1:20)) - expect_equal(count(df), 20) - sparkR.stop() - } -}) +# Does not work consistently even with Hive off +# nolint start +# test_that("repeatedly starting and stopping SparkR", { +# for (i in 1:4) { +# sparkR.session(enableHiveSupport = FALSE) +# df <- createDataFrame(data.frame(dummy=1:i)) +# expect_equal(count(df), i) +# sparkR.session.stop() +# Sys.sleep(5) # Need more time to shutdown Hive metastore +# } +# }) +# nolint end test_that("rdd GC across sparkR.stop", { - sparkR.stop() - sc <- sparkR.init() # sc should get id 0 + sc <- sparkR.sparkContext() # sc should get id 0 rdd1 <- parallelize(sc, 1:20, 2L) # rdd1 should get id 1 rdd2 <- parallelize(sc, 1:10, 2L) # rdd2 should get id 2 - sparkR.stop() + sparkR.session.stop() - sc <- sparkR.init() # sc should get id 0 again + sc <- sparkR.sparkContext() # sc should get id 0 again # GC rdd1 before creating rdd3 and rdd2 after rm(rdd1) @@ -97,15 +99,17 @@ test_that("rdd GC across sparkR.stop", { }) test_that("job group functions can be called", { - sc <- sparkR.init() + sc <- sparkR.sparkContext() setJobGroup(sc, "groupId", "job description", TRUE) cancelJobGroup(sc, "groupId") clearJobGroup(sc) + sparkR.session.stop() }) test_that("utility function can be called", { - sc <- sparkR.init() + sc <- sparkR.sparkContext() setLogLevel(sc, "ERROR") + sparkR.session.stop() }) test_that("getClientModeSparkSubmitOpts() returns spark-submit args from whitelist", { @@ -156,7 +160,8 @@ test_that("sparkJars sparkPackages as comma-separated strings", { }) test_that("spark.lapply should perform simple transforms", { - sc <- sparkR.init() + sc <- sparkR.sparkContext() doubled <- spark.lapply(sc, 1:10, function(x) { 2 * x }) expect_equal(doubled, as.list(2 * 1:10)) + sparkR.session.stop() }) diff --git a/R/pkg/inst/tests/testthat/test_includePackage.R b/R/pkg/inst/tests/testthat/test_includePackage.R index 8152b448d0..d6a3766539 100644 --- a/R/pkg/inst/tests/testthat/test_includePackage.R +++ b/R/pkg/inst/tests/testthat/test_includePackage.R @@ -18,7 +18,8 @@ context("include R packages") # JavaSparkContext handle -sc <- sparkR.init() +sparkSession <- sparkR.session() +sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Partitioned data nums <- 1:2 diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 59ef15c1e9..c8c5ef2476 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -20,10 +20,7 @@ library(testthat) context("MLlib functions") # Tests for MLlib functions in SparkR - -sc <- sparkR.init() - -sqlContext <- sparkRSQL.init(sc) +sparkSession <- sparkR.session() test_that("formula of spark.glm", { training <- suppressWarnings(createDataFrame(iris)) diff --git a/R/pkg/inst/tests/testthat/test_parallelize_collect.R b/R/pkg/inst/tests/testthat/test_parallelize_collect.R index 2552127cc5..f79a8a70aa 100644 --- a/R/pkg/inst/tests/testthat/test_parallelize_collect.R +++ b/R/pkg/inst/tests/testthat/test_parallelize_collect.R @@ -33,7 +33,8 @@ numPairs <- list(list(1, 1), list(1, 2), list(2, 2), list(2, 3)) strPairs <- list(list(strList, strList), list(strList, strList)) # JavaSparkContext handle -jsc <- sparkR.init() +sparkSession <- sparkR.session() +jsc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Tests diff --git a/R/pkg/inst/tests/testthat/test_rdd.R b/R/pkg/inst/tests/testthat/test_rdd.R index b6c8e1dc6c..429311d292 100644 --- a/R/pkg/inst/tests/testthat/test_rdd.R +++ b/R/pkg/inst/tests/testthat/test_rdd.R @@ -18,7 +18,8 @@ context("basic RDD functions") # JavaSparkContext handle -sc <- sparkR.init() +sparkSession <- sparkR.session() +sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Data nums <- 1:10 diff --git a/R/pkg/inst/tests/testthat/test_shuffle.R b/R/pkg/inst/tests/testthat/test_shuffle.R index d3d0f8a24d..7d4f342016 100644 --- a/R/pkg/inst/tests/testthat/test_shuffle.R +++ b/R/pkg/inst/tests/testthat/test_shuffle.R @@ -18,7 +18,8 @@ context("partitionBy, groupByKey, reduceByKey etc.") # JavaSparkContext handle -sc <- sparkR.init() +sparkSession <- sparkR.session() +sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Data intPairs <- list(list(1L, -1), list(2L, 100), list(2L, 1), list(1L, 200)) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 607bd9c12f..fcc2ab3ed6 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -33,26 +33,35 @@ markUtf8 <- function(s) { } setHiveContext <- function(sc) { - ssc <- callJMethod(sc, "sc") - hiveCtx <- tryCatch({ - newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) - }, - error = function(err) { - skip("Hive is not build with SparkSQL, skipped") - }) - assign(".sparkRHivesc", hiveCtx, envir = .sparkREnv) - hiveCtx + if (exists(".testHiveSession", envir = .sparkREnv)) { + hiveSession <- get(".testHiveSession", envir = .sparkREnv) + } else { + # initialize once and reuse + ssc <- callJMethod(sc, "sc") + hiveCtx <- tryCatch({ + newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) + }, + error = function(err) { + skip("Hive is not build with SparkSQL, skipped") + }) + hiveSession <- callJMethod(hiveCtx, "sparkSession") + } + previousSession <- get(".sparkRsession", envir = .sparkREnv) + assign(".sparkRsession", hiveSession, envir = .sparkREnv) + assign(".prevSparkRsession", previousSession, envir = .sparkREnv) + hiveSession } unsetHiveContext <- function() { - remove(".sparkRHivesc", envir = .sparkREnv) + previousSession <- get(".prevSparkRsession", envir = .sparkREnv) + assign(".sparkRsession", previousSession, envir = .sparkREnv) + remove(".prevSparkRsession", envir = .sparkREnv) } # Tests for SparkSQL functions in SparkR -sc <- sparkR.init() - -sqlContext <- sparkRSQL.init(sc) +sparkSession <- sparkR.session() +sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) mockLines <- c("{\"name\":\"Michael\"}", "{\"name\":\"Andy\", \"age\":30}", @@ -79,7 +88,16 @@ complexTypeJsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(mockLinesComplexType, complexTypeJsonPath) test_that("calling sparkRSQL.init returns existing SQL context", { - expect_equal(sparkRSQL.init(sc), sqlContext) + sqlContext <- suppressWarnings(sparkRSQL.init(sc)) + expect_equal(suppressWarnings(sparkRSQL.init(sc)), sqlContext) +}) + +test_that("calling sparkRSQL.init returns existing SparkSession", { + expect_equal(suppressWarnings(sparkRSQL.init(sc)), sparkSession) +}) + +test_that("calling sparkR.session returns existing SparkSession", { + expect_equal(sparkR.session(), sparkSession) }) test_that("infer types and check types", { @@ -431,6 +449,7 @@ test_that("read/write json files", { }) test_that("jsonRDD() on a RDD with json string", { + sqlContext <- suppressWarnings(sparkRSQL.init(sc)) rdd <- parallelize(sc, mockLines) expect_equal(count(rdd), 3) df <- suppressWarnings(jsonRDD(sqlContext, rdd)) @@ -2228,7 +2247,6 @@ test_that("gapply() on a DataFrame", { }) test_that("Window functions on a DataFrame", { - setHiveContext(sc) df <- createDataFrame(list(list(1L, "1"), list(2L, "2"), list(1L, "1"), list(2L, "2")), schema = c("key", "value")) ws <- orderBy(window.partitionBy("key"), "value") @@ -2253,10 +2271,10 @@ test_that("Window functions on a DataFrame", { result <- collect(select(df, over(lead("key", 1), ws), over(lead("value", 1), ws))) names(result) <- c("key", "value") expect_equal(result, expected) - unsetHiveContext() }) test_that("createDataFrame sqlContext parameter backward compatibility", { + sqlContext <- suppressWarnings(sparkRSQL.init(sc)) a <- 1:3 b <- c("a", "b", "c") ldf <- data.frame(a, b) @@ -2283,7 +2301,6 @@ test_that("createDataFrame sqlContext parameter backward compatibility", { test_that("randomSplit", { num <- 4000 df <- createDataFrame(data.frame(id = 1:num)) - weights <- c(2, 3, 5) df_list <- randomSplit(df, weights) expect_equal(length(weights), length(df_list)) @@ -2298,6 +2315,41 @@ test_that("randomSplit", { expect_true(all(sapply(abs(counts / num - weights / sum(weights)), function(e) { e < 0.05 }))) }) +test_that("Change config on SparkSession", { + # first, set it to a random but known value + conf <- callJMethod(sparkSession, "conf") + property <- paste0("spark.testing.", as.character(runif(1))) + value1 <- as.character(runif(1)) + callJMethod(conf, "set", property, value1) + + # next, change the same property to the new value + value2 <- as.character(runif(1)) + l <- list(value2) + names(l) <- property + sparkR.session(sparkConfig = l) + + conf <- callJMethod(sparkSession, "conf") + newValue <- callJMethod(conf, "get", property, "") + expect_equal(value2, newValue) + + value <- as.character(runif(1)) + sparkR.session(spark.app.name = "sparkSession test", spark.testing.r.session.r = value) + conf <- callJMethod(sparkSession, "conf") + appNameValue <- callJMethod(conf, "get", "spark.app.name", "") + testValue <- callJMethod(conf, "get", "spark.testing.r.session.r", "") + expect_equal(appNameValue, "sparkSession test") + expect_equal(testValue, value) +}) + +test_that("enableHiveSupport on SparkSession", { + setHiveContext(sc) + unsetHiveContext() + # if we are still here, it must be built with hive + conf <- callJMethod(sparkSession, "conf") + value <- callJMethod(conf, "get", "spark.sql.catalogImplementation", "") + expect_equal(value, "hive") +}) + unlink(parquetPath) unlink(jsonPath) unlink(jsonPathNa) diff --git a/R/pkg/inst/tests/testthat/test_take.R b/R/pkg/inst/tests/testthat/test_take.R index c2c724cdc7..daf5e41abe 100644 --- a/R/pkg/inst/tests/testthat/test_take.R +++ b/R/pkg/inst/tests/testthat/test_take.R @@ -30,10 +30,11 @@ strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge, ", "raising me. But they're both dead now. I didn't kill them. Honest.") # JavaSparkContext handle -jsc <- sparkR.init() +sparkSession <- sparkR.session() +sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) test_that("take() gives back the original elements in correct count and order", { - numVectorRDD <- parallelize(jsc, numVector, 10) + numVectorRDD <- parallelize(sc, numVector, 10) # case: number of elements to take is less than the size of the first partition expect_equal(take(numVectorRDD, 1), as.list(head(numVector, n = 1))) # case: number of elements to take is the same as the size of the first partition @@ -42,20 +43,20 @@ test_that("take() gives back the original elements in correct count and order", expect_equal(take(numVectorRDD, length(numVector)), as.list(numVector)) expect_equal(take(numVectorRDD, length(numVector) + 1), as.list(numVector)) - numListRDD <- parallelize(jsc, numList, 1) - numListRDD2 <- parallelize(jsc, numList, 4) + numListRDD <- parallelize(sc, numList, 1) + numListRDD2 <- parallelize(sc, numList, 4) expect_equal(take(numListRDD, 3), take(numListRDD2, 3)) expect_equal(take(numListRDD, 5), take(numListRDD2, 5)) expect_equal(take(numListRDD, 1), as.list(head(numList, n = 1))) expect_equal(take(numListRDD2, 999), numList) - strVectorRDD <- parallelize(jsc, strVector, 2) - strVectorRDD2 <- parallelize(jsc, strVector, 3) + strVectorRDD <- parallelize(sc, strVector, 2) + strVectorRDD2 <- parallelize(sc, strVector, 3) expect_equal(take(strVectorRDD, 4), as.list(strVector)) expect_equal(take(strVectorRDD2, 2), as.list(head(strVector, n = 2))) - strListRDD <- parallelize(jsc, strList, 4) - strListRDD2 <- parallelize(jsc, strList, 1) + strListRDD <- parallelize(sc, strList, 4) + strListRDD2 <- parallelize(sc, strList, 1) expect_equal(take(strListRDD, 3), as.list(head(strList, n = 3))) expect_equal(take(strListRDD2, 1), as.list(head(strList, n = 1))) diff --git a/R/pkg/inst/tests/testthat/test_textFile.R b/R/pkg/inst/tests/testthat/test_textFile.R index e64ef1bb31..7b2cc74753 100644 --- a/R/pkg/inst/tests/testthat/test_textFile.R +++ b/R/pkg/inst/tests/testthat/test_textFile.R @@ -18,7 +18,8 @@ context("the textFile() function") # JavaSparkContext handle -sc <- sparkR.init() +sparkSession <- sparkR.session() +sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) mockFile <- c("Spark is pretty.", "Spark is awesome.") diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R index 54d2eca50e..21a119a06b 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/inst/tests/testthat/test_utils.R @@ -18,7 +18,8 @@ context("functions in utils.R") # JavaSparkContext handle -sc <- sparkR.init() +sparkSession <- sparkR.session() +sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) test_that("convertJListToRList() gives back (deserializes) the original JLists of strings and integers", { @@ -168,3 +169,16 @@ test_that("convertToJSaveMode", { test_that("hashCode", { expect_error(hashCode("bc53d3605e8a5b7de1e8e271c2317645"), NA) }) + +test_that("overrideEnvs", { + config <- new.env() + config[["spark.master"]] <- "foo" + config[["config_only"]] <- "ok" + param <- new.env() + param[["spark.master"]] <- "local" + param[["param_only"]] <- "blah" + overrideEnvs(config, param) + expect_equal(config[["spark.master"]], "local") + expect_equal(config[["param_only"]], "blah") + expect_equal(config[["config_only"]], "ok") +}) |