diff options
author | Justin Pihony <justin.pihony@gmail.com> | 2016-09-26 09:54:22 +0100 |
---|---|---|
committer | Sean Owen <sowen@cloudera.com> | 2016-09-26 09:54:22 +0100 |
commit | 50b89d05b7bffc212cc9b9ae6e0bca7cb90b9c77 (patch) | |
tree | 004018c95e9fedc204d683c210af79ac43bd4212 /sql | |
parent | ac65139be96dbf87402b9a85729a93afd3c6ff17 (diff) | |
download | spark-50b89d05b7bffc212cc9b9ae6e0bca7cb90b9c77.tar.gz spark-50b89d05b7bffc212cc9b9ae6e0bca7cb90b9c77.tar.bz2 spark-50b89d05b7bffc212cc9b9ae6e0bca7cb90b9c77.zip |
[SPARK-14525][SQL] Make DataFrameWrite.save work for jdbc
## What changes were proposed in this pull request?
This change modifies the implementation of DataFrameWriter.save such that it works with jdbc, and the call to jdbc merely delegates to save.
## How was this patch tested?
This was tested via unit tests in the JDBCWriteSuite, of which I added one new test to cover this scenario.
## Additional details
rxin This seems to have been most recently touched by you and was also commented on in the JIRA.
This contribution is my original work and I license the work to the project under the project's open source license.
Author: Justin Pihony <justin.pihony@gmail.com>
Author: Justin Pihony <justin.pihony@typesafe.com>
Closes #12601 from JustinPihony/jdbc_reconciliation.
Diffstat (limited to 'sql')
4 files changed, 175 insertions, 72 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 64d3422cb4..7374a8e045 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -425,62 +425,11 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { def jdbc(url: String, table: String, connectionProperties: Properties): Unit = { assertNotPartitioned("jdbc") assertNotBucketed("jdbc") - - // to add required options like URL and dbtable - val params = extraOptions.toMap ++ Map("url" -> url, "dbtable" -> table) - val jdbcOptions = new JDBCOptions(params) - val jdbcUrl = jdbcOptions.url - val jdbcTable = jdbcOptions.table - - val props = new Properties() - extraOptions.foreach { case (key, value) => - props.put(key, value) - } // connectionProperties should override settings in extraOptions - props.putAll(connectionProperties) - val conn = JdbcUtils.createConnectionFactory(jdbcUrl, props)() - - try { - var tableExists = JdbcUtils.tableExists(conn, jdbcUrl, jdbcTable) - - if (mode == SaveMode.Ignore && tableExists) { - return - } - - if (mode == SaveMode.ErrorIfExists && tableExists) { - sys.error(s"Table $jdbcTable already exists.") - } - - if (mode == SaveMode.Overwrite && tableExists) { - if (jdbcOptions.isTruncate && - JdbcUtils.isCascadingTruncateTable(jdbcUrl) == Some(false)) { - JdbcUtils.truncateTable(conn, jdbcTable) - } else { - JdbcUtils.dropTable(conn, jdbcTable) - tableExists = false - } - } - - // Create the table if the table didn't exist. - if (!tableExists) { - val schema = JdbcUtils.schemaString(df, jdbcUrl) - // To allow certain options to append when create a new table, which can be - // table_options or partition_options. - // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8" - val createtblOptions = jdbcOptions.createTableOptions - val sql = s"CREATE TABLE $jdbcTable ($schema) $createtblOptions" - val statement = conn.createStatement - try { - statement.executeUpdate(sql) - } finally { - statement.close() - } - } - } finally { - conn.close() - } - - JdbcUtils.saveTable(df, jdbcUrl, jdbcTable, props) + this.extraOptions = this.extraOptions ++ (connectionProperties.asScala) + // explicit url and dbtable should override all + this.extraOptions += ("url" -> url, "dbtable" -> table) + format("jdbc").save() } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index 1db090eaf9..bcf65e53af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -27,10 +27,12 @@ class JDBCOptions( // ------------------------------------------------------------ // Required parameters // ------------------------------------------------------------ + require(parameters.isDefinedAt("url"), "Option 'url' is required.") + require(parameters.isDefinedAt("dbtable"), "Option 'dbtable' is required.") // a JDBC URL - val url = parameters.getOrElse("url", sys.error("Option 'url' not specified")) + val url = parameters("url") // name of table - val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified")) + val table = parameters("dbtable") // ------------------------------------------------------------ // Optional parameter list @@ -44,6 +46,11 @@ class JDBCOptions( // the number of partitions val numPartitions = parameters.getOrElse("numPartitions", null) + require(partitionColumn == null || + (lowerBound != null && upperBound != null && numPartitions != null), + "If 'partitionColumn' is specified then 'lowerBound', 'upperBound'," + + " and 'numPartitions' are required.") + // ------------------------------------------------------------ // The options for DataFrameWriter // ------------------------------------------------------------ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index 106ed1d440..ae04af2479 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -19,37 +19,102 @@ package org.apache.spark.sql.execution.datasources.jdbc import java.util.Properties -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, RelationProvider} +import scala.collection.JavaConverters.mapAsJavaMapConverter -class JdbcRelationProvider extends RelationProvider with DataSourceRegister { +import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} +import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, RelationProvider} + +class JdbcRelationProvider extends CreatableRelationProvider + with RelationProvider with DataSourceRegister { override def shortName(): String = "jdbc" - /** Returns a new base relation with the given parameters. */ override def createRelation( sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { val jdbcOptions = new JDBCOptions(parameters) - if (jdbcOptions.partitionColumn != null - && (jdbcOptions.lowerBound == null - || jdbcOptions.upperBound == null - || jdbcOptions.numPartitions == null)) { - sys.error("Partitioning incompletely specified") - } + val partitionColumn = jdbcOptions.partitionColumn + val lowerBound = jdbcOptions.lowerBound + val upperBound = jdbcOptions.upperBound + val numPartitions = jdbcOptions.numPartitions - val partitionInfo = if (jdbcOptions.partitionColumn == null) { + val partitionInfo = if (partitionColumn == null) { null } else { JDBCPartitioningInfo( - jdbcOptions.partitionColumn, - jdbcOptions.lowerBound.toLong, - jdbcOptions.upperBound.toLong, - jdbcOptions.numPartitions.toInt) + partitionColumn, lowerBound.toLong, upperBound.toLong, numPartitions.toInt) } val parts = JDBCRelation.columnPartition(partitionInfo) val properties = new Properties() // Additional properties that we will pass to getConnection parameters.foreach(kv => properties.setProperty(kv._1, kv._2)) JDBCRelation(jdbcOptions.url, jdbcOptions.table, parts, properties)(sqlContext.sparkSession) } + + /* + * The following structure applies to this code: + * | tableExists | !tableExists + *------------------------------------------------------------------------------------ + * Ignore | BaseRelation | CreateTable, saveTable, BaseRelation + * ErrorIfExists | ERROR | CreateTable, saveTable, BaseRelation + * Overwrite* | (DropTable, CreateTable,) | CreateTable, saveTable, BaseRelation + * | saveTable, BaseRelation | + * Append | saveTable, BaseRelation | CreateTable, saveTable, BaseRelation + * + * *Overwrite & tableExists with truncate, will not drop & create, but instead truncate + */ + override def createRelation( + sqlContext: SQLContext, + mode: SaveMode, + parameters: Map[String, String], + data: DataFrame): BaseRelation = { + val jdbcOptions = new JDBCOptions(parameters) + val url = jdbcOptions.url + val table = jdbcOptions.table + + val props = new Properties() + props.putAll(parameters.asJava) + val conn = JdbcUtils.createConnectionFactory(url, props)() + + try { + val tableExists = JdbcUtils.tableExists(conn, url, table) + + val (doCreate, doSave) = (mode, tableExists) match { + case (SaveMode.Ignore, true) => (false, false) + case (SaveMode.ErrorIfExists, true) => throw new AnalysisException( + s"Table or view '$table' already exists, and SaveMode is set to ErrorIfExists.") + case (SaveMode.Overwrite, true) => + if (jdbcOptions.isTruncate && JdbcUtils.isCascadingTruncateTable(url) == Some(false)) { + JdbcUtils.truncateTable(conn, table) + (false, true) + } else { + JdbcUtils.dropTable(conn, table) + (true, true) + } + case (SaveMode.Append, true) => (false, true) + case (_, true) => throw new IllegalArgumentException(s"Unexpected SaveMode, '$mode'," + + " for handling existing tables.") + case (_, false) => (true, true) + } + + if (doCreate) { + val schema = JdbcUtils.schemaString(data, url) + // To allow certain options to append when create a new table, which can be + // table_options or partition_options. + // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8" + val createtblOptions = jdbcOptions.createTableOptions + val sql = s"CREATE TABLE $table ($schema) $createtblOptions" + val statement = conn.createStatement + try { + statement.executeUpdate(sql) + } finally { + statement.close() + } + } + if (doSave) JdbcUtils.saveTable(data, url, table, props) + } finally { + conn.close() + } + + createRelation(sqlContext, parameters) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index ff3309874f..506971362f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.jdbc import java.sql.DriverManager import java.util.Properties +import scala.collection.JavaConverters.propertiesAsScalaMapConverter + import org.scalatest.BeforeAndAfter import org.apache.spark.SparkException @@ -208,4 +210,84 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).count()) assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } + + test("save works for format(\"jdbc\") if url and dbtable are set") { + val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + + df.write.format("jdbc") + .options(Map("url" -> url, "dbtable" -> "TEST.SAVETEST")) + .save() + + assert(2 === sqlContext.read.jdbc(url, "TEST.SAVETEST", new Properties).count) + assert( + 2 === sqlContext.read.jdbc(url, "TEST.SAVETEST", new Properties).collect()(0).length) + } + + test("save API with SaveMode.Overwrite") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2) + + df.write.format("jdbc") + .option("url", url1) + .option("dbtable", "TEST.SAVETEST") + .options(properties.asScala) + .save() + df2.write.mode(SaveMode.Overwrite).format("jdbc") + .option("url", url1) + .option("dbtable", "TEST.SAVETEST") + .options(properties.asScala) + .save() + assert(1 === spark.read.jdbc(url1, "TEST.SAVETEST", properties).count()) + assert(2 === spark.read.jdbc(url1, "TEST.SAVETEST", properties).collect()(0).length) + } + + test("save errors if url is not specified") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + + val e = intercept[RuntimeException] { + df.write.format("jdbc") + .option("dbtable", "TEST.SAVETEST") + .options(properties.asScala) + .save() + }.getMessage + assert(e.contains("Option 'url' is required")) + } + + test("save errors if dbtable is not specified") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + + val e = intercept[RuntimeException] { + df.write.format("jdbc") + .option("url", url1) + .options(properties.asScala) + .save() + }.getMessage + assert(e.contains("Option 'dbtable' is required")) + } + + test("save errors if wrong user/password combination") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + + val e = intercept[org.h2.jdbc.JdbcSQLException] { + df.write.format("jdbc") + .option("dbtable", "TEST.SAVETEST") + .option("url", url1) + .save() + }.getMessage + assert(e.contains("Wrong user name or password")) + } + + test("save errors if partitionColumn and numPartitions and bounds not set") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + + val e = intercept[java.lang.IllegalArgumentException] { + df.write.format("jdbc") + .option("dbtable", "TEST.SAVETEST") + .option("url", url1) + .option("partitionColumn", "foo") + .save() + }.getMessage + assert(e.contains("If 'partitionColumn' is specified then 'lowerBound', 'upperBound'," + + " and 'numPartitions' are required.")) + } } |