diff options
Diffstat (limited to 'R/pkg/inst/tests/testthat/test_sparkSQL.R')
-rw-r--r-- | R/pkg/inst/tests/testthat/test_sparkSQL.R | 86 |
1 files changed, 69 insertions, 17 deletions
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 607bd9c12f..fcc2ab3ed6 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -33,26 +33,35 @@ markUtf8 <- function(s) { } setHiveContext <- function(sc) { - ssc <- callJMethod(sc, "sc") - hiveCtx <- tryCatch({ - newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) - }, - error = function(err) { - skip("Hive is not build with SparkSQL, skipped") - }) - assign(".sparkRHivesc", hiveCtx, envir = .sparkREnv) - hiveCtx + if (exists(".testHiveSession", envir = .sparkREnv)) { + hiveSession <- get(".testHiveSession", envir = .sparkREnv) + } else { + # initialize once and reuse + ssc <- callJMethod(sc, "sc") + hiveCtx <- tryCatch({ + newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc) + }, + error = function(err) { + skip("Hive is not build with SparkSQL, skipped") + }) + hiveSession <- callJMethod(hiveCtx, "sparkSession") + } + previousSession <- get(".sparkRsession", envir = .sparkREnv) + assign(".sparkRsession", hiveSession, envir = .sparkREnv) + assign(".prevSparkRsession", previousSession, envir = .sparkREnv) + hiveSession } unsetHiveContext <- function() { - remove(".sparkRHivesc", envir = .sparkREnv) + previousSession <- get(".prevSparkRsession", envir = .sparkREnv) + assign(".sparkRsession", previousSession, envir = .sparkREnv) + remove(".prevSparkRsession", envir = .sparkREnv) } # Tests for SparkSQL functions in SparkR -sc <- sparkR.init() - -sqlContext <- sparkRSQL.init(sc) +sparkSession <- sparkR.session() +sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) mockLines <- c("{\"name\":\"Michael\"}", "{\"name\":\"Andy\", \"age\":30}", @@ -79,7 +88,16 @@ complexTypeJsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(mockLinesComplexType, complexTypeJsonPath) test_that("calling sparkRSQL.init returns existing SQL context", { - expect_equal(sparkRSQL.init(sc), sqlContext) + sqlContext <- suppressWarnings(sparkRSQL.init(sc)) + expect_equal(suppressWarnings(sparkRSQL.init(sc)), sqlContext) +}) + +test_that("calling sparkRSQL.init returns existing SparkSession", { + expect_equal(suppressWarnings(sparkRSQL.init(sc)), sparkSession) +}) + +test_that("calling sparkR.session returns existing SparkSession", { + expect_equal(sparkR.session(), sparkSession) }) test_that("infer types and check types", { @@ -431,6 +449,7 @@ test_that("read/write json files", { }) test_that("jsonRDD() on a RDD with json string", { + sqlContext <- suppressWarnings(sparkRSQL.init(sc)) rdd <- parallelize(sc, mockLines) expect_equal(count(rdd), 3) df <- suppressWarnings(jsonRDD(sqlContext, rdd)) @@ -2228,7 +2247,6 @@ test_that("gapply() on a DataFrame", { }) test_that("Window functions on a DataFrame", { - setHiveContext(sc) df <- createDataFrame(list(list(1L, "1"), list(2L, "2"), list(1L, "1"), list(2L, "2")), schema = c("key", "value")) ws <- orderBy(window.partitionBy("key"), "value") @@ -2253,10 +2271,10 @@ test_that("Window functions on a DataFrame", { result <- collect(select(df, over(lead("key", 1), ws), over(lead("value", 1), ws))) names(result) <- c("key", "value") expect_equal(result, expected) - unsetHiveContext() }) test_that("createDataFrame sqlContext parameter backward compatibility", { + sqlContext <- suppressWarnings(sparkRSQL.init(sc)) a <- 1:3 b <- c("a", "b", "c") ldf <- data.frame(a, b) @@ -2283,7 +2301,6 @@ test_that("createDataFrame sqlContext parameter backward compatibility", { test_that("randomSplit", { num <- 4000 df <- createDataFrame(data.frame(id = 1:num)) - weights <- c(2, 3, 5) df_list <- randomSplit(df, weights) expect_equal(length(weights), length(df_list)) @@ -2298,6 +2315,41 @@ test_that("randomSplit", { expect_true(all(sapply(abs(counts / num - weights / sum(weights)), function(e) { e < 0.05 }))) }) +test_that("Change config on SparkSession", { + # first, set it to a random but known value + conf <- callJMethod(sparkSession, "conf") + property <- paste0("spark.testing.", as.character(runif(1))) + value1 <- as.character(runif(1)) + callJMethod(conf, "set", property, value1) + + # next, change the same property to the new value + value2 <- as.character(runif(1)) + l <- list(value2) + names(l) <- property + sparkR.session(sparkConfig = l) + + conf <- callJMethod(sparkSession, "conf") + newValue <- callJMethod(conf, "get", property, "") + expect_equal(value2, newValue) + + value <- as.character(runif(1)) + sparkR.session(spark.app.name = "sparkSession test", spark.testing.r.session.r = value) + conf <- callJMethod(sparkSession, "conf") + appNameValue <- callJMethod(conf, "get", "spark.app.name", "") + testValue <- callJMethod(conf, "get", "spark.testing.r.session.r", "") + expect_equal(appNameValue, "sparkSession test") + expect_equal(testValue, value) +}) + +test_that("enableHiveSupport on SparkSession", { + setHiveContext(sc) + unsetHiveContext() + # if we are still here, it must be built with hive + conf <- callJMethod(sparkSession, "conf") + value <- callJMethod(conf, "get", "spark.sql.catalogImplementation", "") + expect_equal(value, "hive") +}) + unlink(parquetPath) unlink(jsonPath) unlink(jsonPathNa) |