From ecd877e8335ff6bb06c96d3045ccade80676e714 Mon Sep 17 00:00:00 2001 From: felixcheung Date: Tue, 19 Apr 2016 15:59:47 -0700 Subject: [SPARK-12224][SPARKR] R support for JDBC source Add R API for `read.jdbc`, `write.jdbc`. Tested this quite a bit manually with different combinations of parameters. It's not clear if we could have automated tests in R for this - Scala `JDBCSuite` depends on Java H2 in-memory database. Refactored some code into util so they could be tested. Core's R SerDe code needs to be updated to allow access to java.util.Properties as `jobj` handle which is required by DataFrameReader/Writer's `jdbc` method. It would be possible, though more code to add a `sql/r/SQLUtils` helper function. Tested: ``` # with postgresql ../bin/sparkR --driver-class-path /usr/share/java/postgresql-9.4.1207.jre7.jar # read.jdbc df <- read.jdbc(sqlContext, "jdbc:postgresql://localhost/db", "films2", user = "user", password = "12345") df <- read.jdbc(sqlContext, "jdbc:postgresql://localhost/db", "films2", user = "user", password = 12345) # partitionColumn and numPartitions test df <- read.jdbc(sqlContext, "jdbc:postgresql://localhost/db", "films2", partitionColumn = "did", lowerBound = 0, upperBound = 200, numPartitions = 4, user = "user", password = 12345) a <- SparkR:::toRDD(df) SparkR:::getNumPartitions(a) [1] 4 SparkR:::collectPartition(a, 2L) # defaultParallelism test df <- read.jdbc(sqlContext, "jdbc:postgresql://localhost/db", "films2", partitionColumn = "did", lowerBound = 0, upperBound = 200, user = "user", password = 12345) SparkR:::getNumPartitions(a) [1] 2 # predicates test df <- read.jdbc(sqlContext, "jdbc:postgresql://localhost/db", "films2", predicates = list("did<=105"), user = "user", password = 12345) count(df) == 1 # write.jdbc, default save mode "error" irisDf <- as.DataFrame(sqlContext, iris) write.jdbc(irisDf, "jdbc:postgresql://localhost/db", "films2", user = "user", password = "12345") "error, already exists" write.jdbc(irisDf, "jdbc:postgresql://localhost/db", "iris", user = "user", password = "12345") ``` Author: felixcheung Closes #10480 from felixcheung/rreadjdbc. --- R/pkg/NAMESPACE | 2 + R/pkg/R/DataFrame.R | 39 ++++++++++++++- R/pkg/R/SQLContext.R | 58 ++++++++++++++++++++++ R/pkg/R/generics.R | 6 +++ R/pkg/R/utils.R | 11 ++++ R/pkg/inst/tests/testthat/test_utils.R | 24 +++++++++ .../main/scala/org/apache/spark/api/r/SerDe.scala | 7 +++ 7 files changed, 146 insertions(+), 1 deletion(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 94ac7e7df7..10b9d16279 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -101,6 +101,7 @@ exportMethods("arrange", "withColumn", "withColumnRenamed", "write.df", + "write.jdbc", "write.json", "write.parquet", "write.text") @@ -284,6 +285,7 @@ export("as.DataFrame", "loadDF", "parquetFile", "read.df", + "read.jdbc", "read.json", "read.parquet", "read.text", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index a64a013b65..ddb056fa71 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2363,7 +2363,7 @@ setMethod("with", #' @examples \dontrun{ #' # Create a DataFrame from the Iris dataset #' irisDF <- createDataFrame(sqlContext, iris) -#' +#' #' # Show the structure of the DataFrame #' str(irisDF) #' } @@ -2468,3 +2468,40 @@ setMethod("drop", function(x) { base::drop(x) }) + +#' Saves the content of the DataFrame to an external database table via JDBC +#' +#' Additional JDBC database connection properties can be set (...) +#' +#' Also, mode is used to specify the behavior of the save operation when +#' data already exists in the data source. There are four modes: \cr +#' append: Contents of this DataFrame are expected to be appended to existing data. \cr +#' overwrite: Existing data is expected to be overwritten by the contents of this DataFrame. \cr +#' error: An exception is expected to be thrown. \cr +#' ignore: The save operation is expected to not save the contents of the DataFrame +#' and to not change the existing data. \cr +#' +#' @param x A SparkSQL DataFrame +#' @param url JDBC database url of the form `jdbc:subprotocol:subname` +#' @param tableName The name of the table in the external database +#' @param mode One of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default) +#' @family DataFrame functions +#' @rdname write.jdbc +#' @name write.jdbc +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' jdbcUrl <- "jdbc:mysql://localhost:3306/databasename" +#' write.jdbc(df, jdbcUrl, "table", user = "username", password = "password") +#' } +setMethod("write.jdbc", + signature(x = "DataFrame", url = "character", tableName = "character"), + function(x, url, tableName, mode = "error", ...){ + jmode <- convertToJSaveMode(mode) + jprops <- varargsToJProperties(...) + write <- callJMethod(x@sdf, "write") + write <- callJMethod(write, "mode", jmode) + invisible(callJMethod(write, "jdbc", url, tableName, jprops)) + }) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 16a2578678..b726c1e1b9 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -583,3 +583,61 @@ createExternalTable <- function(sqlContext, tableName, path = NULL, source = NUL sdf <- callJMethod(sqlContext, "createExternalTable", tableName, source, options) dataFrame(sdf) } + +#' Create a DataFrame representing the database table accessible via JDBC URL +#' +#' Additional JDBC database connection properties can be set (...) +#' +#' Only one of partitionColumn or predicates should be set. Partitions of the table will be +#' retrieved in parallel based on the `numPartitions` or by the predicates. +#' +#' Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash +#' your external database systems. +#' +#' @param sqlContext SQLContext to use +#' @param url JDBC database url of the form `jdbc:subprotocol:subname` +#' @param tableName the name of the table in the external database +#' @param partitionColumn the name of a column of integral type that will be used for partitioning +#' @param lowerBound the minimum value of `partitionColumn` used to decide partition stride +#' @param upperBound the maximum value of `partitionColumn` used to decide partition stride +#' @param numPartitions the number of partitions, This, along with `lowerBound` (inclusive), +#' `upperBound` (exclusive), form partition strides for generated WHERE +#' clause expressions used to split the column `partitionColumn` evenly. +#' This defaults to SparkContext.defaultParallelism when unset. +#' @param predicates a list of conditions in the where clause; each one defines one partition +#' @return DataFrame +#' @rdname read.jdbc +#' @name read.jdbc +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' jdbcUrl <- "jdbc:mysql://localhost:3306/databasename" +#' df <- read.jdbc(sqlContext, jdbcUrl, "table", predicates = list("field<=123"), user = "username") +#' df2 <- read.jdbc(sqlContext, jdbcUrl, "table2", partitionColumn = "index", lowerBound = 0, +#' upperBound = 10000, user = "username", password = "password") +#' } + +read.jdbc <- function(sqlContext, url, tableName, + partitionColumn = NULL, lowerBound = NULL, upperBound = NULL, + numPartitions = 0L, predicates = list(), ...) { + jprops <- varargsToJProperties(...) + + read <- callJMethod(sqlContext, "read") + if (!is.null(partitionColumn)) { + if (is.null(numPartitions) || numPartitions == 0) { + sc <- callJMethod(sqlContext, "sparkContext") + numPartitions <- callJMethod(sc, "defaultParallelism") + } else { + numPartitions <- numToInt(numPartitions) + } + sdf <- callJMethod(read, "jdbc", url, tableName, as.character(partitionColumn), + numToInt(lowerBound), numToInt(upperBound), numPartitions, jprops) + } else if (length(predicates) > 0) { + sdf <- callJMethod(read, "jdbc", url, tableName, as.list(as.character(predicates)), jprops) + } else { + sdf <- callJMethod(read, "jdbc", url, tableName, jprops) + } + dataFrame(sdf) +} diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index ecdeea5ec4..4ef05d56bf 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -577,6 +577,12 @@ setGeneric("saveDF", function(df, path, source = NULL, mode = "error", ...) { standardGeneric("saveDF") }) +#' @rdname write.jdbc +#' @export +setGeneric("write.jdbc", function(x, url, tableName, mode = "error", ...) { + standardGeneric("write.jdbc") +}) + #' @rdname write.json #' @export setGeneric("write.json", function(x, path) { standardGeneric("write.json") }) diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index fb6575cb42..b425ccf6e7 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -650,3 +650,14 @@ convertToJSaveMode <- function(mode) { jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) jmode } + +varargsToJProperties <- function(...) { + pairs <- list(...) + props <- newJObject("java.util.Properties") + if (length(pairs) > 0) { + lapply(ls(pairs), function(k) { + callJMethod(props, "setProperty", as.character(k), as.character(pairs[[k]])) + }) + } + props +} diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R index 4218138f64..01694ab5c4 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/inst/tests/testthat/test_utils.R @@ -140,3 +140,27 @@ test_that("cleanClosure on R functions", { expect_equal(ls(env), "aBroadcast") expect_equal(get("aBroadcast", envir = env, inherits = FALSE), aBroadcast) }) + +test_that("varargsToJProperties", { + jprops <- newJObject("java.util.Properties") + expect_true(class(jprops) == "jobj") + + jprops <- varargsToJProperties(abc = "123") + expect_true(class(jprops) == "jobj") + expect_equal(callJMethod(jprops, "getProperty", "abc"), "123") + + jprops <- varargsToJProperties(abc = "abc", b = 1) + expect_equal(callJMethod(jprops, "getProperty", "abc"), "abc") + expect_equal(callJMethod(jprops, "getProperty", "b"), "1") + + jprops <- varargsToJProperties() + expect_equal(callJMethod(jprops, "size"), 0L) +}) + +test_that("convertToJSaveMode", { + s <- convertToJSaveMode("error") + expect_true(class(s) == "jobj") + expect_match(capture.output(print.jobj(s)), "Java ref type org.apache.spark.sql.SaveMode id ") + expect_error(convertToJSaveMode("foo"), + 'mode should be one of "append", "overwrite", "error", "ignore"') #nolint +}) diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index 48df5bedd6..8e4e80a24a 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -356,6 +356,13 @@ private[spark] object SerDe { writeInt(dos, v.length) v.foreach(elem => writeObject(dos, elem)) + // Handle Properties + // This must be above the case java.util.Map below. + // (Properties implements Map and will be serialized as map otherwise) + case v: java.util.Properties => + writeType(dos, "jobj") + writeJObj(dos, value) + // Handle map case v: java.util.Map[_, _] => writeType(dos, "map") -- cgit v1.2.3