aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorhyukjinkwon <gurwls223@gmail.com>2016-10-10 22:22:41 -0700
committergatorsmile <gatorsmile@gmail.com>2016-10-10 22:22:41 -0700
commit0c0ad436ad909364915b910867d08262c62bc95d (patch)
treeb8afcc1aa41d83596258fc541ab99e01231b851e /sql
parent19a5bae47f69929d00d9de43387c7df37a05ee25 (diff)
downloadspark-0c0ad436ad909364915b910867d08262c62bc95d.tar.gz
spark-0c0ad436ad909364915b910867d08262c62bc95d.tar.bz2
spark-0c0ad436ad909364915b910867d08262c62bc95d.zip
[SPARK-17719][SPARK-17776][SQL] Unify and tie up options in a single place in JDBC datasource package
## What changes were proposed in this pull request? This PR proposes to fix arbitrary usages among `Map[String, String]`, `Properties` and `JDBCOptions` instances for options in `execution/jdbc` package and make the connection properties exclude Spark-only options. This PR includes some changes as below: - Unify `Map[String, String]`, `Properties` and `JDBCOptions` in `execution/jdbc` package to `JDBCOptions`. - Move `batchsize`, `fetchszie`, `driver` and `isolationlevel` options into `JDBCOptions` instance. - Document `batchSize` and `isolationlevel` with marking both read-only options and write-only options. Also, this includes minor types and detailed explanation for some statements such as url. - Throw exceptions fast by checking arguments first rather than in execution time (e.g. for `fetchsize`). - Exclude Spark-only options in connection properties. ## How was this patch tested? Existing tests should cover this. Author: hyukjinkwon <gurwls223@gmail.com> Closes #15292 from HyukjinKwon/SPARK-17719.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala110
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala45
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala20
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala30
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala42
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala11
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala8
9 files changed, 155 insertions, 128 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index b54e695db3..a716a916b7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -29,7 +29,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.json.{JacksonParser, JSONOptions}
import org.apache.spark.sql.execution.LogicalRDD
import org.apache.spark.sql.execution.datasources.DataSource
-import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation}
+import org.apache.spark.sql.execution.datasources.jdbc._
import org.apache.spark.sql.execution.datasources.json.InferSchema
import org.apache.spark.sql.types.StructType
@@ -231,13 +231,10 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
table: String,
parts: Array[Partition],
connectionProperties: Properties): DataFrame = {
- val props = new Properties()
- extraOptions.foreach { case (key, value) =>
- props.put(key, value)
- }
- // connectionProperties should override settings in extraOptions
- props.putAll(connectionProperties)
- val relation = JDBCRelation(url, table, parts, props)(sparkSession)
+ // connectionProperties should override settings in extraOptions.
+ val params = extraOptions.toMap ++ connectionProperties.asScala.toMap
+ val options = new JDBCOptions(url, table, params)
+ val relation = JDBCRelation(parts, options)(sparkSession)
sparkSession.baseRelationToDataFrame(relation)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
index bcf65e53af..fcd7409159 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
@@ -17,6 +17,11 @@
package org.apache.spark.sql.execution.datasources.jdbc
+import java.sql.{Connection, DriverManager}
+import java.util.Properties
+
+import scala.collection.mutable.ArrayBuffer
+
/**
* Options for the JDBC data source.
*/
@@ -24,40 +29,115 @@ class JDBCOptions(
@transient private val parameters: Map[String, String])
extends Serializable {
+ import JDBCOptions._
+
+ def this(url: String, table: String, parameters: Map[String, String]) = {
+ this(parameters ++ Map(
+ JDBCOptions.JDBC_URL -> url,
+ JDBCOptions.JDBC_TABLE_NAME -> table))
+ }
+
+ val asConnectionProperties: Properties = {
+ val properties = new Properties()
+ // We should avoid to pass the options into properties. See SPARK-17776.
+ parameters.filterKeys(!jdbcOptionNames.contains(_))
+ .foreach { case (k, v) => properties.setProperty(k, v) }
+ properties
+ }
+
// ------------------------------------------------------------
// Required parameters
// ------------------------------------------------------------
- require(parameters.isDefinedAt("url"), "Option 'url' is required.")
- require(parameters.isDefinedAt("dbtable"), "Option 'dbtable' is required.")
+ require(parameters.isDefinedAt(JDBC_URL), s"Option '$JDBC_URL' is required.")
+ require(parameters.isDefinedAt(JDBC_TABLE_NAME), s"Option '$JDBC_TABLE_NAME' is required.")
// a JDBC URL
- val url = parameters("url")
+ val url = parameters(JDBC_URL)
// name of table
- val table = parameters("dbtable")
+ val table = parameters(JDBC_TABLE_NAME)
+
+ // ------------------------------------------------------------
+ // Optional parameters
+ // ------------------------------------------------------------
+ val driverClass = {
+ val userSpecifiedDriverClass = parameters.get(JDBC_DRIVER_CLASS)
+ userSpecifiedDriverClass.foreach(DriverRegistry.register)
+
+ // Performing this part of the logic on the driver guards against the corner-case where the
+ // driver returned for a URL is different on the driver and executors due to classpath
+ // differences.
+ userSpecifiedDriverClass.getOrElse {
+ DriverManager.getDriver(url).getClass.getCanonicalName
+ }
+ }
// ------------------------------------------------------------
- // Optional parameter list
+ // Optional parameters only for reading
// ------------------------------------------------------------
// the column used to partition
- val partitionColumn = parameters.getOrElse("partitionColumn", null)
+ val partitionColumn = parameters.getOrElse(JDBC_PARTITION_COLUMN, null)
// the lower bound of partition column
- val lowerBound = parameters.getOrElse("lowerBound", null)
+ val lowerBound = parameters.getOrElse(JDBC_LOWER_BOUND, null)
// the upper bound of the partition column
- val upperBound = parameters.getOrElse("upperBound", null)
+ val upperBound = parameters.getOrElse(JDBC_UPPER_BOUND, null)
// the number of partitions
- val numPartitions = parameters.getOrElse("numPartitions", null)
-
+ val numPartitions = parameters.getOrElse(JDBC_NUM_PARTITIONS, null)
require(partitionColumn == null ||
(lowerBound != null && upperBound != null && numPartitions != null),
- "If 'partitionColumn' is specified then 'lowerBound', 'upperBound'," +
- " and 'numPartitions' are required.")
+ s"If '$JDBC_PARTITION_COLUMN' is specified then '$JDBC_LOWER_BOUND', '$JDBC_UPPER_BOUND'," +
+ s" and '$JDBC_NUM_PARTITIONS' are required.")
+ val fetchSize = {
+ val size = parameters.getOrElse(JDBC_BATCH_FETCH_SIZE, "0").toInt
+ require(size >= 0,
+ s"Invalid value `${size.toString}` for parameter " +
+ s"`$JDBC_BATCH_FETCH_SIZE`. The minimum value is 0. When the value is 0, " +
+ "the JDBC driver ignores the value and does the estimates.")
+ size
+ }
// ------------------------------------------------------------
- // The options for DataFrameWriter
+ // Optional parameters only for writing
// ------------------------------------------------------------
// if to truncate the table from the JDBC database
- val isTruncate = parameters.getOrElse("truncate", "false").toBoolean
+ val isTruncate = parameters.getOrElse(JDBC_TRUNCATE, "false").toBoolean
// the create table option , which can be table_options or partition_options.
// E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8"
// TODO: to reuse the existing partition parameters for those partition specific options
- val createTableOptions = parameters.getOrElse("createTableOptions", "")
+ val createTableOptions = parameters.getOrElse(JDBC_CREATE_TABLE_OPTIONS, "")
+ val batchSize = {
+ val size = parameters.getOrElse(JDBC_BATCH_INSERT_SIZE, "1000").toInt
+ require(size >= 1,
+ s"Invalid value `${size.toString}` for parameter " +
+ s"`$JDBC_BATCH_INSERT_SIZE`. The minimum value is 1.")
+ size
+ }
+ val isolationLevel =
+ parameters.getOrElse(JDBC_TXN_ISOLATION_LEVEL, "READ_UNCOMMITTED") match {
+ case "NONE" => Connection.TRANSACTION_NONE
+ case "READ_UNCOMMITTED" => Connection.TRANSACTION_READ_UNCOMMITTED
+ case "READ_COMMITTED" => Connection.TRANSACTION_READ_COMMITTED
+ case "REPEATABLE_READ" => Connection.TRANSACTION_REPEATABLE_READ
+ case "SERIALIZABLE" => Connection.TRANSACTION_SERIALIZABLE
+ }
+}
+
+object JDBCOptions {
+ private val jdbcOptionNames = ArrayBuffer.empty[String]
+
+ private def newOption(name: String): String = {
+ jdbcOptionNames += name
+ name
+ }
+
+ val JDBC_URL = newOption("url")
+ val JDBC_TABLE_NAME = newOption("dbtable")
+ val JDBC_DRIVER_CLASS = newOption("driver")
+ val JDBC_PARTITION_COLUMN = newOption("partitionColumn")
+ val JDBC_LOWER_BOUND = newOption("lowerBound")
+ val JDBC_UPPER_BOUND = newOption("upperBound")
+ val JDBC_NUM_PARTITIONS = newOption("numPartitions")
+ val JDBC_BATCH_FETCH_SIZE = newOption("fetchsize")
+ val JDBC_TRUNCATE = newOption("truncate")
+ val JDBC_CREATE_TABLE_OPTIONS = newOption("createTableOptions")
+ val JDBC_BATCH_INSERT_SIZE = newOption("batchsize")
+ val JDBC_TXN_ISOLATION_LEVEL = newOption("isolationLevel")
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index f10615ebe4..c0fabc81e4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -18,7 +18,6 @@
package org.apache.spark.sql.execution.datasources.jdbc
import java.sql.{Connection, Date, PreparedStatement, ResultSet, SQLException, Timestamp}
-import java.util.Properties
import scala.util.control.NonFatal
@@ -46,17 +45,18 @@ object JDBCRDD extends Logging {
* Takes a (schema, table) specification and returns the table's Catalyst
* schema.
*
- * @param url - The JDBC url to fetch information from.
- * @param table - The table name of the desired table. This may also be a
- * SQL query wrapped in parentheses.
+ * @param options - JDBC options that contains url, table and other information.
*
* @return A StructType giving the table's Catalyst schema.
* @throws SQLException if the table specification is garbage.
* @throws SQLException if the table contains an unsupported type.
*/
- def resolveTable(url: String, table: String, properties: Properties): StructType = {
+ def resolveTable(options: JDBCOptions): StructType = {
+ val url = options.url
+ val table = options.table
+ val properties = options.asConnectionProperties
val dialect = JdbcDialects.get(url)
- val conn: Connection = JdbcUtils.createConnectionFactory(url, properties)()
+ val conn: Connection = JdbcUtils.createConnectionFactory(options)()
try {
val statement = conn.prepareStatement(dialect.getSchemaQuery(table))
try {
@@ -143,43 +143,38 @@ object JDBCRDD extends Logging {
})
}
-
-
/**
* Build and return JDBCRDD from the given information.
*
* @param sc - Your SparkContext.
* @param schema - The Catalyst schema of the underlying database table.
- * @param url - The JDBC url to connect to.
- * @param fqTable - The fully-qualified table name (or paren'd SQL query) to use.
* @param requiredColumns - The names of the columns to SELECT.
* @param filters - The filters to include in all WHERE clauses.
* @param parts - An array of JDBCPartitions specifying partition ids and
* per-partition WHERE clauses.
+ * @param options - JDBC options that contains url, table and other information.
*
* @return An RDD representing "SELECT requiredColumns FROM fqTable".
*/
def scanTable(
sc: SparkContext,
schema: StructType,
- url: String,
- properties: Properties,
- fqTable: String,
requiredColumns: Array[String],
filters: Array[Filter],
- parts: Array[Partition]): RDD[InternalRow] = {
+ parts: Array[Partition],
+ options: JDBCOptions): RDD[InternalRow] = {
+ val url = options.url
val dialect = JdbcDialects.get(url)
val quotedColumns = requiredColumns.map(colName => dialect.quoteIdentifier(colName))
new JDBCRDD(
sc,
- JdbcUtils.createConnectionFactory(url, properties),
+ JdbcUtils.createConnectionFactory(options),
pruneSchema(schema, requiredColumns),
- fqTable,
quotedColumns,
filters,
parts,
url,
- properties)
+ options)
}
}
@@ -192,12 +187,11 @@ private[jdbc] class JDBCRDD(
sc: SparkContext,
getConnection: () => Connection,
schema: StructType,
- fqTable: String,
columns: Array[String],
filters: Array[Filter],
partitions: Array[Partition],
url: String,
- properties: Properties)
+ options: JDBCOptions)
extends RDD[InternalRow](sc, Nil) {
/**
@@ -211,7 +205,7 @@ private[jdbc] class JDBCRDD(
private val columnList: String = {
val sb = new StringBuilder()
columns.foreach(x => sb.append(",").append(x))
- if (sb.length == 0) "1" else sb.substring(1)
+ if (sb.isEmpty) "1" else sb.substring(1)
}
/**
@@ -286,7 +280,7 @@ private[jdbc] class JDBCRDD(
conn = getConnection()
val dialect = JdbcDialects.get(url)
import scala.collection.JavaConverters._
- dialect.beforeFetch(conn, properties.asScala.toMap)
+ dialect.beforeFetch(conn, options.asConnectionProperties.asScala.toMap)
// H2's JDBC driver does not support the setSchema() method. We pass a
// fully-qualified table name in the SELECT statement. I don't know how to
@@ -294,15 +288,10 @@ private[jdbc] class JDBCRDD(
val myWhereClause = getWhereClause(part)
- val sqlText = s"SELECT $columnList FROM $fqTable $myWhereClause"
+ val sqlText = s"SELECT $columnList FROM ${options.table} $myWhereClause"
stmt = conn.prepareStatement(sqlText,
ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
- val fetchSize = properties.getProperty(JdbcUtils.JDBC_BATCH_FETCH_SIZE, "0").toInt
- require(fetchSize >= 0,
- s"Invalid value `${fetchSize.toString}` for parameter " +
- s"`${JdbcUtils.JDBC_BATCH_FETCH_SIZE}`. The minimum value is 0. When the value is 0, " +
- "the JDBC driver ignores the value and does the estimates.")
- stmt.setFetchSize(fetchSize)
+ stmt.setFetchSize(options.fetchSize)
rs = stmt.executeQuery()
val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
index 11613dd912..672c21c6ac 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.execution.datasources.jdbc
-import java.util.Properties
-
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.internal.Logging
@@ -102,10 +100,7 @@ private[sql] object JDBCRelation extends Logging {
}
private[sql] case class JDBCRelation(
- url: String,
- table: String,
- parts: Array[Partition],
- properties: Properties = new Properties())(@transient val sparkSession: SparkSession)
+ parts: Array[Partition], jdbcOptions: JDBCOptions)(@transient val sparkSession: SparkSession)
extends BaseRelation
with PrunedFilteredScan
with InsertableRelation {
@@ -114,7 +109,7 @@ private[sql] case class JDBCRelation(
override val needConversion: Boolean = false
- override val schema: StructType = JDBCRDD.resolveTable(url, table, properties)
+ override val schema: StructType = JDBCRDD.resolveTable(jdbcOptions)
// Check if JDBCRDD.compileFilter can accept input filters
override def unhandledFilters(filters: Array[Filter]): Array[Filter] = {
@@ -126,15 +121,16 @@ private[sql] case class JDBCRelation(
JDBCRDD.scanTable(
sparkSession.sparkContext,
schema,
- url,
- properties,
- table,
requiredColumns,
filters,
- parts).asInstanceOf[RDD[Row]]
+ parts,
+ jdbcOptions).asInstanceOf[RDD[Row]]
}
override def insert(data: DataFrame, overwrite: Boolean): Unit = {
+ val url = jdbcOptions.url
+ val table = jdbcOptions.table
+ val properties = jdbcOptions.asConnectionProperties
data.write
.mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append)
.jdbc(url, table, properties)
@@ -142,6 +138,6 @@ private[sql] case class JDBCRelation(
override def toString: String = {
// credentials should not be included in the plan output, table information is sufficient.
- s"JDBCRelation(${table})"
+ s"JDBCRelation(${jdbcOptions.table})"
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
index b1a061b6f7..4420b3b18a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
@@ -17,10 +17,6 @@
package org.apache.spark.sql.execution.datasources.jdbc
-import java.util.Properties
-
-import scala.collection.JavaConverters.mapAsJavaMapConverter
-
import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext}
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils._
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, RelationProvider}
@@ -46,9 +42,7 @@ class JdbcRelationProvider extends CreatableRelationProvider
partitionColumn, lowerBound.toLong, upperBound.toLong, numPartitions.toInt)
}
val parts = JDBCRelation.columnPartition(partitionInfo)
- val properties = new Properties() // Additional properties that we will pass to getConnection
- parameters.foreach(kv => properties.setProperty(kv._1, kv._2))
- JDBCRelation(jdbcOptions.url, jdbcOptions.table, parts, properties)(sqlContext.sparkSession)
+ JDBCRelation(parts, jdbcOptions)(sqlContext.sparkSession)
}
override def createRelation(
@@ -56,15 +50,13 @@ class JdbcRelationProvider extends CreatableRelationProvider
mode: SaveMode,
parameters: Map[String, String],
df: DataFrame): BaseRelation = {
- val options = new JDBCOptions(parameters)
- val url = options.url
- val table = options.table
- val createTableOptions = options.createTableOptions
- val isTruncate = options.isTruncate
- val props = new Properties()
- props.putAll(parameters.asJava)
+ val jdbcOptions = new JDBCOptions(parameters)
+ val url = jdbcOptions.url
+ val table = jdbcOptions.table
+ val createTableOptions = jdbcOptions.createTableOptions
+ val isTruncate = jdbcOptions.isTruncate
- val conn = JdbcUtils.createConnectionFactory(url, props)()
+ val conn = JdbcUtils.createConnectionFactory(jdbcOptions)()
try {
val tableExists = JdbcUtils.tableExists(conn, url, table)
if (tableExists) {
@@ -73,16 +65,16 @@ class JdbcRelationProvider extends CreatableRelationProvider
if (isTruncate && isCascadingTruncateTable(url) == Some(false)) {
// In this case, we should truncate table and then load.
truncateTable(conn, table)
- saveTable(df, url, table, props)
+ saveTable(df, url, table, jdbcOptions)
} else {
// Otherwise, do not truncate the table, instead drop and recreate it
dropTable(conn, table)
createTable(df.schema, url, table, createTableOptions, conn)
- saveTable(df, url, table, props)
+ saveTable(df, url, table, jdbcOptions)
}
case SaveMode.Append =>
- saveTable(df, url, table, props)
+ saveTable(df, url, table, jdbcOptions)
case SaveMode.ErrorIfExists =>
throw new AnalysisException(
@@ -95,7 +87,7 @@ class JdbcRelationProvider extends CreatableRelationProvider
}
} else {
createTable(df.schema, url, table, createTableOptions, conn)
- saveTable(df, url, table, props)
+ saveTable(df, url, table, jdbcOptions)
}
} finally {
conn.close()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index 47549637b5..e32db73bd6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -18,7 +18,6 @@
package org.apache.spark.sql.execution.datasources.jdbc
import java.sql.{Connection, Driver, DriverManager, PreparedStatement, ResultSet, ResultSetMetaData, SQLException}
-import java.util.Properties
import scala.collection.JavaConverters._
import scala.util.Try
@@ -41,27 +40,13 @@ import org.apache.spark.util.NextIterator
* Util functions for JDBC tables.
*/
object JdbcUtils extends Logging {
-
- // the property names are case sensitive
- val JDBC_BATCH_FETCH_SIZE = "fetchsize"
- val JDBC_BATCH_INSERT_SIZE = "batchsize"
- val JDBC_TXN_ISOLATION_LEVEL = "isolationLevel"
-
/**
* Returns a factory for creating connections to the given JDBC URL.
*
- * @param url the JDBC url to connect to.
- * @param properties JDBC connection properties.
+ * @param options - JDBC options that contains url, table and other information.
*/
- def createConnectionFactory(url: String, properties: Properties): () => Connection = {
- val userSpecifiedDriverClass = Option(properties.getProperty("driver"))
- userSpecifiedDriverClass.foreach(DriverRegistry.register)
- // Performing this part of the logic on the driver guards against the corner-case where the
- // driver returned for a URL is different on the driver and executors due to classpath
- // differences.
- val driverClass: String = userSpecifiedDriverClass.getOrElse {
- DriverManager.getDriver(url).getClass.getCanonicalName
- }
+ def createConnectionFactory(options: JDBCOptions): () => Connection = {
+ val driverClass: String = options.driverClass
() => {
DriverRegistry.register(driverClass)
val driver: Driver = DriverManager.getDrivers.asScala.collectFirst {
@@ -71,7 +56,7 @@ object JdbcUtils extends Logging {
throw new IllegalStateException(
s"Did not find registered driver with class $driverClass")
}
- driver.connect(url, properties)
+ driver.connect(options.url, options.asConnectionProperties)
}
}
@@ -550,10 +535,6 @@ object JdbcUtils extends Logging {
batchSize: Int,
dialect: JdbcDialect,
isolationLevel: Int): Iterator[Byte] = {
- require(batchSize >= 1,
- s"Invalid value `${batchSize.toString}` for parameter " +
- s"`$JDBC_BATCH_INSERT_SIZE`. The minimum value is 1.")
-
val conn = getConnection()
var committed = false
@@ -676,23 +657,16 @@ object JdbcUtils extends Logging {
df: DataFrame,
url: String,
table: String,
- properties: Properties) {
+ options: JDBCOptions) {
val dialect = JdbcDialects.get(url)
val nullTypes: Array[Int] = df.schema.fields.map { field =>
getJdbcType(field.dataType, dialect).jdbcNullType
}
val rddSchema = df.schema
- val getConnection: () => Connection = createConnectionFactory(url, properties)
- val batchSize = properties.getProperty(JDBC_BATCH_INSERT_SIZE, "1000").toInt
- val isolationLevel =
- properties.getProperty(JDBC_TXN_ISOLATION_LEVEL, "READ_UNCOMMITTED") match {
- case "NONE" => Connection.TRANSACTION_NONE
- case "READ_UNCOMMITTED" => Connection.TRANSACTION_READ_UNCOMMITTED
- case "READ_COMMITTED" => Connection.TRANSACTION_READ_COMMITTED
- case "REPEATABLE_READ" => Connection.TRANSACTION_REPEATABLE_READ
- case "SERIALIZABLE" => Connection.TRANSACTION_SERIALIZABLE
- }
+ val getConnection: () => Connection = createConnectionFactory(options)
+ val batchSize = options.batchSize
+ val isolationLevel = options.isolationLevel
df.foreachPartition(iterator => savePartition(
getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect, isolationLevel)
)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
index 3f540d6258..4f61a328f4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.jdbc
import java.sql.{Connection, Types}
-import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
+import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils}
import org.apache.spark.sql.types._
@@ -94,7 +94,7 @@ private object PostgresDialect extends JdbcDialect {
//
// See: https://jdbc.postgresql.org/documentation/head/query.html#query-with-cursor
//
- if (properties.getOrElse(JdbcUtils.JDBC_BATCH_FETCH_SIZE, "0").toInt > 0) {
+ if (properties.getOrElse(JDBCOptions.JDBC_BATCH_FETCH_SIZE, "0").toInt > 0) {
connection.setAutoCommit(false)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index 7cc3989b79..71cf5e6a22 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -29,8 +29,7 @@ import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.execution.DataSourceScanExec
import org.apache.spark.sql.execution.command.ExplainCommand
import org.apache.spark.sql.execution.datasources.LogicalRelation
-import org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD
-import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
+import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JdbcUtils}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
@@ -84,7 +83,7 @@ class JDBCSuite extends SparkFunSuite
|CREATE TEMPORARY TABLE fetchtwo
|USING org.apache.spark.sql.jdbc
|OPTIONS (url '$url', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass',
- | ${JdbcUtils.JDBC_BATCH_FETCH_SIZE} '2')
+ | ${JDBCOptions.JDBC_BATCH_FETCH_SIZE} '2')
""".stripMargin.replaceAll("\n", " "))
sql(
@@ -354,8 +353,8 @@ class JDBCSuite extends SparkFunSuite
test("Basic API with illegal fetchsize") {
val properties = new Properties()
- properties.setProperty(JdbcUtils.JDBC_BATCH_FETCH_SIZE, "-1")
- val e = intercept[SparkException] {
+ properties.setProperty(JDBCOptions.JDBC_BATCH_FETCH_SIZE, "-1")
+ val e = intercept[IllegalArgumentException] {
spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", properties).collect()
}.getMessage
assert(e.contains("Invalid value `-1` for parameter `fetchsize`"))
@@ -364,7 +363,7 @@ class JDBCSuite extends SparkFunSuite
test("Basic API with FetchSize") {
(0 to 4).foreach { size =>
val properties = new Properties()
- properties.setProperty(JdbcUtils.JDBC_BATCH_FETCH_SIZE, size.toString)
+ properties.setProperty(JDBCOptions.JDBC_BATCH_FETCH_SIZE, size.toString)
assert(spark.read.jdbc(
urlWithUserAndPass, "TEST.PEOPLE", properties).collect().length === 3)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
index 62b29db4d5..96540ec92d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
@@ -26,7 +26,7 @@ import org.scalatest.BeforeAndAfter
import org.apache.spark.SparkException
import org.apache.spark.sql.{Row, SaveMode}
-import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
+import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -113,8 +113,8 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
(-1 to 0).foreach { size =>
val properties = new Properties()
- properties.setProperty(JdbcUtils.JDBC_BATCH_INSERT_SIZE, size.toString)
- val e = intercept[SparkException] {
+ properties.setProperty(JDBCOptions.JDBC_BATCH_INSERT_SIZE, size.toString)
+ val e = intercept[IllegalArgumentException] {
df.write.mode(SaveMode.Overwrite).jdbc(url, "TEST.BASICCREATETEST", properties)
}.getMessage
assert(e.contains(s"Invalid value `$size` for parameter `batchsize`"))
@@ -126,7 +126,7 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
(1 to 3).foreach { size =>
val properties = new Properties()
- properties.setProperty(JdbcUtils.JDBC_BATCH_INSERT_SIZE, size.toString)
+ properties.setProperty(JDBCOptions.JDBC_BATCH_INSERT_SIZE, size.toString)
df.write.mode(SaveMode.Overwrite).jdbc(url, "TEST.BASICCREATETEST", properties)
assert(2 === spark.read.jdbc(url, "TEST.BASICCREATETEST", new Properties()).count())
}