aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJustin Pihony <justin.pihony@gmail.com>2016-09-26 09:54:22 +0100
committerSean Owen <sowen@cloudera.com>2016-09-26 09:54:22 +0100
commit50b89d05b7bffc212cc9b9ae6e0bca7cb90b9c77 (patch)
tree004018c95e9fedc204d683c210af79ac43bd4212
parentac65139be96dbf87402b9a85729a93afd3c6ff17 (diff)
downloadspark-50b89d05b7bffc212cc9b9ae6e0bca7cb90b9c77.tar.gz
spark-50b89d05b7bffc212cc9b9ae6e0bca7cb90b9c77.tar.bz2
spark-50b89d05b7bffc212cc9b9ae6e0bca7cb90b9c77.zip
[SPARK-14525][SQL] Make DataFrameWrite.save work for jdbc
## What changes were proposed in this pull request? This change modifies the implementation of DataFrameWriter.save such that it works with jdbc, and the call to jdbc merely delegates to save. ## How was this patch tested? This was tested via unit tests in the JDBCWriteSuite, of which I added one new test to cover this scenario. ## Additional details rxin This seems to have been most recently touched by you and was also commented on in the JIRA. This contribution is my original work and I license the work to the project under the project's open source license. Author: Justin Pihony <justin.pihony@gmail.com> Author: Justin Pihony <justin.pihony@typesafe.com> Closes #12601 from JustinPihony/jdbc_reconciliation.
-rw-r--r--docs/sql-programming-guide.md6
-rw-r--r--examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java21
-rw-r--r--examples/src/main/python/sql/datasource.py19
-rw-r--r--examples/src/main/r/RSparkSQLExample.R4
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala22
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala59
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala95
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala82
9 files changed, 246 insertions, 73 deletions
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index 4ac5fae566..71bdd19c16 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -1100,9 +1100,13 @@ CREATE TEMPORARY VIEW jdbcTable
USING org.apache.spark.sql.jdbc
OPTIONS (
url "jdbc:postgresql:dbserver",
- dbtable "schema.tablename"
+ dbtable "schema.tablename",
+ user 'username',
+ password 'password'
)
+INSERT INTO TABLE jdbcTable
+SELECT * FROM resultTable
{% endhighlight %}
</div>
diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java
index f9087e0593..1860594e8e 100644
--- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java
@@ -22,6 +22,7 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
// $example off:schema_merging$
+import java.util.Properties;
// $example on:basic_parquet_example$
import org.apache.spark.api.java.JavaRDD;
@@ -235,6 +236,8 @@ public class JavaSQLDataSourceExample {
private static void runJdbcDatasetExample(SparkSession spark) {
// $example on:jdbc_dataset$
+ // Note: JDBC loading and saving can be achieved via either the load/save or jdbc methods
+ // Loading data from a JDBC source
Dataset<Row> jdbcDF = spark.read()
.format("jdbc")
.option("url", "jdbc:postgresql:dbserver")
@@ -242,6 +245,24 @@ public class JavaSQLDataSourceExample {
.option("user", "username")
.option("password", "password")
.load();
+
+ Properties connectionProperties = new Properties();
+ connectionProperties.put("user", "username");
+ connectionProperties.put("password", "password");
+ Dataset<Row> jdbcDF2 = spark.read()
+ .jdbc("jdbc:postgresql:dbserver", "schema.tablename", connectionProperties);
+
+ // Saving data to a JDBC source
+ jdbcDF.write()
+ .format("jdbc")
+ .option("url", "jdbc:postgresql:dbserver")
+ .option("dbtable", "schema.tablename")
+ .option("user", "username")
+ .option("password", "password")
+ .save();
+
+ jdbcDF2.write()
+ .jdbc("jdbc:postgresql:dbserver", "schema.tablename", connectionProperties);
// $example off:jdbc_dataset$
}
}
diff --git a/examples/src/main/python/sql/datasource.py b/examples/src/main/python/sql/datasource.py
index b36c901d2b..e9aa9d9ac2 100644
--- a/examples/src/main/python/sql/datasource.py
+++ b/examples/src/main/python/sql/datasource.py
@@ -143,6 +143,8 @@ def json_dataset_example(spark):
def jdbc_dataset_example(spark):
# $example on:jdbc_dataset$
+ # Note: JDBC loading and saving can be achieved via either the load/save or jdbc methods
+ # Loading data from a JDBC source
jdbcDF = spark.read \
.format("jdbc") \
.option("url", "jdbc:postgresql:dbserver") \
@@ -150,6 +152,23 @@ def jdbc_dataset_example(spark):
.option("user", "username") \
.option("password", "password") \
.load()
+
+ jdbcDF2 = spark.read \
+ .jdbc("jdbc:postgresql:dbserver", "schema.tablename",
+ properties={"user": "username", "password": "password"})
+
+ # Saving data to a JDBC source
+ jdbcDF.write \
+ .format("jdbc") \
+ .option("url", "jdbc:postgresql:dbserver") \
+ .option("dbtable", "schema.tablename") \
+ .option("user", "username") \
+ .option("password", "password") \
+ .save()
+
+ jdbcDF2.write \
+ .jdbc("jdbc:postgresql:dbserver", "schema.tablename",
+ properties={"user": "username", "password": "password"})
# $example off:jdbc_dataset$
diff --git a/examples/src/main/r/RSparkSQLExample.R b/examples/src/main/r/RSparkSQLExample.R
index 4e0267a038..373a36dba1 100644
--- a/examples/src/main/r/RSparkSQLExample.R
+++ b/examples/src/main/r/RSparkSQLExample.R
@@ -204,7 +204,11 @@ results <- collect(sql("FROM src SELECT key, value"))
# $example on:jdbc_dataset$
+# Loading data from a JDBC source
df <- read.jdbc("jdbc:postgresql:dbserver", "schema.tablename", user = "username", password = "password")
+
+# Saving data to a JDBC source
+write.jdbc(df, "jdbc:postgresql:dbserver", "schema.tablename", user = "username", password = "password")
# $example off:jdbc_dataset$
# Stop the SparkSession now
diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala
index dc3915a488..66f7cb1b53 100644
--- a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala
@@ -16,6 +16,8 @@
*/
package org.apache.spark.examples.sql
+import java.util.Properties
+
import org.apache.spark.sql.SparkSession
object SQLDataSourceExample {
@@ -148,6 +150,8 @@ object SQLDataSourceExample {
private def runJdbcDatasetExample(spark: SparkSession): Unit = {
// $example on:jdbc_dataset$
+ // Note: JDBC loading and saving can be achieved via either the load/save or jdbc methods
+ // Loading data from a JDBC source
val jdbcDF = spark.read
.format("jdbc")
.option("url", "jdbc:postgresql:dbserver")
@@ -155,6 +159,24 @@ object SQLDataSourceExample {
.option("user", "username")
.option("password", "password")
.load()
+
+ val connectionProperties = new Properties()
+ connectionProperties.put("user", "username")
+ connectionProperties.put("password", "password")
+ val jdbcDF2 = spark.read
+ .jdbc("jdbc:postgresql:dbserver", "schema.tablename", connectionProperties)
+
+ // Saving data to a JDBC source
+ jdbcDF.write
+ .format("jdbc")
+ .option("url", "jdbc:postgresql:dbserver")
+ .option("dbtable", "schema.tablename")
+ .option("user", "username")
+ .option("password", "password")
+ .save()
+
+ jdbcDF2.write
+ .jdbc("jdbc:postgresql:dbserver", "schema.tablename", connectionProperties)
// $example off:jdbc_dataset$
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index 64d3422cb4..7374a8e045 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -425,62 +425,11 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
def jdbc(url: String, table: String, connectionProperties: Properties): Unit = {
assertNotPartitioned("jdbc")
assertNotBucketed("jdbc")
-
- // to add required options like URL and dbtable
- val params = extraOptions.toMap ++ Map("url" -> url, "dbtable" -> table)
- val jdbcOptions = new JDBCOptions(params)
- val jdbcUrl = jdbcOptions.url
- val jdbcTable = jdbcOptions.table
-
- val props = new Properties()
- extraOptions.foreach { case (key, value) =>
- props.put(key, value)
- }
// connectionProperties should override settings in extraOptions
- props.putAll(connectionProperties)
- val conn = JdbcUtils.createConnectionFactory(jdbcUrl, props)()
-
- try {
- var tableExists = JdbcUtils.tableExists(conn, jdbcUrl, jdbcTable)
-
- if (mode == SaveMode.Ignore && tableExists) {
- return
- }
-
- if (mode == SaveMode.ErrorIfExists && tableExists) {
- sys.error(s"Table $jdbcTable already exists.")
- }
-
- if (mode == SaveMode.Overwrite && tableExists) {
- if (jdbcOptions.isTruncate &&
- JdbcUtils.isCascadingTruncateTable(jdbcUrl) == Some(false)) {
- JdbcUtils.truncateTable(conn, jdbcTable)
- } else {
- JdbcUtils.dropTable(conn, jdbcTable)
- tableExists = false
- }
- }
-
- // Create the table if the table didn't exist.
- if (!tableExists) {
- val schema = JdbcUtils.schemaString(df, jdbcUrl)
- // To allow certain options to append when create a new table, which can be
- // table_options or partition_options.
- // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8"
- val createtblOptions = jdbcOptions.createTableOptions
- val sql = s"CREATE TABLE $jdbcTable ($schema) $createtblOptions"
- val statement = conn.createStatement
- try {
- statement.executeUpdate(sql)
- } finally {
- statement.close()
- }
- }
- } finally {
- conn.close()
- }
-
- JdbcUtils.saveTable(df, jdbcUrl, jdbcTable, props)
+ this.extraOptions = this.extraOptions ++ (connectionProperties.asScala)
+ // explicit url and dbtable should override all
+ this.extraOptions += ("url" -> url, "dbtable" -> table)
+ format("jdbc").save()
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
index 1db090eaf9..bcf65e53af 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
@@ -27,10 +27,12 @@ class JDBCOptions(
// ------------------------------------------------------------
// Required parameters
// ------------------------------------------------------------
+ require(parameters.isDefinedAt("url"), "Option 'url' is required.")
+ require(parameters.isDefinedAt("dbtable"), "Option 'dbtable' is required.")
// a JDBC URL
- val url = parameters.getOrElse("url", sys.error("Option 'url' not specified"))
+ val url = parameters("url")
// name of table
- val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified"))
+ val table = parameters("dbtable")
// ------------------------------------------------------------
// Optional parameter list
@@ -44,6 +46,11 @@ class JDBCOptions(
// the number of partitions
val numPartitions = parameters.getOrElse("numPartitions", null)
+ require(partitionColumn == null ||
+ (lowerBound != null && upperBound != null && numPartitions != null),
+ "If 'partitionColumn' is specified then 'lowerBound', 'upperBound'," +
+ " and 'numPartitions' are required.")
+
// ------------------------------------------------------------
// The options for DataFrameWriter
// ------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
index 106ed1d440..ae04af2479 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
@@ -19,37 +19,102 @@ package org.apache.spark.sql.execution.datasources.jdbc
import java.util.Properties
-import org.apache.spark.sql.SQLContext
-import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, RelationProvider}
+import scala.collection.JavaConverters.mapAsJavaMapConverter
-class JdbcRelationProvider extends RelationProvider with DataSourceRegister {
+import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext}
+import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, RelationProvider}
+
+class JdbcRelationProvider extends CreatableRelationProvider
+ with RelationProvider with DataSourceRegister {
override def shortName(): String = "jdbc"
- /** Returns a new base relation with the given parameters. */
override def createRelation(
sqlContext: SQLContext,
parameters: Map[String, String]): BaseRelation = {
val jdbcOptions = new JDBCOptions(parameters)
- if (jdbcOptions.partitionColumn != null
- && (jdbcOptions.lowerBound == null
- || jdbcOptions.upperBound == null
- || jdbcOptions.numPartitions == null)) {
- sys.error("Partitioning incompletely specified")
- }
+ val partitionColumn = jdbcOptions.partitionColumn
+ val lowerBound = jdbcOptions.lowerBound
+ val upperBound = jdbcOptions.upperBound
+ val numPartitions = jdbcOptions.numPartitions
- val partitionInfo = if (jdbcOptions.partitionColumn == null) {
+ val partitionInfo = if (partitionColumn == null) {
null
} else {
JDBCPartitioningInfo(
- jdbcOptions.partitionColumn,
- jdbcOptions.lowerBound.toLong,
- jdbcOptions.upperBound.toLong,
- jdbcOptions.numPartitions.toInt)
+ partitionColumn, lowerBound.toLong, upperBound.toLong, numPartitions.toInt)
}
val parts = JDBCRelation.columnPartition(partitionInfo)
val properties = new Properties() // Additional properties that we will pass to getConnection
parameters.foreach(kv => properties.setProperty(kv._1, kv._2))
JDBCRelation(jdbcOptions.url, jdbcOptions.table, parts, properties)(sqlContext.sparkSession)
}
+
+ /*
+ * The following structure applies to this code:
+ * | tableExists | !tableExists
+ *------------------------------------------------------------------------------------
+ * Ignore | BaseRelation | CreateTable, saveTable, BaseRelation
+ * ErrorIfExists | ERROR | CreateTable, saveTable, BaseRelation
+ * Overwrite* | (DropTable, CreateTable,) | CreateTable, saveTable, BaseRelation
+ * | saveTable, BaseRelation |
+ * Append | saveTable, BaseRelation | CreateTable, saveTable, BaseRelation
+ *
+ * *Overwrite & tableExists with truncate, will not drop & create, but instead truncate
+ */
+ override def createRelation(
+ sqlContext: SQLContext,
+ mode: SaveMode,
+ parameters: Map[String, String],
+ data: DataFrame): BaseRelation = {
+ val jdbcOptions = new JDBCOptions(parameters)
+ val url = jdbcOptions.url
+ val table = jdbcOptions.table
+
+ val props = new Properties()
+ props.putAll(parameters.asJava)
+ val conn = JdbcUtils.createConnectionFactory(url, props)()
+
+ try {
+ val tableExists = JdbcUtils.tableExists(conn, url, table)
+
+ val (doCreate, doSave) = (mode, tableExists) match {
+ case (SaveMode.Ignore, true) => (false, false)
+ case (SaveMode.ErrorIfExists, true) => throw new AnalysisException(
+ s"Table or view '$table' already exists, and SaveMode is set to ErrorIfExists.")
+ case (SaveMode.Overwrite, true) =>
+ if (jdbcOptions.isTruncate && JdbcUtils.isCascadingTruncateTable(url) == Some(false)) {
+ JdbcUtils.truncateTable(conn, table)
+ (false, true)
+ } else {
+ JdbcUtils.dropTable(conn, table)
+ (true, true)
+ }
+ case (SaveMode.Append, true) => (false, true)
+ case (_, true) => throw new IllegalArgumentException(s"Unexpected SaveMode, '$mode'," +
+ " for handling existing tables.")
+ case (_, false) => (true, true)
+ }
+
+ if (doCreate) {
+ val schema = JdbcUtils.schemaString(data, url)
+ // To allow certain options to append when create a new table, which can be
+ // table_options or partition_options.
+ // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8"
+ val createtblOptions = jdbcOptions.createTableOptions
+ val sql = s"CREATE TABLE $table ($schema) $createtblOptions"
+ val statement = conn.createStatement
+ try {
+ statement.executeUpdate(sql)
+ } finally {
+ statement.close()
+ }
+ }
+ if (doSave) JdbcUtils.saveTable(data, url, table, props)
+ } finally {
+ conn.close()
+ }
+
+ createRelation(sqlContext, parameters)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
index ff3309874f..506971362f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
@@ -20,6 +20,8 @@ package org.apache.spark.sql.jdbc
import java.sql.DriverManager
import java.util.Properties
+import scala.collection.JavaConverters.propertiesAsScalaMapConverter
+
import org.scalatest.BeforeAndAfter
import org.apache.spark.SparkException
@@ -208,4 +210,84 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).count())
assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length)
}
+
+ test("save works for format(\"jdbc\") if url and dbtable are set") {
+ val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
+
+ df.write.format("jdbc")
+ .options(Map("url" -> url, "dbtable" -> "TEST.SAVETEST"))
+ .save()
+
+ assert(2 === sqlContext.read.jdbc(url, "TEST.SAVETEST", new Properties).count)
+ assert(
+ 2 === sqlContext.read.jdbc(url, "TEST.SAVETEST", new Properties).collect()(0).length)
+ }
+
+ test("save API with SaveMode.Overwrite") {
+ val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
+ val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2)
+
+ df.write.format("jdbc")
+ .option("url", url1)
+ .option("dbtable", "TEST.SAVETEST")
+ .options(properties.asScala)
+ .save()
+ df2.write.mode(SaveMode.Overwrite).format("jdbc")
+ .option("url", url1)
+ .option("dbtable", "TEST.SAVETEST")
+ .options(properties.asScala)
+ .save()
+ assert(1 === spark.read.jdbc(url1, "TEST.SAVETEST", properties).count())
+ assert(2 === spark.read.jdbc(url1, "TEST.SAVETEST", properties).collect()(0).length)
+ }
+
+ test("save errors if url is not specified") {
+ val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
+
+ val e = intercept[RuntimeException] {
+ df.write.format("jdbc")
+ .option("dbtable", "TEST.SAVETEST")
+ .options(properties.asScala)
+ .save()
+ }.getMessage
+ assert(e.contains("Option 'url' is required"))
+ }
+
+ test("save errors if dbtable is not specified") {
+ val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
+
+ val e = intercept[RuntimeException] {
+ df.write.format("jdbc")
+ .option("url", url1)
+ .options(properties.asScala)
+ .save()
+ }.getMessage
+ assert(e.contains("Option 'dbtable' is required"))
+ }
+
+ test("save errors if wrong user/password combination") {
+ val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
+
+ val e = intercept[org.h2.jdbc.JdbcSQLException] {
+ df.write.format("jdbc")
+ .option("dbtable", "TEST.SAVETEST")
+ .option("url", url1)
+ .save()
+ }.getMessage
+ assert(e.contains("Wrong user name or password"))
+ }
+
+ test("save errors if partitionColumn and numPartitions and bounds not set") {
+ val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
+
+ val e = intercept[java.lang.IllegalArgumentException] {
+ df.write.format("jdbc")
+ .option("dbtable", "TEST.SAVETEST")
+ .option("url", url1)
+ .option("partitionColumn", "foo")
+ .save()
+ }.getMessage
+ assert(e.contains("If 'partitionColumn' is specified then 'lowerBound', 'upperBound'," +
+ " and 'numPartitions' are required."))
+ }
}