aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorfelixcheung <felixcheung_m@hotmail.com>2016-04-19 15:59:47 -0700
committerShivaram Venkataraman <shivaram@cs.berkeley.edu>2016-04-19 15:59:47 -0700
commitecd877e8335ff6bb06c96d3045ccade80676e714 (patch)
tree95af2af9dc9d84807f1f8b4386fa91b681c5b2d0
parent008a8bbef0d3475610c13fff778a425900912650 (diff)
downloadspark-ecd877e8335ff6bb06c96d3045ccade80676e714.tar.gz
spark-ecd877e8335ff6bb06c96d3045ccade80676e714.tar.bz2
spark-ecd877e8335ff6bb06c96d3045ccade80676e714.zip
[SPARK-12224][SPARKR] R support for JDBC source
Add R API for `read.jdbc`, `write.jdbc`. Tested this quite a bit manually with different combinations of parameters. It's not clear if we could have automated tests in R for this - Scala `JDBCSuite` depends on Java H2 in-memory database. Refactored some code into util so they could be tested. Core's R SerDe code needs to be updated to allow access to java.util.Properties as `jobj` handle which is required by DataFrameReader/Writer's `jdbc` method. It would be possible, though more code to add a `sql/r/SQLUtils` helper function. Tested: ``` # with postgresql ../bin/sparkR --driver-class-path /usr/share/java/postgresql-9.4.1207.jre7.jar # read.jdbc df <- read.jdbc(sqlContext, "jdbc:postgresql://localhost/db", "films2", user = "user", password = "12345") df <- read.jdbc(sqlContext, "jdbc:postgresql://localhost/db", "films2", user = "user", password = 12345) # partitionColumn and numPartitions test df <- read.jdbc(sqlContext, "jdbc:postgresql://localhost/db", "films2", partitionColumn = "did", lowerBound = 0, upperBound = 200, numPartitions = 4, user = "user", password = 12345) a <- SparkR:::toRDD(df) SparkR:::getNumPartitions(a) [1] 4 SparkR:::collectPartition(a, 2L) # defaultParallelism test df <- read.jdbc(sqlContext, "jdbc:postgresql://localhost/db", "films2", partitionColumn = "did", lowerBound = 0, upperBound = 200, user = "user", password = 12345) SparkR:::getNumPartitions(a) [1] 2 # predicates test df <- read.jdbc(sqlContext, "jdbc:postgresql://localhost/db", "films2", predicates = list("did<=105"), user = "user", password = 12345) count(df) == 1 # write.jdbc, default save mode "error" irisDf <- as.DataFrame(sqlContext, iris) write.jdbc(irisDf, "jdbc:postgresql://localhost/db", "films2", user = "user", password = "12345") "error, already exists" write.jdbc(irisDf, "jdbc:postgresql://localhost/db", "iris", user = "user", password = "12345") ``` Author: felixcheung <felixcheung_m@hotmail.com> Closes #10480 from felixcheung/rreadjdbc.
-rw-r--r--R/pkg/NAMESPACE2
-rw-r--r--R/pkg/R/DataFrame.R39
-rw-r--r--R/pkg/R/SQLContext.R58
-rw-r--r--R/pkg/R/generics.R6
-rw-r--r--R/pkg/R/utils.R11
-rw-r--r--R/pkg/inst/tests/testthat/test_utils.R24
-rw-r--r--core/src/main/scala/org/apache/spark/api/r/SerDe.scala7
7 files changed, 146 insertions, 1 deletions
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 94ac7e7df7..10b9d16279 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -101,6 +101,7 @@ exportMethods("arrange",
"withColumn",
"withColumnRenamed",
"write.df",
+ "write.jdbc",
"write.json",
"write.parquet",
"write.text")
@@ -284,6 +285,7 @@ export("as.DataFrame",
"loadDF",
"parquetFile",
"read.df",
+ "read.jdbc",
"read.json",
"read.parquet",
"read.text",
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index a64a013b65..ddb056fa71 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -2363,7 +2363,7 @@ setMethod("with",
#' @examples \dontrun{
#' # Create a DataFrame from the Iris dataset
#' irisDF <- createDataFrame(sqlContext, iris)
-#'
+#'
#' # Show the structure of the DataFrame
#' str(irisDF)
#' }
@@ -2468,3 +2468,40 @@ setMethod("drop",
function(x) {
base::drop(x)
})
+
+#' Saves the content of the DataFrame to an external database table via JDBC
+#'
+#' Additional JDBC database connection properties can be set (...)
+#'
+#' Also, mode is used to specify the behavior of the save operation when
+#' data already exists in the data source. There are four modes: \cr
+#' append: Contents of this DataFrame are expected to be appended to existing data. \cr
+#' overwrite: Existing data is expected to be overwritten by the contents of this DataFrame. \cr
+#' error: An exception is expected to be thrown. \cr
+#' ignore: The save operation is expected to not save the contents of the DataFrame
+#' and to not change the existing data. \cr
+#'
+#' @param x A SparkSQL DataFrame
+#' @param url JDBC database url of the form `jdbc:subprotocol:subname`
+#' @param tableName The name of the table in the external database
+#' @param mode One of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default)
+#' @family DataFrame functions
+#' @rdname write.jdbc
+#' @name write.jdbc
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlContext <- sparkRSQL.init(sc)
+#' jdbcUrl <- "jdbc:mysql://localhost:3306/databasename"
+#' write.jdbc(df, jdbcUrl, "table", user = "username", password = "password")
+#' }
+setMethod("write.jdbc",
+ signature(x = "DataFrame", url = "character", tableName = "character"),
+ function(x, url, tableName, mode = "error", ...){
+ jmode <- convertToJSaveMode(mode)
+ jprops <- varargsToJProperties(...)
+ write <- callJMethod(x@sdf, "write")
+ write <- callJMethod(write, "mode", jmode)
+ invisible(callJMethod(write, "jdbc", url, tableName, jprops))
+ })
diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R
index 16a2578678..b726c1e1b9 100644
--- a/R/pkg/R/SQLContext.R
+++ b/R/pkg/R/SQLContext.R
@@ -583,3 +583,61 @@ createExternalTable <- function(sqlContext, tableName, path = NULL, source = NUL
sdf <- callJMethod(sqlContext, "createExternalTable", tableName, source, options)
dataFrame(sdf)
}
+
+#' Create a DataFrame representing the database table accessible via JDBC URL
+#'
+#' Additional JDBC database connection properties can be set (...)
+#'
+#' Only one of partitionColumn or predicates should be set. Partitions of the table will be
+#' retrieved in parallel based on the `numPartitions` or by the predicates.
+#'
+#' Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash
+#' your external database systems.
+#'
+#' @param sqlContext SQLContext to use
+#' @param url JDBC database url of the form `jdbc:subprotocol:subname`
+#' @param tableName the name of the table in the external database
+#' @param partitionColumn the name of a column of integral type that will be used for partitioning
+#' @param lowerBound the minimum value of `partitionColumn` used to decide partition stride
+#' @param upperBound the maximum value of `partitionColumn` used to decide partition stride
+#' @param numPartitions the number of partitions, This, along with `lowerBound` (inclusive),
+#' `upperBound` (exclusive), form partition strides for generated WHERE
+#' clause expressions used to split the column `partitionColumn` evenly.
+#' This defaults to SparkContext.defaultParallelism when unset.
+#' @param predicates a list of conditions in the where clause; each one defines one partition
+#' @return DataFrame
+#' @rdname read.jdbc
+#' @name read.jdbc
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlContext <- sparkRSQL.init(sc)
+#' jdbcUrl <- "jdbc:mysql://localhost:3306/databasename"
+#' df <- read.jdbc(sqlContext, jdbcUrl, "table", predicates = list("field<=123"), user = "username")
+#' df2 <- read.jdbc(sqlContext, jdbcUrl, "table2", partitionColumn = "index", lowerBound = 0,
+#' upperBound = 10000, user = "username", password = "password")
+#' }
+
+read.jdbc <- function(sqlContext, url, tableName,
+ partitionColumn = NULL, lowerBound = NULL, upperBound = NULL,
+ numPartitions = 0L, predicates = list(), ...) {
+ jprops <- varargsToJProperties(...)
+
+ read <- callJMethod(sqlContext, "read")
+ if (!is.null(partitionColumn)) {
+ if (is.null(numPartitions) || numPartitions == 0) {
+ sc <- callJMethod(sqlContext, "sparkContext")
+ numPartitions <- callJMethod(sc, "defaultParallelism")
+ } else {
+ numPartitions <- numToInt(numPartitions)
+ }
+ sdf <- callJMethod(read, "jdbc", url, tableName, as.character(partitionColumn),
+ numToInt(lowerBound), numToInt(upperBound), numPartitions, jprops)
+ } else if (length(predicates) > 0) {
+ sdf <- callJMethod(read, "jdbc", url, tableName, as.list(as.character(predicates)), jprops)
+ } else {
+ sdf <- callJMethod(read, "jdbc", url, tableName, jprops)
+ }
+ dataFrame(sdf)
+}
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index ecdeea5ec4..4ef05d56bf 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -577,6 +577,12 @@ setGeneric("saveDF", function(df, path, source = NULL, mode = "error", ...) {
standardGeneric("saveDF")
})
+#' @rdname write.jdbc
+#' @export
+setGeneric("write.jdbc", function(x, url, tableName, mode = "error", ...) {
+ standardGeneric("write.jdbc")
+})
+
#' @rdname write.json
#' @export
setGeneric("write.json", function(x, path) { standardGeneric("write.json") })
diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R
index fb6575cb42..b425ccf6e7 100644
--- a/R/pkg/R/utils.R
+++ b/R/pkg/R/utils.R
@@ -650,3 +650,14 @@ convertToJSaveMode <- function(mode) {
jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode)
jmode
}
+
+varargsToJProperties <- function(...) {
+ pairs <- list(...)
+ props <- newJObject("java.util.Properties")
+ if (length(pairs) > 0) {
+ lapply(ls(pairs), function(k) {
+ callJMethod(props, "setProperty", as.character(k), as.character(pairs[[k]]))
+ })
+ }
+ props
+}
diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R
index 4218138f64..01694ab5c4 100644
--- a/R/pkg/inst/tests/testthat/test_utils.R
+++ b/R/pkg/inst/tests/testthat/test_utils.R
@@ -140,3 +140,27 @@ test_that("cleanClosure on R functions", {
expect_equal(ls(env), "aBroadcast")
expect_equal(get("aBroadcast", envir = env, inherits = FALSE), aBroadcast)
})
+
+test_that("varargsToJProperties", {
+ jprops <- newJObject("java.util.Properties")
+ expect_true(class(jprops) == "jobj")
+
+ jprops <- varargsToJProperties(abc = "123")
+ expect_true(class(jprops) == "jobj")
+ expect_equal(callJMethod(jprops, "getProperty", "abc"), "123")
+
+ jprops <- varargsToJProperties(abc = "abc", b = 1)
+ expect_equal(callJMethod(jprops, "getProperty", "abc"), "abc")
+ expect_equal(callJMethod(jprops, "getProperty", "b"), "1")
+
+ jprops <- varargsToJProperties()
+ expect_equal(callJMethod(jprops, "size"), 0L)
+})
+
+test_that("convertToJSaveMode", {
+ s <- convertToJSaveMode("error")
+ expect_true(class(s) == "jobj")
+ expect_match(capture.output(print.jobj(s)), "Java ref type org.apache.spark.sql.SaveMode id ")
+ expect_error(convertToJSaveMode("foo"),
+ 'mode should be one of "append", "overwrite", "error", "ignore"') #nolint
+})
diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
index 48df5bedd6..8e4e80a24a 100644
--- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
@@ -356,6 +356,13 @@ private[spark] object SerDe {
writeInt(dos, v.length)
v.foreach(elem => writeObject(dos, elem))
+ // Handle Properties
+ // This must be above the case java.util.Map below.
+ // (Properties implements Map<Object,Object> and will be serialized as map otherwise)
+ case v: java.util.Properties =>
+ writeType(dos, "jobj")
+ writeJObj(dos, value)
+
// Handle map
case v: java.util.Map[_, _] =>
writeType(dos, "map")