aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--R/pkg/R/DataFrame.R20
-rw-r--r--R/pkg/R/SQLContext.R19
-rw-r--r--R/pkg/R/generics.R4
-rw-r--r--R/pkg/R/utils.R52
-rw-r--r--R/pkg/inst/tests/testthat/test_sparkSQL.R35
-rw-r--r--R/pkg/inst/tests/testthat/test_utils.R10
6 files changed, 127 insertions, 13 deletions
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 40f1f0f442..75861d5de7 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -2608,7 +2608,7 @@ setMethod("except",
#' @param ... additional argument(s) passed to the method.
#'
#' @family SparkDataFrame functions
-#' @aliases write.df,SparkDataFrame,character-method
+#' @aliases write.df,SparkDataFrame-method
#' @rdname write.df
#' @name write.df
#' @export
@@ -2622,21 +2622,31 @@ setMethod("except",
#' }
#' @note write.df since 1.4.0
setMethod("write.df",
- signature(df = "SparkDataFrame", path = "character"),
- function(df, path, source = NULL, mode = "error", ...) {
+ signature(df = "SparkDataFrame"),
+ function(df, path = NULL, source = NULL, mode = "error", ...) {
+ if (!is.null(path) && !is.character(path)) {
+ stop("path should be charactor, NULL or omitted.")
+ }
+ if (!is.null(source) && !is.character(source)) {
+ stop("source should be character, NULL or omitted. It is the datasource specified ",
+ "in 'spark.sql.sources.default' configuration by default.")
+ }
+ if (!is.character(mode)) {
+ stop("mode should be charactor or omitted. It is 'error' by default.")
+ }
if (is.null(source)) {
source <- getDefaultSqlSource()
}
jmode <- convertToJSaveMode(mode)
options <- varargsToEnv(...)
if (!is.null(path)) {
- options[["path"]] <- path
+ options[["path"]] <- path
}
write <- callJMethod(df@sdf, "write")
write <- callJMethod(write, "format", source)
write <- callJMethod(write, "mode", jmode)
write <- callJMethod(write, "options", options)
- write <- callJMethod(write, "save", path)
+ write <- handledCallJMethod(write, "save")
})
#' @rdname write.df
diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R
index ce531c3f88..baa87824be 100644
--- a/R/pkg/R/SQLContext.R
+++ b/R/pkg/R/SQLContext.R
@@ -771,6 +771,13 @@ dropTempView <- function(viewName) {
#' @method read.df default
#' @note read.df since 1.4.0
read.df.default <- function(path = NULL, source = NULL, schema = NULL, na.strings = "NA", ...) {
+ if (!is.null(path) && !is.character(path)) {
+ stop("path should be charactor, NULL or omitted.")
+ }
+ if (!is.null(source) && !is.character(source)) {
+ stop("source should be character, NULL or omitted. It is the datasource specified ",
+ "in 'spark.sql.sources.default' configuration by default.")
+ }
sparkSession <- getSparkSession()
options <- varargsToEnv(...)
if (!is.null(path)) {
@@ -784,16 +791,16 @@ read.df.default <- function(path = NULL, source = NULL, schema = NULL, na.string
}
if (!is.null(schema)) {
stopifnot(class(schema) == "structType")
- sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sparkSession, source,
- schema$jobj, options)
+ sdf <- handledCallJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sparkSession,
+ source, schema$jobj, options)
} else {
- sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils",
- "loadDF", sparkSession, source, options)
+ sdf <- handledCallJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sparkSession,
+ source, options)
}
dataFrame(sdf)
}
-read.df <- function(x, ...) {
+read.df <- function(x = NULL, ...) {
dispatchFunc("read.df(path = NULL, source = NULL, schema = NULL, ...)", x, ...)
}
@@ -805,7 +812,7 @@ loadDF.default <- function(path = NULL, source = NULL, schema = NULL, ...) {
read.df(path, source, schema, ...)
}
-loadDF <- function(x, ...) {
+loadDF <- function(x = NULL, ...) {
dispatchFunc("loadDF(path = NULL, source = NULL, schema = NULL, ...)", x, ...)
}
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 67a999da9b..90a02e2778 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -633,7 +633,7 @@ setGeneric("transform", function(`_data`, ...) {standardGeneric("transform") })
#' @rdname write.df
#' @export
-setGeneric("write.df", function(df, path, source = NULL, mode = "error", ...) {
+setGeneric("write.df", function(df, path = NULL, source = NULL, mode = "error", ...) {
standardGeneric("write.df")
})
@@ -732,7 +732,7 @@ setGeneric("withColumnRenamed",
#' @rdname write.df
#' @export
-setGeneric("write.df", function(df, path, ...) { standardGeneric("write.df") })
+setGeneric("write.df", function(df, path = NULL, ...) { standardGeneric("write.df") })
#' @rdname randomSplit
#' @export
diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R
index 248c57532b..e696664534 100644
--- a/R/pkg/R/utils.R
+++ b/R/pkg/R/utils.R
@@ -698,6 +698,58 @@ isSparkRShell <- function() {
grepl(".*shell\\.R$", Sys.getenv("R_PROFILE_USER"), perl = TRUE)
}
+# Works identically with `callJStatic(...)` but throws a pretty formatted exception.
+handledCallJStatic <- function(cls, method, ...) {
+ result <- tryCatch(callJStatic(cls, method, ...),
+ error = function(e) {
+ captureJVMException(e, method)
+ })
+ result
+}
+
+# Works identically with `callJMethod(...)` but throws a pretty formatted exception.
+handledCallJMethod <- function(obj, method, ...) {
+ result <- tryCatch(callJMethod(obj, method, ...),
+ error = function(e) {
+ captureJVMException(e, method)
+ })
+ result
+}
+
+captureJVMException <- function(e, method) {
+ rawmsg <- as.character(e)
+ if (any(grep("^Error in .*?: ", rawmsg))) {
+ # If the exception message starts with "Error in ...", this is possibly
+ # "Error in invokeJava(...)". Here, it replaces the characters to
+ # `paste("Error in", method, ":")` in order to identify which function
+ # was called in JVM side.
+ stacktrace <- strsplit(rawmsg, "Error in .*?: ")[[1]]
+ rmsg <- paste("Error in", method, ":")
+ stacktrace <- paste(rmsg[1], stacktrace[2])
+ } else {
+ # Otherwise, do not convert the error message just in case.
+ stacktrace <- rawmsg
+ }
+
+ if (any(grep("java.lang.IllegalArgumentException: ", stacktrace))) {
+ msg <- strsplit(stacktrace, "java.lang.IllegalArgumentException: ", fixed = TRUE)[[1]]
+ # Extract "Error in ..." message.
+ rmsg <- msg[1]
+ # Extract the first message of JVM exception.
+ first <- strsplit(msg[2], "\r?\n\tat")[[1]][1]
+ stop(paste0(rmsg, "illegal argument - ", first), call. = FALSE)
+ } else if (any(grep("org.apache.spark.sql.AnalysisException: ", stacktrace))) {
+ msg <- strsplit(stacktrace, "org.apache.spark.sql.AnalysisException: ", fixed = TRUE)[[1]]
+ # Extract "Error in ..." message.
+ rmsg <- msg[1]
+ # Extract the first message of JVM exception.
+ first <- strsplit(msg[2], "\r?\n\tat")[[1]][1]
+ stop(paste0(rmsg, "analysis error - ", first), call. = FALSE)
+ } else {
+ stop(stacktrace, call. = FALSE)
+ }
+}
+
# rbind a list of rows with raw (binary) columns
#
# @param inputData a list of rows, with each row a list
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index 9d874a0988..f5ab601f27 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -2544,6 +2544,41 @@ test_that("Spark version from SparkSession", {
expect_equal(ver, version)
})
+test_that("Call DataFrameWriter.save() API in Java without path and check argument types", {
+ df <- read.df(jsonPath, "json")
+ # This tests if the exception is thrown from JVM not from SparkR side.
+ # It makes sure that we can omit path argument in write.df API and then it calls
+ # DataFrameWriter.save() without path.
+ expect_error(write.df(df, source = "csv"),
+ "Error in save : illegal argument - 'path' is not specified")
+
+ # Arguments checking in R side.
+ expect_error(write.df(df, "data.tmp", source = c(1, 2)),
+ paste("source should be character, NULL or omitted. It is the datasource specified",
+ "in 'spark.sql.sources.default' configuration by default."))
+ expect_error(write.df(df, path = c(3)),
+ "path should be charactor, NULL or omitted.")
+ expect_error(write.df(df, mode = TRUE),
+ "mode should be charactor or omitted. It is 'error' by default.")
+})
+
+test_that("Call DataFrameWriter.load() API in Java without path and check argument types", {
+ # This tests if the exception is thrown from JVM not from SparkR side.
+ # It makes sure that we can omit path argument in read.df API and then it calls
+ # DataFrameWriter.load() without path.
+ expect_error(read.df(source = "json"),
+ paste("Error in loadDF : analysis error - Unable to infer schema for JSON at .",
+ "It must be specified manually"))
+ expect_error(read.df("arbitrary_path"), "Error in loadDF : analysis error - Path does not exist")
+
+ # Arguments checking in R side.
+ expect_error(read.df(path = c(3)),
+ "path should be charactor, NULL or omitted.")
+ expect_error(read.df(jsonPath, source = c(1, 2)),
+ paste("source should be character, NULL or omitted. It is the datasource specified",
+ "in 'spark.sql.sources.default' configuration by default."))
+})
+
unlink(parquetPath)
unlink(orcPath)
unlink(jsonPath)
diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R
index 77f25292f3..69ed554916 100644
--- a/R/pkg/inst/tests/testthat/test_utils.R
+++ b/R/pkg/inst/tests/testthat/test_utils.R
@@ -166,6 +166,16 @@ test_that("convertToJSaveMode", {
'mode should be one of "append", "overwrite", "error", "ignore"') #nolint
})
+test_that("captureJVMException", {
+ method <- "getSQLDataType"
+ expect_error(tryCatch(callJStatic("org.apache.spark.sql.api.r.SQLUtils", method,
+ "unknown"),
+ error = function(e) {
+ captureJVMException(e, method)
+ }),
+ "Error in getSQLDataType : illegal argument - Invalid type unknown")
+})
+
test_that("hashCode", {
expect_error(hashCode("bc53d3605e8a5b7de1e8e271c2317645"), NA)
})