From 6c83d938cc61bd5fabaf2157fcc3936364a83f02 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 4 Jan 2016 10:39:42 -0800 Subject: [SPARK-12579][SQL] Force user-specified JDBC driver to take precedence Spark SQL's JDBC data source allows users to specify an explicit JDBC driver to load (using the `driver` argument), but in the current code it's possible that the user-specified driver will not be used when it comes time to actually create a JDBC connection. In a nutshell, the problem is that you might have multiple JDBC drivers on the classpath that claim to be able to handle the same subprotocol, so simply registering the user-provided driver class with the our `DriverRegistry` and JDBC's `DriverManager` is not sufficient to ensure that it's actually used when creating the JDBC connection. This patch addresses this issue by first registering the user-specified driver with the DriverManager, then iterating over the driver manager's loaded drivers in order to obtain the correct driver and use it to create a connection (previously, we just called `DriverManager.getConnection()` directly). If a user did not specify a JDBC driver to use, then we call `DriverManager.getDriver` to figure out the class of the driver to use, then pass that class's name to executors; this guards against corner-case bugs in situations where the driver and executor JVMs might have different sets of JDBC drivers on their classpaths (previously, there was the (rare) potential for `DriverManager.getConnection()` to use different drivers on the driver and executors if the user had not explicitly specified a JDBC driver class and the classpaths were different). This patch is inspired by a similar patch that I made to the `spark-redshift` library (https://github.com/databricks/spark-redshift/pull/143), which contains its own modified fork of some of Spark's JDBC data source code (for cross-Spark-version compatibility reasons). Author: Josh Rosen Closes #10519 from JoshRosen/jdbc-driver-precedence. --- docs/sql-programming-guide.md | 4 +-- .../org/apache/spark/sql/DataFrameWriter.scala | 2 +- .../execution/datasources/jdbc/DefaultSource.scala | 3 -- .../datasources/jdbc/DriverRegistry.scala | 5 ---- .../sql/execution/datasources/jdbc/JDBCRDD.scala | 33 +++----------------- .../execution/datasources/jdbc/JDBCRelation.scala | 2 -- .../sql/execution/datasources/jdbc/JdbcUtils.scala | 35 +++++++++++++++++----- 7 files changed, 34 insertions(+), 50 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 3f9a831edd..b058833616 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1895,9 +1895,7 @@ the Data Sources API. The following options are supported: driver - The class name of the JDBC driver needed to connect to this URL. This class will be loaded - on the master and workers before running an JDBC commands to allow the driver to - register itself with the JDBC subsystem. + The class name of the JDBC driver to use to connect to this URL. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index ab362539e2..9f59c0f3a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -275,7 +275,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { } // connectionProperties should override settings in extraOptions props.putAll(connectionProperties) - val conn = JdbcUtils.createConnection(url, props) + val conn = JdbcUtils.createConnectionFactory(url, props)() try { var tableExists = JdbcUtils.tableExists(conn, url, table) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala index f522303be9..5ae6cff9b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala @@ -31,15 +31,12 @@ class DefaultSource extends RelationProvider with DataSourceRegister { sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { val url = parameters.getOrElse("url", sys.error("Option 'url' not specified")) - val driver = parameters.getOrElse("driver", null) val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified")) val partitionColumn = parameters.getOrElse("partitionColumn", null) val lowerBound = parameters.getOrElse("lowerBound", null) val upperBound = parameters.getOrElse("upperBound", null) val numPartitions = parameters.getOrElse("numPartitions", null) - if (driver != null) DriverRegistry.register(driver) - if (partitionColumn != null && (lowerBound == null || upperBound == null || numPartitions == null)) { sys.error("Partitioning incompletely specified") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala index 7ccd61ed46..65af397451 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala @@ -51,10 +51,5 @@ object DriverRegistry extends Logging { } } } - - def getDriverClassName(url: String): String = DriverManager.getDriver(url) match { - case wrapper: DriverWrapper => wrapper.wrapped.getClass.getCanonicalName - case driver => driver.getClass.getCanonicalName - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 87d43addd3..cb8d9504af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.jdbc -import java.sql.{Connection, Date, DriverManager, ResultSet, ResultSetMetaData, SQLException, Timestamp} +import java.sql.{Connection, Date, ResultSet, ResultSetMetaData, SQLException, Timestamp} import java.util.Properties import scala.util.control.NonFatal @@ -41,7 +41,6 @@ private[sql] case class JDBCPartition(whereClause: String, idx: Int) extends Par override def index: Int = idx } - private[sql] object JDBCRDD extends Logging { /** @@ -120,7 +119,7 @@ private[sql] object JDBCRDD extends Logging { */ def resolveTable(url: String, table: String, properties: Properties): StructType = { val dialect = JdbcDialects.get(url) - val conn: Connection = getConnector(properties.getProperty("driver"), url, properties)() + val conn: Connection = JdbcUtils.createConnectionFactory(url, properties)() try { val statement = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0") try { @@ -228,36 +227,13 @@ private[sql] object JDBCRDD extends Logging { }) } - /** - * Given a driver string and an url, return a function that loads the - * specified driver string then returns a connection to the JDBC url. - * getConnector is run on the driver code, while the function it returns - * is run on the executor. - * - * @param driver - The class name of the JDBC driver for the given url, or null if the class name - * is not necessary. - * @param url - The JDBC url to connect to. - * - * @return A function that loads the driver and connects to the url. - */ - def getConnector(driver: String, url: String, properties: Properties): () => Connection = { - () => { - try { - if (driver != null) DriverRegistry.register(driver) - } catch { - case e: ClassNotFoundException => - logWarning(s"Couldn't find class $driver", e) - } - DriverManager.getConnection(url, properties) - } - } + /** * Build and return JDBCRDD from the given information. * * @param sc - Your SparkContext. * @param schema - The Catalyst schema of the underlying database table. - * @param driver - The class name of the JDBC driver for the given url. * @param url - The JDBC url to connect to. * @param fqTable - The fully-qualified table name (or paren'd SQL query) to use. * @param requiredColumns - The names of the columns to SELECT. @@ -270,7 +246,6 @@ private[sql] object JDBCRDD extends Logging { def scanTable( sc: SparkContext, schema: StructType, - driver: String, url: String, properties: Properties, fqTable: String, @@ -281,7 +256,7 @@ private[sql] object JDBCRDD extends Logging { val quotedColumns = requiredColumns.map(colName => dialect.quoteIdentifier(colName)) new JDBCRDD( sc, - getConnector(driver, url, properties), + JdbcUtils.createConnectionFactory(url, properties), pruneSchema(schema, requiredColumns), fqTable, quotedColumns, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index f9300dc2cb..375266f1be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -91,12 +91,10 @@ private[sql] case class JDBCRelation( override val schema: StructType = JDBCRDD.resolveTable(url, table, properties) override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { - val driver: String = DriverRegistry.getDriverClassName(url) // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] JDBCRDD.scanTable( sqlContext.sparkContext, schema, - driver, url, properties, table, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 46f2670eee..10f650693f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.execution.datasources.jdbc -import java.sql.{Connection, PreparedStatement} +import java.sql.{Connection, Driver, DriverManager, PreparedStatement} import java.util.Properties +import scala.collection.JavaConverters._ import scala.util.Try import scala.util.control.NonFatal @@ -34,10 +35,31 @@ import org.apache.spark.sql.{DataFrame, Row} object JdbcUtils extends Logging { /** - * Establishes a JDBC connection. + * Returns a factory for creating connections to the given JDBC URL. + * + * @param url the JDBC url to connect to. + * @param properties JDBC connection properties. */ - def createConnection(url: String, connectionProperties: Properties): Connection = { - JDBCRDD.getConnector(connectionProperties.getProperty("driver"), url, connectionProperties)() + def createConnectionFactory(url: String, properties: Properties): () => Connection = { + val userSpecifiedDriverClass = Option(properties.getProperty("driver")) + userSpecifiedDriverClass.foreach(DriverRegistry.register) + // Performing this part of the logic on the driver guards against the corner-case where the + // driver returned for a URL is different on the driver and executors due to classpath + // differences. + val driverClass: String = userSpecifiedDriverClass.getOrElse { + DriverManager.getDriver(url).getClass.getCanonicalName + } + () => { + userSpecifiedDriverClass.foreach(DriverRegistry.register) + val driver: Driver = DriverManager.getDrivers.asScala.collectFirst { + case d: DriverWrapper if d.wrapped.getClass.getCanonicalName == driverClass => d + case d if d.getClass.getCanonicalName == driverClass => d + }.getOrElse { + throw new IllegalStateException( + s"Did not find registered driver with class $driverClass") + } + driver.connect(url, properties) + } } /** @@ -242,15 +264,14 @@ object JdbcUtils extends Logging { df: DataFrame, url: String, table: String, - properties: Properties = new Properties()) { + properties: Properties) { val dialect = JdbcDialects.get(url) val nullTypes: Array[Int] = df.schema.fields.map { field => getJdbcType(field.dataType, dialect).jdbcNullType } val rddSchema = df.schema - val driver: String = DriverRegistry.getDriverClassName(url) - val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties) + val getConnection: () => Connection = createConnectionFactory(url, properties) val batchSize = properties.getProperty("batchsize", "1000").toInt df.foreachPartition { iterator => savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect) -- cgit v1.2.3