aboutsummaryrefslogtreecommitdiff
path: root/R/pkg/inst/tests/testthat
diff options
context:
space:
mode:
Diffstat (limited to 'R/pkg/inst/tests/testthat')
-rw-r--r--R/pkg/inst/tests/testthat/jarTest.R4
-rw-r--r--R/pkg/inst/tests/testthat/packageInAJarTest.R4
-rw-r--r--R/pkg/inst/tests/testthat/test_Serde.R2
-rw-r--r--R/pkg/inst/tests/testthat/test_binaryFile.R3
-rw-r--r--R/pkg/inst/tests/testthat/test_binary_function.R3
-rw-r--r--R/pkg/inst/tests/testthat/test_broadcast.R3
-rw-r--r--R/pkg/inst/tests/testthat/test_context.R41
-rw-r--r--R/pkg/inst/tests/testthat/test_includePackage.R3
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib.R5
-rw-r--r--R/pkg/inst/tests/testthat/test_parallelize_collect.R3
-rw-r--r--R/pkg/inst/tests/testthat/test_rdd.R3
-rw-r--r--R/pkg/inst/tests/testthat/test_shuffle.R3
-rw-r--r--R/pkg/inst/tests/testthat/test_sparkSQL.R86
-rw-r--r--R/pkg/inst/tests/testthat/test_take.R17
-rw-r--r--R/pkg/inst/tests/testthat/test_textFile.R3
-rw-r--r--R/pkg/inst/tests/testthat/test_utils.R16
16 files changed, 138 insertions, 61 deletions
diff --git a/R/pkg/inst/tests/testthat/jarTest.R b/R/pkg/inst/tests/testthat/jarTest.R
index d68bb20950..84e4845f18 100644
--- a/R/pkg/inst/tests/testthat/jarTest.R
+++ b/R/pkg/inst/tests/testthat/jarTest.R
@@ -16,7 +16,7 @@
#
library(SparkR)
-sc <- sparkR.init()
+sparkSession <- sparkR.session()
helloTest <- SparkR:::callJStatic("sparkR.test.hello",
"helloWorld",
@@ -27,6 +27,6 @@ basicFunction <- SparkR:::callJStatic("sparkR.test.basicFunction",
2L,
2L)
-sparkR.stop()
+sparkR.session.stop()
output <- c(helloTest, basicFunction)
writeLines(output)
diff --git a/R/pkg/inst/tests/testthat/packageInAJarTest.R b/R/pkg/inst/tests/testthat/packageInAJarTest.R
index c26b28b78d..940c91f376 100644
--- a/R/pkg/inst/tests/testthat/packageInAJarTest.R
+++ b/R/pkg/inst/tests/testthat/packageInAJarTest.R
@@ -17,13 +17,13 @@
library(SparkR)
library(sparkPackageTest)
-sc <- sparkR.init()
+sparkSession <- sparkR.session()
run1 <- myfunc(5L)
run2 <- myfunc(-4L)
-sparkR.stop()
+sparkR.session.stop()
if (run1 != 6) quit(save = "no", status = 1)
diff --git a/R/pkg/inst/tests/testthat/test_Serde.R b/R/pkg/inst/tests/testthat/test_Serde.R
index dddce54d70..96fb6dda26 100644
--- a/R/pkg/inst/tests/testthat/test_Serde.R
+++ b/R/pkg/inst/tests/testthat/test_Serde.R
@@ -17,7 +17,7 @@
context("SerDe functionality")
-sc <- sparkR.init()
+sparkSession <- sparkR.session()
test_that("SerDe of primitive types", {
x <- callJStatic("SparkRHandler", "echo", 1L)
diff --git a/R/pkg/inst/tests/testthat/test_binaryFile.R b/R/pkg/inst/tests/testthat/test_binaryFile.R
index 976a7558a8..b69f017de8 100644
--- a/R/pkg/inst/tests/testthat/test_binaryFile.R
+++ b/R/pkg/inst/tests/testthat/test_binaryFile.R
@@ -18,7 +18,8 @@
context("functions on binary files")
# JavaSparkContext handle
-sc <- sparkR.init()
+sparkSession <- sparkR.session()
+sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession)
mockFile <- c("Spark is pretty.", "Spark is awesome.")
diff --git a/R/pkg/inst/tests/testthat/test_binary_function.R b/R/pkg/inst/tests/testthat/test_binary_function.R
index 7bad4d2a7e..6f51d20687 100644
--- a/R/pkg/inst/tests/testthat/test_binary_function.R
+++ b/R/pkg/inst/tests/testthat/test_binary_function.R
@@ -18,7 +18,8 @@
context("binary functions")
# JavaSparkContext handle
-sc <- sparkR.init()
+sparkSession <- sparkR.session()
+sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession)
# Data
nums <- 1:10
diff --git a/R/pkg/inst/tests/testthat/test_broadcast.R b/R/pkg/inst/tests/testthat/test_broadcast.R
index 8be6efc3db..cf1d432771 100644
--- a/R/pkg/inst/tests/testthat/test_broadcast.R
+++ b/R/pkg/inst/tests/testthat/test_broadcast.R
@@ -18,7 +18,8 @@
context("broadcast variables")
# JavaSparkContext handle
-sc <- sparkR.init()
+sparkSession <- sparkR.session()
+sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession)
# Partitioned data
nums <- 1:2
diff --git a/R/pkg/inst/tests/testthat/test_context.R b/R/pkg/inst/tests/testthat/test_context.R
index 126484c995..f123187adf 100644
--- a/R/pkg/inst/tests/testthat/test_context.R
+++ b/R/pkg/inst/tests/testthat/test_context.R
@@ -56,31 +56,33 @@ test_that("Check masked functions", {
test_that("repeatedly starting and stopping SparkR", {
for (i in 1:4) {
- sc <- sparkR.init()
+ sc <- suppressWarnings(sparkR.init())
rdd <- parallelize(sc, 1:20, 2L)
expect_equal(count(rdd), 20)
- sparkR.stop()
+ suppressWarnings(sparkR.stop())
}
})
-test_that("repeatedly starting and stopping SparkR SQL", {
- for (i in 1:4) {
- sc <- sparkR.init()
- sqlContext <- sparkRSQL.init(sc)
- df <- createDataFrame(data.frame(a = 1:20))
- expect_equal(count(df), 20)
- sparkR.stop()
- }
-})
+# Does not work consistently even with Hive off
+# nolint start
+# test_that("repeatedly starting and stopping SparkR", {
+# for (i in 1:4) {
+# sparkR.session(enableHiveSupport = FALSE)
+# df <- createDataFrame(data.frame(dummy=1:i))
+# expect_equal(count(df), i)
+# sparkR.session.stop()
+# Sys.sleep(5) # Need more time to shutdown Hive metastore
+# }
+# })
+# nolint end
test_that("rdd GC across sparkR.stop", {
- sparkR.stop()
- sc <- sparkR.init() # sc should get id 0
+ sc <- sparkR.sparkContext() # sc should get id 0
rdd1 <- parallelize(sc, 1:20, 2L) # rdd1 should get id 1
rdd2 <- parallelize(sc, 1:10, 2L) # rdd2 should get id 2
- sparkR.stop()
+ sparkR.session.stop()
- sc <- sparkR.init() # sc should get id 0 again
+ sc <- sparkR.sparkContext() # sc should get id 0 again
# GC rdd1 before creating rdd3 and rdd2 after
rm(rdd1)
@@ -97,15 +99,17 @@ test_that("rdd GC across sparkR.stop", {
})
test_that("job group functions can be called", {
- sc <- sparkR.init()
+ sc <- sparkR.sparkContext()
setJobGroup(sc, "groupId", "job description", TRUE)
cancelJobGroup(sc, "groupId")
clearJobGroup(sc)
+ sparkR.session.stop()
})
test_that("utility function can be called", {
- sc <- sparkR.init()
+ sc <- sparkR.sparkContext()
setLogLevel(sc, "ERROR")
+ sparkR.session.stop()
})
test_that("getClientModeSparkSubmitOpts() returns spark-submit args from whitelist", {
@@ -156,7 +160,8 @@ test_that("sparkJars sparkPackages as comma-separated strings", {
})
test_that("spark.lapply should perform simple transforms", {
- sc <- sparkR.init()
+ sc <- sparkR.sparkContext()
doubled <- spark.lapply(sc, 1:10, function(x) { 2 * x })
expect_equal(doubled, as.list(2 * 1:10))
+ sparkR.session.stop()
})
diff --git a/R/pkg/inst/tests/testthat/test_includePackage.R b/R/pkg/inst/tests/testthat/test_includePackage.R
index 8152b448d0..d6a3766539 100644
--- a/R/pkg/inst/tests/testthat/test_includePackage.R
+++ b/R/pkg/inst/tests/testthat/test_includePackage.R
@@ -18,7 +18,8 @@
context("include R packages")
# JavaSparkContext handle
-sc <- sparkR.init()
+sparkSession <- sparkR.session()
+sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession)
# Partitioned data
nums <- 1:2
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index 59ef15c1e9..c8c5ef2476 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -20,10 +20,7 @@ library(testthat)
context("MLlib functions")
# Tests for MLlib functions in SparkR
-
-sc <- sparkR.init()
-
-sqlContext <- sparkRSQL.init(sc)
+sparkSession <- sparkR.session()
test_that("formula of spark.glm", {
training <- suppressWarnings(createDataFrame(iris))
diff --git a/R/pkg/inst/tests/testthat/test_parallelize_collect.R b/R/pkg/inst/tests/testthat/test_parallelize_collect.R
index 2552127cc5..f79a8a70aa 100644
--- a/R/pkg/inst/tests/testthat/test_parallelize_collect.R
+++ b/R/pkg/inst/tests/testthat/test_parallelize_collect.R
@@ -33,7 +33,8 @@ numPairs <- list(list(1, 1), list(1, 2), list(2, 2), list(2, 3))
strPairs <- list(list(strList, strList), list(strList, strList))
# JavaSparkContext handle
-jsc <- sparkR.init()
+sparkSession <- sparkR.session()
+jsc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession)
# Tests
diff --git a/R/pkg/inst/tests/testthat/test_rdd.R b/R/pkg/inst/tests/testthat/test_rdd.R
index b6c8e1dc6c..429311d292 100644
--- a/R/pkg/inst/tests/testthat/test_rdd.R
+++ b/R/pkg/inst/tests/testthat/test_rdd.R
@@ -18,7 +18,8 @@
context("basic RDD functions")
# JavaSparkContext handle
-sc <- sparkR.init()
+sparkSession <- sparkR.session()
+sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession)
# Data
nums <- 1:10
diff --git a/R/pkg/inst/tests/testthat/test_shuffle.R b/R/pkg/inst/tests/testthat/test_shuffle.R
index d3d0f8a24d..7d4f342016 100644
--- a/R/pkg/inst/tests/testthat/test_shuffle.R
+++ b/R/pkg/inst/tests/testthat/test_shuffle.R
@@ -18,7 +18,8 @@
context("partitionBy, groupByKey, reduceByKey etc.")
# JavaSparkContext handle
-sc <- sparkR.init()
+sparkSession <- sparkR.session()
+sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession)
# Data
intPairs <- list(list(1L, -1), list(2L, 100), list(2L, 1), list(1L, 200))
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index 607bd9c12f..fcc2ab3ed6 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -33,26 +33,35 @@ markUtf8 <- function(s) {
}
setHiveContext <- function(sc) {
- ssc <- callJMethod(sc, "sc")
- hiveCtx <- tryCatch({
- newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc)
- },
- error = function(err) {
- skip("Hive is not build with SparkSQL, skipped")
- })
- assign(".sparkRHivesc", hiveCtx, envir = .sparkREnv)
- hiveCtx
+ if (exists(".testHiveSession", envir = .sparkREnv)) {
+ hiveSession <- get(".testHiveSession", envir = .sparkREnv)
+ } else {
+ # initialize once and reuse
+ ssc <- callJMethod(sc, "sc")
+ hiveCtx <- tryCatch({
+ newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc)
+ },
+ error = function(err) {
+ skip("Hive is not build with SparkSQL, skipped")
+ })
+ hiveSession <- callJMethod(hiveCtx, "sparkSession")
+ }
+ previousSession <- get(".sparkRsession", envir = .sparkREnv)
+ assign(".sparkRsession", hiveSession, envir = .sparkREnv)
+ assign(".prevSparkRsession", previousSession, envir = .sparkREnv)
+ hiveSession
}
unsetHiveContext <- function() {
- remove(".sparkRHivesc", envir = .sparkREnv)
+ previousSession <- get(".prevSparkRsession", envir = .sparkREnv)
+ assign(".sparkRsession", previousSession, envir = .sparkREnv)
+ remove(".prevSparkRsession", envir = .sparkREnv)
}
# Tests for SparkSQL functions in SparkR
-sc <- sparkR.init()
-
-sqlContext <- sparkRSQL.init(sc)
+sparkSession <- sparkR.session()
+sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession)
mockLines <- c("{\"name\":\"Michael\"}",
"{\"name\":\"Andy\", \"age\":30}",
@@ -79,7 +88,16 @@ complexTypeJsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp")
writeLines(mockLinesComplexType, complexTypeJsonPath)
test_that("calling sparkRSQL.init returns existing SQL context", {
- expect_equal(sparkRSQL.init(sc), sqlContext)
+ sqlContext <- suppressWarnings(sparkRSQL.init(sc))
+ expect_equal(suppressWarnings(sparkRSQL.init(sc)), sqlContext)
+})
+
+test_that("calling sparkRSQL.init returns existing SparkSession", {
+ expect_equal(suppressWarnings(sparkRSQL.init(sc)), sparkSession)
+})
+
+test_that("calling sparkR.session returns existing SparkSession", {
+ expect_equal(sparkR.session(), sparkSession)
})
test_that("infer types and check types", {
@@ -431,6 +449,7 @@ test_that("read/write json files", {
})
test_that("jsonRDD() on a RDD with json string", {
+ sqlContext <- suppressWarnings(sparkRSQL.init(sc))
rdd <- parallelize(sc, mockLines)
expect_equal(count(rdd), 3)
df <- suppressWarnings(jsonRDD(sqlContext, rdd))
@@ -2228,7 +2247,6 @@ test_that("gapply() on a DataFrame", {
})
test_that("Window functions on a DataFrame", {
- setHiveContext(sc)
df <- createDataFrame(list(list(1L, "1"), list(2L, "2"), list(1L, "1"), list(2L, "2")),
schema = c("key", "value"))
ws <- orderBy(window.partitionBy("key"), "value")
@@ -2253,10 +2271,10 @@ test_that("Window functions on a DataFrame", {
result <- collect(select(df, over(lead("key", 1), ws), over(lead("value", 1), ws)))
names(result) <- c("key", "value")
expect_equal(result, expected)
- unsetHiveContext()
})
test_that("createDataFrame sqlContext parameter backward compatibility", {
+ sqlContext <- suppressWarnings(sparkRSQL.init(sc))
a <- 1:3
b <- c("a", "b", "c")
ldf <- data.frame(a, b)
@@ -2283,7 +2301,6 @@ test_that("createDataFrame sqlContext parameter backward compatibility", {
test_that("randomSplit", {
num <- 4000
df <- createDataFrame(data.frame(id = 1:num))
-
weights <- c(2, 3, 5)
df_list <- randomSplit(df, weights)
expect_equal(length(weights), length(df_list))
@@ -2298,6 +2315,41 @@ test_that("randomSplit", {
expect_true(all(sapply(abs(counts / num - weights / sum(weights)), function(e) { e < 0.05 })))
})
+test_that("Change config on SparkSession", {
+ # first, set it to a random but known value
+ conf <- callJMethod(sparkSession, "conf")
+ property <- paste0("spark.testing.", as.character(runif(1)))
+ value1 <- as.character(runif(1))
+ callJMethod(conf, "set", property, value1)
+
+ # next, change the same property to the new value
+ value2 <- as.character(runif(1))
+ l <- list(value2)
+ names(l) <- property
+ sparkR.session(sparkConfig = l)
+
+ conf <- callJMethod(sparkSession, "conf")
+ newValue <- callJMethod(conf, "get", property, "")
+ expect_equal(value2, newValue)
+
+ value <- as.character(runif(1))
+ sparkR.session(spark.app.name = "sparkSession test", spark.testing.r.session.r = value)
+ conf <- callJMethod(sparkSession, "conf")
+ appNameValue <- callJMethod(conf, "get", "spark.app.name", "")
+ testValue <- callJMethod(conf, "get", "spark.testing.r.session.r", "")
+ expect_equal(appNameValue, "sparkSession test")
+ expect_equal(testValue, value)
+})
+
+test_that("enableHiveSupport on SparkSession", {
+ setHiveContext(sc)
+ unsetHiveContext()
+ # if we are still here, it must be built with hive
+ conf <- callJMethod(sparkSession, "conf")
+ value <- callJMethod(conf, "get", "spark.sql.catalogImplementation", "")
+ expect_equal(value, "hive")
+})
+
unlink(parquetPath)
unlink(jsonPath)
unlink(jsonPathNa)
diff --git a/R/pkg/inst/tests/testthat/test_take.R b/R/pkg/inst/tests/testthat/test_take.R
index c2c724cdc7..daf5e41abe 100644
--- a/R/pkg/inst/tests/testthat/test_take.R
+++ b/R/pkg/inst/tests/testthat/test_take.R
@@ -30,10 +30,11 @@ strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge, ",
"raising me. But they're both dead now. I didn't kill them. Honest.")
# JavaSparkContext handle
-jsc <- sparkR.init()
+sparkSession <- sparkR.session()
+sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession)
test_that("take() gives back the original elements in correct count and order", {
- numVectorRDD <- parallelize(jsc, numVector, 10)
+ numVectorRDD <- parallelize(sc, numVector, 10)
# case: number of elements to take is less than the size of the first partition
expect_equal(take(numVectorRDD, 1), as.list(head(numVector, n = 1)))
# case: number of elements to take is the same as the size of the first partition
@@ -42,20 +43,20 @@ test_that("take() gives back the original elements in correct count and order",
expect_equal(take(numVectorRDD, length(numVector)), as.list(numVector))
expect_equal(take(numVectorRDD, length(numVector) + 1), as.list(numVector))
- numListRDD <- parallelize(jsc, numList, 1)
- numListRDD2 <- parallelize(jsc, numList, 4)
+ numListRDD <- parallelize(sc, numList, 1)
+ numListRDD2 <- parallelize(sc, numList, 4)
expect_equal(take(numListRDD, 3), take(numListRDD2, 3))
expect_equal(take(numListRDD, 5), take(numListRDD2, 5))
expect_equal(take(numListRDD, 1), as.list(head(numList, n = 1)))
expect_equal(take(numListRDD2, 999), numList)
- strVectorRDD <- parallelize(jsc, strVector, 2)
- strVectorRDD2 <- parallelize(jsc, strVector, 3)
+ strVectorRDD <- parallelize(sc, strVector, 2)
+ strVectorRDD2 <- parallelize(sc, strVector, 3)
expect_equal(take(strVectorRDD, 4), as.list(strVector))
expect_equal(take(strVectorRDD2, 2), as.list(head(strVector, n = 2)))
- strListRDD <- parallelize(jsc, strList, 4)
- strListRDD2 <- parallelize(jsc, strList, 1)
+ strListRDD <- parallelize(sc, strList, 4)
+ strListRDD2 <- parallelize(sc, strList, 1)
expect_equal(take(strListRDD, 3), as.list(head(strList, n = 3)))
expect_equal(take(strListRDD2, 1), as.list(head(strList, n = 1)))
diff --git a/R/pkg/inst/tests/testthat/test_textFile.R b/R/pkg/inst/tests/testthat/test_textFile.R
index e64ef1bb31..7b2cc74753 100644
--- a/R/pkg/inst/tests/testthat/test_textFile.R
+++ b/R/pkg/inst/tests/testthat/test_textFile.R
@@ -18,7 +18,8 @@
context("the textFile() function")
# JavaSparkContext handle
-sc <- sparkR.init()
+sparkSession <- sparkR.session()
+sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession)
mockFile <- c("Spark is pretty.", "Spark is awesome.")
diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R
index 54d2eca50e..21a119a06b 100644
--- a/R/pkg/inst/tests/testthat/test_utils.R
+++ b/R/pkg/inst/tests/testthat/test_utils.R
@@ -18,7 +18,8 @@
context("functions in utils.R")
# JavaSparkContext handle
-sc <- sparkR.init()
+sparkSession <- sparkR.session()
+sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession)
test_that("convertJListToRList() gives back (deserializes) the original JLists
of strings and integers", {
@@ -168,3 +169,16 @@ test_that("convertToJSaveMode", {
test_that("hashCode", {
expect_error(hashCode("bc53d3605e8a5b7de1e8e271c2317645"), NA)
})
+
+test_that("overrideEnvs", {
+ config <- new.env()
+ config[["spark.master"]] <- "foo"
+ config[["config_only"]] <- "ok"
+ param <- new.env()
+ param[["spark.master"]] <- "local"
+ param[["param_only"]] <- "blah"
+ overrideEnvs(config, param)
+ expect_equal(config[["spark.master"]], "local")
+ expect_equal(config[["param_only"]], "blah")
+ expect_equal(config[["config_only"]], "ok")
+})