aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorhyukjinkwon <gurwls223@gmail.com>2016-10-04 22:58:43 -0700
committerFelix Cheung <felixcheung@apache.org>2016-10-04 22:58:43 -0700
commitc9fe10d4ed8df5ac4bd0f1eb8c9cd19244e27736 (patch)
treeb70dbf3d5ea108198c6451b7c4692aa589895fe0
parenta99743d053e84f695dc3034550939555297b0a05 (diff)
downloadspark-c9fe10d4ed8df5ac4bd0f1eb8c9cd19244e27736.tar.gz
spark-c9fe10d4ed8df5ac4bd0f1eb8c9cd19244e27736.tar.bz2
spark-c9fe10d4ed8df5ac4bd0f1eb8c9cd19244e27736.zip
[SPARK-17658][SPARKR] read.df/write.df API taking path optionally in SparkR
## What changes were proposed in this pull request? `write.df`/`read.df` API require path which is not actually always necessary in Spark. Currently, it only affects the datasources implementing `CreatableRelationProvider`. Currently, Spark currently does not have internal data sources implementing this but it'd affect other external datasources. In addition we'd be able to use this way in Spark's JDBC datasource after https://github.com/apache/spark/pull/12601 is merged. **Before** - `read.df` ```r > read.df(source = "json") Error in dispatchFunc("read.df(path = NULL, source = NULL, schema = NULL, ...)", : argument "x" is missing with no default ``` ```r > read.df(path = c(1, 2)) Error in dispatchFunc("read.df(path = NULL, source = NULL, schema = NULL, ...)", : argument "x" is missing with no default ``` ```r > read.df(c(1, 2)) Error in invokeJava(isStatic = TRUE, className, methodName, ...) : java.lang.ClassCastException: java.lang.Double cannot be cast to java.lang.String at org.apache.spark.sql.execution.datasources.DataSource.hasMetadata(DataSource.scala:300) at ... In if (is.na(object)) { : ... ``` - `write.df` ```r > write.df(df, source = "json") Error in (function (classes, fdef, mtable) : unable to find an inherited method for function ‘write.df’ for signature ‘"function", "missing"’ ``` ```r > write.df(df, source = c(1, 2)) Error in (function (classes, fdef, mtable) : unable to find an inherited method for function ‘write.df’ for signature ‘"SparkDataFrame", "missing"’ ``` ```r > write.df(df, mode = TRUE) Error in (function (classes, fdef, mtable) : unable to find an inherited method for function ‘write.df’ for signature ‘"SparkDataFrame", "missing"’ ``` **After** - `read.df` ```r > read.df(source = "json") Error in loadDF : analysis error - Unable to infer schema for JSON at . It must be specified manually; ``` ```r > read.df(path = c(1, 2)) Error in f(x, ...) : path should be charactor, null or omitted. ``` ```r > read.df(c(1, 2)) Error in f(x, ...) : path should be charactor, null or omitted. ``` - `write.df` ```r > write.df(df, source = "json") Error in save : illegal argument - 'path' is not specified ``` ```r > write.df(df, source = c(1, 2)) Error in .local(df, path, ...) : source should be charactor, null or omitted. It is 'parquet' by default. ``` ```r > write.df(df, mode = TRUE) Error in .local(df, path, ...) : mode should be charactor or omitted. It is 'error' by default. ``` ## How was this patch tested? Unit tests in `test_sparkSQL.R` Author: hyukjinkwon <gurwls223@gmail.com> Closes #15231 from HyukjinKwon/write-default-r.
-rw-r--r--R/pkg/R/DataFrame.R20
-rw-r--r--R/pkg/R/SQLContext.R19
-rw-r--r--R/pkg/R/generics.R4
-rw-r--r--R/pkg/R/utils.R52
-rw-r--r--R/pkg/inst/tests/testthat/test_sparkSQL.R35
-rw-r--r--R/pkg/inst/tests/testthat/test_utils.R10
6 files changed, 127 insertions, 13 deletions
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 40f1f0f442..75861d5de7 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -2608,7 +2608,7 @@ setMethod("except",
#' @param ... additional argument(s) passed to the method.
#'
#' @family SparkDataFrame functions
-#' @aliases write.df,SparkDataFrame,character-method
+#' @aliases write.df,SparkDataFrame-method
#' @rdname write.df
#' @name write.df
#' @export
@@ -2622,21 +2622,31 @@ setMethod("except",
#' }
#' @note write.df since 1.4.0
setMethod("write.df",
- signature(df = "SparkDataFrame", path = "character"),
- function(df, path, source = NULL, mode = "error", ...) {
+ signature(df = "SparkDataFrame"),
+ function(df, path = NULL, source = NULL, mode = "error", ...) {
+ if (!is.null(path) && !is.character(path)) {
+ stop("path should be charactor, NULL or omitted.")
+ }
+ if (!is.null(source) && !is.character(source)) {
+ stop("source should be character, NULL or omitted. It is the datasource specified ",
+ "in 'spark.sql.sources.default' configuration by default.")
+ }
+ if (!is.character(mode)) {
+ stop("mode should be charactor or omitted. It is 'error' by default.")
+ }
if (is.null(source)) {
source <- getDefaultSqlSource()
}
jmode <- convertToJSaveMode(mode)
options <- varargsToEnv(...)
if (!is.null(path)) {
- options[["path"]] <- path
+ options[["path"]] <- path
}
write <- callJMethod(df@sdf, "write")
write <- callJMethod(write, "format", source)
write <- callJMethod(write, "mode", jmode)
write <- callJMethod(write, "options", options)
- write <- callJMethod(write, "save", path)
+ write <- handledCallJMethod(write, "save")
})
#' @rdname write.df
diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R
index ce531c3f88..baa87824be 100644
--- a/R/pkg/R/SQLContext.R
+++ b/R/pkg/R/SQLContext.R
@@ -771,6 +771,13 @@ dropTempView <- function(viewName) {
#' @method read.df default
#' @note read.df since 1.4.0
read.df.default <- function(path = NULL, source = NULL, schema = NULL, na.strings = "NA", ...) {
+ if (!is.null(path) && !is.character(path)) {
+ stop("path should be charactor, NULL or omitted.")
+ }
+ if (!is.null(source) && !is.character(source)) {
+ stop("source should be character, NULL or omitted. It is the datasource specified ",
+ "in 'spark.sql.sources.default' configuration by default.")
+ }
sparkSession <- getSparkSession()
options <- varargsToEnv(...)
if (!is.null(path)) {
@@ -784,16 +791,16 @@ read.df.default <- function(path = NULL, source = NULL, schema = NULL, na.string
}
if (!is.null(schema)) {
stopifnot(class(schema) == "structType")
- sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sparkSession, source,
- schema$jobj, options)
+ sdf <- handledCallJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sparkSession,
+ source, schema$jobj, options)
} else {
- sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils",
- "loadDF", sparkSession, source, options)
+ sdf <- handledCallJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sparkSession,
+ source, options)
}
dataFrame(sdf)
}
-read.df <- function(x, ...) {
+read.df <- function(x = NULL, ...) {
dispatchFunc("read.df(path = NULL, source = NULL, schema = NULL, ...)", x, ...)
}
@@ -805,7 +812,7 @@ loadDF.default <- function(path = NULL, source = NULL, schema = NULL, ...) {
read.df(path, source, schema, ...)
}
-loadDF <- function(x, ...) {
+loadDF <- function(x = NULL, ...) {
dispatchFunc("loadDF(path = NULL, source = NULL, schema = NULL, ...)", x, ...)
}
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 67a999da9b..90a02e2778 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -633,7 +633,7 @@ setGeneric("transform", function(`_data`, ...) {standardGeneric("transform") })
#' @rdname write.df
#' @export
-setGeneric("write.df", function(df, path, source = NULL, mode = "error", ...) {
+setGeneric("write.df", function(df, path = NULL, source = NULL, mode = "error", ...) {
standardGeneric("write.df")
})
@@ -732,7 +732,7 @@ setGeneric("withColumnRenamed",
#' @rdname write.df
#' @export
-setGeneric("write.df", function(df, path, ...) { standardGeneric("write.df") })
+setGeneric("write.df", function(df, path = NULL, ...) { standardGeneric("write.df") })
#' @rdname randomSplit
#' @export
diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R
index 248c57532b..e696664534 100644
--- a/R/pkg/R/utils.R
+++ b/R/pkg/R/utils.R
@@ -698,6 +698,58 @@ isSparkRShell <- function() {
grepl(".*shell\\.R$", Sys.getenv("R_PROFILE_USER"), perl = TRUE)
}
+# Works identically with `callJStatic(...)` but throws a pretty formatted exception.
+handledCallJStatic <- function(cls, method, ...) {
+ result <- tryCatch(callJStatic(cls, method, ...),
+ error = function(e) {
+ captureJVMException(e, method)
+ })
+ result
+}
+
+# Works identically with `callJMethod(...)` but throws a pretty formatted exception.
+handledCallJMethod <- function(obj, method, ...) {
+ result <- tryCatch(callJMethod(obj, method, ...),
+ error = function(e) {
+ captureJVMException(e, method)
+ })
+ result
+}
+
+captureJVMException <- function(e, method) {
+ rawmsg <- as.character(e)
+ if (any(grep("^Error in .*?: ", rawmsg))) {
+ # If the exception message starts with "Error in ...", this is possibly
+ # "Error in invokeJava(...)". Here, it replaces the characters to
+ # `paste("Error in", method, ":")` in order to identify which function
+ # was called in JVM side.
+ stacktrace <- strsplit(rawmsg, "Error in .*?: ")[[1]]
+ rmsg <- paste("Error in", method, ":")
+ stacktrace <- paste(rmsg[1], stacktrace[2])
+ } else {
+ # Otherwise, do not convert the error message just in case.
+ stacktrace <- rawmsg
+ }
+
+ if (any(grep("java.lang.IllegalArgumentException: ", stacktrace))) {
+ msg <- strsplit(stacktrace, "java.lang.IllegalArgumentException: ", fixed = TRUE)[[1]]
+ # Extract "Error in ..." message.
+ rmsg <- msg[1]
+ # Extract the first message of JVM exception.
+ first <- strsplit(msg[2], "\r?\n\tat")[[1]][1]
+ stop(paste0(rmsg, "illegal argument - ", first), call. = FALSE)
+ } else if (any(grep("org.apache.spark.sql.AnalysisException: ", stacktrace))) {
+ msg <- strsplit(stacktrace, "org.apache.spark.sql.AnalysisException: ", fixed = TRUE)[[1]]
+ # Extract "Error in ..." message.
+ rmsg <- msg[1]
+ # Extract the first message of JVM exception.
+ first <- strsplit(msg[2], "\r?\n\tat")[[1]][1]
+ stop(paste0(rmsg, "analysis error - ", first), call. = FALSE)
+ } else {
+ stop(stacktrace, call. = FALSE)
+ }
+}
+
# rbind a list of rows with raw (binary) columns
#
# @param inputData a list of rows, with each row a list
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index 9d874a0988..f5ab601f27 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -2544,6 +2544,41 @@ test_that("Spark version from SparkSession", {
expect_equal(ver, version)
})
+test_that("Call DataFrameWriter.save() API in Java without path and check argument types", {
+ df <- read.df(jsonPath, "json")
+ # This tests if the exception is thrown from JVM not from SparkR side.
+ # It makes sure that we can omit path argument in write.df API and then it calls
+ # DataFrameWriter.save() without path.
+ expect_error(write.df(df, source = "csv"),
+ "Error in save : illegal argument - 'path' is not specified")
+
+ # Arguments checking in R side.
+ expect_error(write.df(df, "data.tmp", source = c(1, 2)),
+ paste("source should be character, NULL or omitted. It is the datasource specified",
+ "in 'spark.sql.sources.default' configuration by default."))
+ expect_error(write.df(df, path = c(3)),
+ "path should be charactor, NULL or omitted.")
+ expect_error(write.df(df, mode = TRUE),
+ "mode should be charactor or omitted. It is 'error' by default.")
+})
+
+test_that("Call DataFrameWriter.load() API in Java without path and check argument types", {
+ # This tests if the exception is thrown from JVM not from SparkR side.
+ # It makes sure that we can omit path argument in read.df API and then it calls
+ # DataFrameWriter.load() without path.
+ expect_error(read.df(source = "json"),
+ paste("Error in loadDF : analysis error - Unable to infer schema for JSON at .",
+ "It must be specified manually"))
+ expect_error(read.df("arbitrary_path"), "Error in loadDF : analysis error - Path does not exist")
+
+ # Arguments checking in R side.
+ expect_error(read.df(path = c(3)),
+ "path should be charactor, NULL or omitted.")
+ expect_error(read.df(jsonPath, source = c(1, 2)),
+ paste("source should be character, NULL or omitted. It is the datasource specified",
+ "in 'spark.sql.sources.default' configuration by default."))
+})
+
unlink(parquetPath)
unlink(orcPath)
unlink(jsonPath)
diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R
index 77f25292f3..69ed554916 100644
--- a/R/pkg/inst/tests/testthat/test_utils.R
+++ b/R/pkg/inst/tests/testthat/test_utils.R
@@ -166,6 +166,16 @@ test_that("convertToJSaveMode", {
'mode should be one of "append", "overwrite", "error", "ignore"') #nolint
})
+test_that("captureJVMException", {
+ method <- "getSQLDataType"
+ expect_error(tryCatch(callJStatic("org.apache.spark.sql.api.r.SQLUtils", method,
+ "unknown"),
+ error = function(e) {
+ captureJVMException(e, method)
+ }),
+ "Error in getSQLDataType : illegal argument - Invalid type unknown")
+})
+
test_that("hashCode", {
expect_error(hashCode("bc53d3605e8a5b7de1e8e271c2317645"), NA)
})