aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala60
3 files changed, 62 insertions, 4 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
index cef92abbdc..2f6ba48dbc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
@@ -159,7 +159,7 @@ private[sql] object JDBCRDD extends Logging {
def getConnector(driver: String, url: String, properties: Properties): () => Connection = {
() => {
try {
- if (driver != null) Utils.getContextOrSparkClassLoader.loadClass(driver)
+ if (driver != null) DriverRegistry.register(driver)
} catch {
case e: ClassNotFoundException => {
logWarning(s"Couldn't find class $driver", e);
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala
index 5f480083d5..d6b3fb3291 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala
@@ -100,7 +100,7 @@ private[sql] class DefaultSource extends RelationProvider {
val upperBound = parameters.getOrElse("upperBound", null)
val numPartitions = parameters.getOrElse("numPartitions", null)
- if (driver != null) Utils.getContextOrSparkClassLoader.loadClass(driver)
+ if (driver != null) DriverRegistry.register(driver)
if (partitionColumn != null
&& (lowerBound == null || upperBound == null || numPartitions == null)) {
@@ -136,7 +136,7 @@ private[sql] case class JDBCRelation(
override val schema: StructType = JDBCRDD.resolveTable(url, table, properties)
override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
- val driver: String = DriverManager.getDriver(url).getClass.getCanonicalName
+ val driver: String = DriverRegistry.getDriverClassName(url)
JDBCRDD.scanTable(
sqlContext.sparkContext,
schema,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala
index d4e0abc040..ae9af1eabe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala
@@ -17,10 +17,14 @@
package org.apache.spark.sql
-import java.sql.{Connection, DriverManager, PreparedStatement}
+import java.sql.{Connection, Driver, DriverManager, DriverPropertyInfo, PreparedStatement}
+import java.util.Properties
+
+import scala.collection.mutable
import org.apache.spark.Logging
import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
package object jdbc {
private[sql] object JDBCWriteDetails extends Logging {
@@ -179,4 +183,58 @@ package object jdbc {
}
}
+
+ private [sql] class DriverWrapper(val wrapped: Driver) extends Driver {
+ override def acceptsURL(url: String): Boolean = wrapped.acceptsURL(url)
+
+ override def jdbcCompliant(): Boolean = wrapped.jdbcCompliant()
+
+ override def getPropertyInfo(url: String, info: Properties): Array[DriverPropertyInfo] = {
+ wrapped.getPropertyInfo(url, info)
+ }
+
+ override def getMinorVersion: Int = wrapped.getMinorVersion
+
+ override def getParentLogger: java.util.logging.Logger = wrapped.getParentLogger
+
+ override def connect(url: String, info: Properties): Connection = wrapped.connect(url, info)
+
+ override def getMajorVersion: Int = wrapped.getMajorVersion
+ }
+
+ /**
+ * java.sql.DriverManager is always loaded by bootstrap classloader,
+ * so it can't load JDBC drivers accessible by Spark ClassLoader.
+ *
+ * To solve the problem, drivers from user-supplied jars are wrapped
+ * into thin wrapper.
+ */
+ private [sql] object DriverRegistry extends Logging {
+
+ private val wrapperMap: mutable.Map[String, DriverWrapper] = mutable.Map.empty
+
+ def register(className: String): Unit = {
+ val cls = Utils.getContextOrSparkClassLoader.loadClass(className)
+ if (cls.getClassLoader == null) {
+ logTrace(s"$className has been loaded with bootstrap ClassLoader, wrapper is not required")
+ } else if (wrapperMap.get(className).isDefined) {
+ logTrace(s"Wrapper for $className already exists")
+ } else {
+ synchronized {
+ if (wrapperMap.get(className).isEmpty) {
+ val wrapper = new DriverWrapper(cls.newInstance().asInstanceOf[Driver])
+ DriverManager.registerDriver(wrapper)
+ wrapperMap(className) = wrapper
+ logTrace(s"Wrapper for $className registered")
+ }
+ }
+ }
+ }
+
+ def getDriverClassName(url: String): String = DriverManager.getDriver(url) match {
+ case wrapper: DriverWrapper => wrapper.wrapped.getClass.getCanonicalName
+ case driver => driver.getClass.getCanonicalName
+ }
+ }
+
} // package object jdbc