aboutsummaryrefslogtreecommitdiff
path: root/R/pkg/inst/tests/testthat/test_sparkSQL.R
diff options
context:
space:
mode:
Diffstat (limited to 'R/pkg/inst/tests/testthat/test_sparkSQL.R')
-rw-r--r--R/pkg/inst/tests/testthat/test_sparkSQL.R86
1 files changed, 69 insertions, 17 deletions
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index 607bd9c12f..fcc2ab3ed6 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -33,26 +33,35 @@ markUtf8 <- function(s) {
}
setHiveContext <- function(sc) {
- ssc <- callJMethod(sc, "sc")
- hiveCtx <- tryCatch({
- newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc)
- },
- error = function(err) {
- skip("Hive is not build with SparkSQL, skipped")
- })
- assign(".sparkRHivesc", hiveCtx, envir = .sparkREnv)
- hiveCtx
+ if (exists(".testHiveSession", envir = .sparkREnv)) {
+ hiveSession <- get(".testHiveSession", envir = .sparkREnv)
+ } else {
+ # initialize once and reuse
+ ssc <- callJMethod(sc, "sc")
+ hiveCtx <- tryCatch({
+ newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc)
+ },
+ error = function(err) {
+ skip("Hive is not build with SparkSQL, skipped")
+ })
+ hiveSession <- callJMethod(hiveCtx, "sparkSession")
+ }
+ previousSession <- get(".sparkRsession", envir = .sparkREnv)
+ assign(".sparkRsession", hiveSession, envir = .sparkREnv)
+ assign(".prevSparkRsession", previousSession, envir = .sparkREnv)
+ hiveSession
}
unsetHiveContext <- function() {
- remove(".sparkRHivesc", envir = .sparkREnv)
+ previousSession <- get(".prevSparkRsession", envir = .sparkREnv)
+ assign(".sparkRsession", previousSession, envir = .sparkREnv)
+ remove(".prevSparkRsession", envir = .sparkREnv)
}
# Tests for SparkSQL functions in SparkR
-sc <- sparkR.init()
-
-sqlContext <- sparkRSQL.init(sc)
+sparkSession <- sparkR.session()
+sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession)
mockLines <- c("{\"name\":\"Michael\"}",
"{\"name\":\"Andy\", \"age\":30}",
@@ -79,7 +88,16 @@ complexTypeJsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp")
writeLines(mockLinesComplexType, complexTypeJsonPath)
test_that("calling sparkRSQL.init returns existing SQL context", {
- expect_equal(sparkRSQL.init(sc), sqlContext)
+ sqlContext <- suppressWarnings(sparkRSQL.init(sc))
+ expect_equal(suppressWarnings(sparkRSQL.init(sc)), sqlContext)
+})
+
+test_that("calling sparkRSQL.init returns existing SparkSession", {
+ expect_equal(suppressWarnings(sparkRSQL.init(sc)), sparkSession)
+})
+
+test_that("calling sparkR.session returns existing SparkSession", {
+ expect_equal(sparkR.session(), sparkSession)
})
test_that("infer types and check types", {
@@ -431,6 +449,7 @@ test_that("read/write json files", {
})
test_that("jsonRDD() on a RDD with json string", {
+ sqlContext <- suppressWarnings(sparkRSQL.init(sc))
rdd <- parallelize(sc, mockLines)
expect_equal(count(rdd), 3)
df <- suppressWarnings(jsonRDD(sqlContext, rdd))
@@ -2228,7 +2247,6 @@ test_that("gapply() on a DataFrame", {
})
test_that("Window functions on a DataFrame", {
- setHiveContext(sc)
df <- createDataFrame(list(list(1L, "1"), list(2L, "2"), list(1L, "1"), list(2L, "2")),
schema = c("key", "value"))
ws <- orderBy(window.partitionBy("key"), "value")
@@ -2253,10 +2271,10 @@ test_that("Window functions on a DataFrame", {
result <- collect(select(df, over(lead("key", 1), ws), over(lead("value", 1), ws)))
names(result) <- c("key", "value")
expect_equal(result, expected)
- unsetHiveContext()
})
test_that("createDataFrame sqlContext parameter backward compatibility", {
+ sqlContext <- suppressWarnings(sparkRSQL.init(sc))
a <- 1:3
b <- c("a", "b", "c")
ldf <- data.frame(a, b)
@@ -2283,7 +2301,6 @@ test_that("createDataFrame sqlContext parameter backward compatibility", {
test_that("randomSplit", {
num <- 4000
df <- createDataFrame(data.frame(id = 1:num))
-
weights <- c(2, 3, 5)
df_list <- randomSplit(df, weights)
expect_equal(length(weights), length(df_list))
@@ -2298,6 +2315,41 @@ test_that("randomSplit", {
expect_true(all(sapply(abs(counts / num - weights / sum(weights)), function(e) { e < 0.05 })))
})
+test_that("Change config on SparkSession", {
+ # first, set it to a random but known value
+ conf <- callJMethod(sparkSession, "conf")
+ property <- paste0("spark.testing.", as.character(runif(1)))
+ value1 <- as.character(runif(1))
+ callJMethod(conf, "set", property, value1)
+
+ # next, change the same property to the new value
+ value2 <- as.character(runif(1))
+ l <- list(value2)
+ names(l) <- property
+ sparkR.session(sparkConfig = l)
+
+ conf <- callJMethod(sparkSession, "conf")
+ newValue <- callJMethod(conf, "get", property, "")
+ expect_equal(value2, newValue)
+
+ value <- as.character(runif(1))
+ sparkR.session(spark.app.name = "sparkSession test", spark.testing.r.session.r = value)
+ conf <- callJMethod(sparkSession, "conf")
+ appNameValue <- callJMethod(conf, "get", "spark.app.name", "")
+ testValue <- callJMethod(conf, "get", "spark.testing.r.session.r", "")
+ expect_equal(appNameValue, "sparkSession test")
+ expect_equal(testValue, value)
+})
+
+test_that("enableHiveSupport on SparkSession", {
+ setHiveContext(sc)
+ unsetHiveContext()
+ # if we are still here, it must be built with hive
+ conf <- callJMethod(sparkSession, "conf")
+ value <- callJMethod(conf, "get", "spark.sql.catalogImplementation", "")
+ expect_equal(value, "hive")
+})
+
unlink(parquetPath)
unlink(jsonPath)
unlink(jsonPathNa)