aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorJustin Pihony <justin.pihony@gmail.com>2016-09-26 09:54:22 +0100
committerSean Owen <sowen@cloudera.com>2016-09-26 09:54:22 +0100
commit50b89d05b7bffc212cc9b9ae6e0bca7cb90b9c77 (patch)
tree004018c95e9fedc204d683c210af79ac43bd4212 /sql
parentac65139be96dbf87402b9a85729a93afd3c6ff17 (diff)
downloadspark-50b89d05b7bffc212cc9b9ae6e0bca7cb90b9c77.tar.gz
spark-50b89d05b7bffc212cc9b9ae6e0bca7cb90b9c77.tar.bz2
spark-50b89d05b7bffc212cc9b9ae6e0bca7cb90b9c77.zip
[SPARK-14525][SQL] Make DataFrameWrite.save work for jdbc
## What changes were proposed in this pull request? This change modifies the implementation of DataFrameWriter.save such that it works with jdbc, and the call to jdbc merely delegates to save. ## How was this patch tested? This was tested via unit tests in the JDBCWriteSuite, of which I added one new test to cover this scenario. ## Additional details rxin This seems to have been most recently touched by you and was also commented on in the JIRA. This contribution is my original work and I license the work to the project under the project's open source license. Author: Justin Pihony <justin.pihony@gmail.com> Author: Justin Pihony <justin.pihony@typesafe.com> Closes #12601 from JustinPihony/jdbc_reconciliation.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala59
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala95
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala82
4 files changed, 175 insertions, 72 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index 64d3422cb4..7374a8e045 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -425,62 +425,11 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
def jdbc(url: String, table: String, connectionProperties: Properties): Unit = {
assertNotPartitioned("jdbc")
assertNotBucketed("jdbc")
-
- // to add required options like URL and dbtable
- val params = extraOptions.toMap ++ Map("url" -> url, "dbtable" -> table)
- val jdbcOptions = new JDBCOptions(params)
- val jdbcUrl = jdbcOptions.url
- val jdbcTable = jdbcOptions.table
-
- val props = new Properties()
- extraOptions.foreach { case (key, value) =>
- props.put(key, value)
- }
// connectionProperties should override settings in extraOptions
- props.putAll(connectionProperties)
- val conn = JdbcUtils.createConnectionFactory(jdbcUrl, props)()
-
- try {
- var tableExists = JdbcUtils.tableExists(conn, jdbcUrl, jdbcTable)
-
- if (mode == SaveMode.Ignore && tableExists) {
- return
- }
-
- if (mode == SaveMode.ErrorIfExists && tableExists) {
- sys.error(s"Table $jdbcTable already exists.")
- }
-
- if (mode == SaveMode.Overwrite && tableExists) {
- if (jdbcOptions.isTruncate &&
- JdbcUtils.isCascadingTruncateTable(jdbcUrl) == Some(false)) {
- JdbcUtils.truncateTable(conn, jdbcTable)
- } else {
- JdbcUtils.dropTable(conn, jdbcTable)
- tableExists = false
- }
- }
-
- // Create the table if the table didn't exist.
- if (!tableExists) {
- val schema = JdbcUtils.schemaString(df, jdbcUrl)
- // To allow certain options to append when create a new table, which can be
- // table_options or partition_options.
- // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8"
- val createtblOptions = jdbcOptions.createTableOptions
- val sql = s"CREATE TABLE $jdbcTable ($schema) $createtblOptions"
- val statement = conn.createStatement
- try {
- statement.executeUpdate(sql)
- } finally {
- statement.close()
- }
- }
- } finally {
- conn.close()
- }
-
- JdbcUtils.saveTable(df, jdbcUrl, jdbcTable, props)
+ this.extraOptions = this.extraOptions ++ (connectionProperties.asScala)
+ // explicit url and dbtable should override all
+ this.extraOptions += ("url" -> url, "dbtable" -> table)
+ format("jdbc").save()
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
index 1db090eaf9..bcf65e53af 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
@@ -27,10 +27,12 @@ class JDBCOptions(
// ------------------------------------------------------------
// Required parameters
// ------------------------------------------------------------
+ require(parameters.isDefinedAt("url"), "Option 'url' is required.")
+ require(parameters.isDefinedAt("dbtable"), "Option 'dbtable' is required.")
// a JDBC URL
- val url = parameters.getOrElse("url", sys.error("Option 'url' not specified"))
+ val url = parameters("url")
// name of table
- val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified"))
+ val table = parameters("dbtable")
// ------------------------------------------------------------
// Optional parameter list
@@ -44,6 +46,11 @@ class JDBCOptions(
// the number of partitions
val numPartitions = parameters.getOrElse("numPartitions", null)
+ require(partitionColumn == null ||
+ (lowerBound != null && upperBound != null && numPartitions != null),
+ "If 'partitionColumn' is specified then 'lowerBound', 'upperBound'," +
+ " and 'numPartitions' are required.")
+
// ------------------------------------------------------------
// The options for DataFrameWriter
// ------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
index 106ed1d440..ae04af2479 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
@@ -19,37 +19,102 @@ package org.apache.spark.sql.execution.datasources.jdbc
import java.util.Properties
-import org.apache.spark.sql.SQLContext
-import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, RelationProvider}
+import scala.collection.JavaConverters.mapAsJavaMapConverter
-class JdbcRelationProvider extends RelationProvider with DataSourceRegister {
+import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext}
+import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, RelationProvider}
+
+class JdbcRelationProvider extends CreatableRelationProvider
+ with RelationProvider with DataSourceRegister {
override def shortName(): String = "jdbc"
- /** Returns a new base relation with the given parameters. */
override def createRelation(
sqlContext: SQLContext,
parameters: Map[String, String]): BaseRelation = {
val jdbcOptions = new JDBCOptions(parameters)
- if (jdbcOptions.partitionColumn != null
- && (jdbcOptions.lowerBound == null
- || jdbcOptions.upperBound == null
- || jdbcOptions.numPartitions == null)) {
- sys.error("Partitioning incompletely specified")
- }
+ val partitionColumn = jdbcOptions.partitionColumn
+ val lowerBound = jdbcOptions.lowerBound
+ val upperBound = jdbcOptions.upperBound
+ val numPartitions = jdbcOptions.numPartitions
- val partitionInfo = if (jdbcOptions.partitionColumn == null) {
+ val partitionInfo = if (partitionColumn == null) {
null
} else {
JDBCPartitioningInfo(
- jdbcOptions.partitionColumn,
- jdbcOptions.lowerBound.toLong,
- jdbcOptions.upperBound.toLong,
- jdbcOptions.numPartitions.toInt)
+ partitionColumn, lowerBound.toLong, upperBound.toLong, numPartitions.toInt)
}
val parts = JDBCRelation.columnPartition(partitionInfo)
val properties = new Properties() // Additional properties that we will pass to getConnection
parameters.foreach(kv => properties.setProperty(kv._1, kv._2))
JDBCRelation(jdbcOptions.url, jdbcOptions.table, parts, properties)(sqlContext.sparkSession)
}
+
+ /*
+ * The following structure applies to this code:
+ * | tableExists | !tableExists
+ *------------------------------------------------------------------------------------
+ * Ignore | BaseRelation | CreateTable, saveTable, BaseRelation
+ * ErrorIfExists | ERROR | CreateTable, saveTable, BaseRelation
+ * Overwrite* | (DropTable, CreateTable,) | CreateTable, saveTable, BaseRelation
+ * | saveTable, BaseRelation |
+ * Append | saveTable, BaseRelation | CreateTable, saveTable, BaseRelation
+ *
+ * *Overwrite & tableExists with truncate, will not drop & create, but instead truncate
+ */
+ override def createRelation(
+ sqlContext: SQLContext,
+ mode: SaveMode,
+ parameters: Map[String, String],
+ data: DataFrame): BaseRelation = {
+ val jdbcOptions = new JDBCOptions(parameters)
+ val url = jdbcOptions.url
+ val table = jdbcOptions.table
+
+ val props = new Properties()
+ props.putAll(parameters.asJava)
+ val conn = JdbcUtils.createConnectionFactory(url, props)()
+
+ try {
+ val tableExists = JdbcUtils.tableExists(conn, url, table)
+
+ val (doCreate, doSave) = (mode, tableExists) match {
+ case (SaveMode.Ignore, true) => (false, false)
+ case (SaveMode.ErrorIfExists, true) => throw new AnalysisException(
+ s"Table or view '$table' already exists, and SaveMode is set to ErrorIfExists.")
+ case (SaveMode.Overwrite, true) =>
+ if (jdbcOptions.isTruncate && JdbcUtils.isCascadingTruncateTable(url) == Some(false)) {
+ JdbcUtils.truncateTable(conn, table)
+ (false, true)
+ } else {
+ JdbcUtils.dropTable(conn, table)
+ (true, true)
+ }
+ case (SaveMode.Append, true) => (false, true)
+ case (_, true) => throw new IllegalArgumentException(s"Unexpected SaveMode, '$mode'," +
+ " for handling existing tables.")
+ case (_, false) => (true, true)
+ }
+
+ if (doCreate) {
+ val schema = JdbcUtils.schemaString(data, url)
+ // To allow certain options to append when create a new table, which can be
+ // table_options or partition_options.
+ // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8"
+ val createtblOptions = jdbcOptions.createTableOptions
+ val sql = s"CREATE TABLE $table ($schema) $createtblOptions"
+ val statement = conn.createStatement
+ try {
+ statement.executeUpdate(sql)
+ } finally {
+ statement.close()
+ }
+ }
+ if (doSave) JdbcUtils.saveTable(data, url, table, props)
+ } finally {
+ conn.close()
+ }
+
+ createRelation(sqlContext, parameters)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
index ff3309874f..506971362f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
@@ -20,6 +20,8 @@ package org.apache.spark.sql.jdbc
import java.sql.DriverManager
import java.util.Properties
+import scala.collection.JavaConverters.propertiesAsScalaMapConverter
+
import org.scalatest.BeforeAndAfter
import org.apache.spark.SparkException
@@ -208,4 +210,84 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).count())
assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length)
}
+
+ test("save works for format(\"jdbc\") if url and dbtable are set") {
+ val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
+
+ df.write.format("jdbc")
+ .options(Map("url" -> url, "dbtable" -> "TEST.SAVETEST"))
+ .save()
+
+ assert(2 === sqlContext.read.jdbc(url, "TEST.SAVETEST", new Properties).count)
+ assert(
+ 2 === sqlContext.read.jdbc(url, "TEST.SAVETEST", new Properties).collect()(0).length)
+ }
+
+ test("save API with SaveMode.Overwrite") {
+ val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
+ val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2)
+
+ df.write.format("jdbc")
+ .option("url", url1)
+ .option("dbtable", "TEST.SAVETEST")
+ .options(properties.asScala)
+ .save()
+ df2.write.mode(SaveMode.Overwrite).format("jdbc")
+ .option("url", url1)
+ .option("dbtable", "TEST.SAVETEST")
+ .options(properties.asScala)
+ .save()
+ assert(1 === spark.read.jdbc(url1, "TEST.SAVETEST", properties).count())
+ assert(2 === spark.read.jdbc(url1, "TEST.SAVETEST", properties).collect()(0).length)
+ }
+
+ test("save errors if url is not specified") {
+ val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
+
+ val e = intercept[RuntimeException] {
+ df.write.format("jdbc")
+ .option("dbtable", "TEST.SAVETEST")
+ .options(properties.asScala)
+ .save()
+ }.getMessage
+ assert(e.contains("Option 'url' is required"))
+ }
+
+ test("save errors if dbtable is not specified") {
+ val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
+
+ val e = intercept[RuntimeException] {
+ df.write.format("jdbc")
+ .option("url", url1)
+ .options(properties.asScala)
+ .save()
+ }.getMessage
+ assert(e.contains("Option 'dbtable' is required"))
+ }
+
+ test("save errors if wrong user/password combination") {
+ val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
+
+ val e = intercept[org.h2.jdbc.JdbcSQLException] {
+ df.write.format("jdbc")
+ .option("dbtable", "TEST.SAVETEST")
+ .option("url", url1)
+ .save()
+ }.getMessage
+ assert(e.contains("Wrong user name or password"))
+ }
+
+ test("save errors if partitionColumn and numPartitions and bounds not set") {
+ val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
+
+ val e = intercept[java.lang.IllegalArgumentException] {
+ df.write.format("jdbc")
+ .option("dbtable", "TEST.SAVETEST")
+ .option("url", url1)
+ .option("partitionColumn", "foo")
+ .save()
+ }.getMessage
+ assert(e.contains("If 'partitionColumn' is specified then 'lowerBound', 'upperBound'," +
+ " and 'numPartitions' are required."))
+ }
}