aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--project/MimaExcludes.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala99
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala211
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala43
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala49
6 files changed, 295 insertions, 126 deletions
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 487062a31f..513bbaf98d 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -137,6 +137,14 @@ object MimaExcludes {
// implementing this interface in Java. Note that ShuffleWriter is private[spark].
ProblemFilters.exclude[IncompatibleTemplateDefProblem](
"org.apache.spark.shuffle.ShuffleWriter")
+ ) ++ Seq(
+ // SPARK-6888 make jdbc driver handling user definable
+ // This patch renames some classes to API friendly names.
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DriverQuirks$"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.DriverQuirks"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.PostgresQuirks"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.NoQuirks"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.jdbc.MySQLQuirks")
)
case v if v.startsWith("1.3") =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala
deleted file mode 100644
index 0feabc4282..0000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala
+++ /dev/null
@@ -1,99 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.jdbc
-
-import org.apache.spark.sql.types._
-
-import java.sql.Types
-
-
-/**
- * Encapsulates workarounds for the extensions, quirks, and bugs in various
- * databases. Lots of databases define types that aren't explicitly supported
- * by the JDBC spec. Some JDBC drivers also report inaccurate
- * information---for instance, BIT(n>1) being reported as a BIT type is quite
- * common, even though BIT in JDBC is meant for single-bit values. Also, there
- * does not appear to be a standard name for an unbounded string or binary
- * type; we use BLOB and CLOB by default but override with database-specific
- * alternatives when these are absent or do not behave correctly.
- *
- * Currently, the only thing DriverQuirks does is handle type mapping.
- * `getCatalystType` is used when reading from a JDBC table and `getJDBCType`
- * is used when writing to a JDBC table. If `getCatalystType` returns `null`,
- * the default type handling is used for the given JDBC type. Similarly,
- * if `getJDBCType` returns `(null, None)`, the default type handling is used
- * for the given Catalyst type.
- */
-private[sql] abstract class DriverQuirks {
- def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType
- def getJDBCType(dt: DataType): (String, Option[Int])
-}
-
-private[sql] object DriverQuirks {
- /**
- * Fetch the DriverQuirks class corresponding to a given database url.
- */
- def get(url: String): DriverQuirks = {
- if (url.startsWith("jdbc:mysql")) {
- new MySQLQuirks()
- } else if (url.startsWith("jdbc:postgresql")) {
- new PostgresQuirks()
- } else {
- new NoQuirks()
- }
- }
-}
-
-private[sql] class NoQuirks extends DriverQuirks {
- def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType =
- null
- def getJDBCType(dt: DataType): (String, Option[Int]) = (null, None)
-}
-
-private[sql] class PostgresQuirks extends DriverQuirks {
- def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = {
- if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) {
- BinaryType
- } else if (sqlType == Types.OTHER && typeName.equals("cidr")) {
- StringType
- } else if (sqlType == Types.OTHER && typeName.equals("inet")) {
- StringType
- } else null
- }
-
- def getJDBCType(dt: DataType): (String, Option[Int]) = dt match {
- case StringType => ("TEXT", Some(java.sql.Types.CHAR))
- case BinaryType => ("BYTEA", Some(java.sql.Types.BINARY))
- case BooleanType => ("BOOLEAN", Some(java.sql.Types.BOOLEAN))
- case _ => (null, None)
- }
-}
-
-private[sql] class MySQLQuirks extends DriverQuirks {
- def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = {
- if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) {
- // This could instead be a BinaryType if we'd rather return bit-vectors of up to 64 bits as
- // byte arrays instead of longs.
- md.putLong("binarylong", 1)
- LongType
- } else if (sqlType == Types.BIT && typeName.equals("TINYINT")) {
- BooleanType
- } else null
- }
- def getJDBCType(dt: DataType): (String, Option[Int]) = (null, None)
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
index 4189dfcf95..f7b19096ea 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
@@ -41,7 +41,7 @@ private[sql] object JDBCRDD extends Logging {
/**
* Maps a JDBC type to a Catalyst type. This function is called only when
- * the DriverQuirks class corresponding to your database driver returns null.
+ * the JdbcDialect class corresponding to your database driver returns null.
*
* @param sqlType - A field of java.sql.Types
* @return The Catalyst type corresponding to sqlType.
@@ -51,7 +51,7 @@ private[sql] object JDBCRDD extends Logging {
case java.sql.Types.ARRAY => null
case java.sql.Types.BIGINT => LongType
case java.sql.Types.BINARY => BinaryType
- case java.sql.Types.BIT => BooleanType // Per JDBC; Quirks handles quirky drivers.
+ case java.sql.Types.BIT => BooleanType // @see JdbcDialect for quirks
case java.sql.Types.BLOB => BinaryType
case java.sql.Types.BOOLEAN => BooleanType
case java.sql.Types.CHAR => StringType
@@ -108,7 +108,7 @@ private[sql] object JDBCRDD extends Logging {
* @throws SQLException if the table contains an unsupported type.
*/
def resolveTable(url: String, table: String, properties: Properties): StructType = {
- val quirks = DriverQuirks.get(url)
+ val dialect = JdbcDialects.get(url)
val conn: Connection = DriverManager.getConnection(url, properties)
try {
val rs = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0").executeQuery()
@@ -125,8 +125,9 @@ private[sql] object JDBCRDD extends Logging {
val fieldScale = rsmd.getScale(i + 1)
val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls
val metadata = new MetadataBuilder().putString("name", columnName)
- var columnType = quirks.getCatalystType(dataType, typeName, fieldSize, metadata)
- if (columnType == null) columnType = getCatalystType(dataType, fieldSize, fieldScale)
+ val columnType =
+ dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse(
+ getCatalystType(dataType, fieldSize, fieldScale))
fields(i) = StructField(columnName, columnType, nullable, metadata.build())
i = i + 1
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
new file mode 100644
index 0000000000..6a169e106b
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
@@ -0,0 +1,211 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.jdbc
+
+import org.apache.spark.sql.types._
+import org.apache.spark.annotation.DeveloperApi
+
+import java.sql.Types
+
+/**
+ * :: DeveloperApi ::
+ * A database type definition coupled with the jdbc type needed to send null
+ * values to the database.
+ * @param databaseTypeDefinition The database type definition
+ * @param jdbcNullType The jdbc type (as defined in java.sql.Types) used to
+ * send a null value to the database.
+ */
+@DeveloperApi
+case class JdbcType(databaseTypeDefinition : String, jdbcNullType : Int)
+
+/**
+ * :: DeveloperApi ::
+ * Encapsulates everything (extensions, workarounds, quirks) to handle the
+ * SQL dialect of a certain database or jdbc driver.
+ * Lots of databases define types that aren't explicitly supported
+ * by the JDBC spec. Some JDBC drivers also report inaccurate
+ * information---for instance, BIT(n>1) being reported as a BIT type is quite
+ * common, even though BIT in JDBC is meant for single-bit values. Also, there
+ * does not appear to be a standard name for an unbounded string or binary
+ * type; we use BLOB and CLOB by default but override with database-specific
+ * alternatives when these are absent or do not behave correctly.
+ *
+ * Currently, the only thing done by the dialect is type mapping.
+ * `getCatalystType` is used when reading from a JDBC table and `getJDBCType`
+ * is used when writing to a JDBC table. If `getCatalystType` returns `null`,
+ * the default type handling is used for the given JDBC type. Similarly,
+ * if `getJDBCType` returns `(null, None)`, the default type handling is used
+ * for the given Catalyst type.
+ */
+@DeveloperApi
+abstract class JdbcDialect {
+ /**
+ * Check if this dialect instance can handle a certain jdbc url.
+ * @param url the jdbc url.
+ * @return True if the dialect can be applied on the given jdbc url.
+ * @throws NullPointerException if the url is null.
+ */
+ def canHandle(url : String): Boolean
+
+ /**
+ * Get the custom datatype mapping for the given jdbc meta information.
+ * @param sqlType The sql type (see java.sql.Types)
+ * @param typeName The sql type name (e.g. "BIGINT UNSIGNED")
+ * @param size The size of the type.
+ * @param md Result metadata associated with this type.
+ * @return The actual DataType (subclasses of [[org.apache.spark.sql.types.DataType]])
+ * or null if the default type mapping should be used.
+ */
+ def getCatalystType(
+ sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = None
+
+ /**
+ * Retrieve the jdbc / sql type for a given datatype.
+ * @param dt The datatype (e.g. [[org.apache.spark.sql.types.StringType]])
+ * @return The new JdbcType if there is an override for this DataType
+ */
+ def getJDBCType(dt: DataType): Option[JdbcType] = None
+}
+
+/**
+ * :: DeveloperApi ::
+ * Registry of dialects that apply to every new jdbc [[org.apache.spark.sql.DataFrame]].
+ *
+ * If multiple matching dialects are registered then all matching ones will be
+ * tried in reverse order. A user-added dialect will thus be applied first,
+ * overwriting the defaults.
+ *
+ * Note that all new dialects are applied to new jdbc DataFrames only. Make
+ * sure to register your dialects first.
+ */
+@DeveloperApi
+object JdbcDialects {
+
+ private var dialects = List[JdbcDialect]()
+
+ /**
+ * Register a dialect for use on all new matching jdbc [[org.apache.spark.sql.DataFrame]].
+ * Readding an existing dialect will cause a move-to-front.
+ * @param dialect The new dialect.
+ */
+ def registerDialect(dialect: JdbcDialect) : Unit = {
+ dialects = dialect :: dialects.filterNot(_ == dialect)
+ }
+
+ /**
+ * Unregister a dialect. Does nothing if the dialect is not registered.
+ * @param dialect The jdbc dialect.
+ */
+ def unregisterDialect(dialect : JdbcDialect) : Unit = {
+ dialects = dialects.filterNot(_ == dialect)
+ }
+
+ registerDialect(MySQLDialect)
+ registerDialect(PostgresDialect)
+
+ /**
+ * Fetch the JdbcDialect class corresponding to a given database url.
+ */
+ private[sql] def get(url: String): JdbcDialect = {
+ val matchingDialects = dialects.filter(_.canHandle(url))
+ matchingDialects.length match {
+ case 0 => NoopDialect
+ case 1 => matchingDialects.head
+ case _ => new AggregatedDialect(matchingDialects)
+ }
+ }
+}
+
+/**
+ * :: DeveloperApi ::
+ * AggregatedDialect can unify multiple dialects into one virtual Dialect.
+ * Dialects are tried in order, and the first dialect that does not return a
+ * neutral element will will.
+ * @param dialects List of dialects.
+ */
+@DeveloperApi
+class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect {
+
+ require(!dialects.isEmpty)
+
+ def canHandle(url : String): Boolean =
+ dialects.map(_.canHandle(url)).reduce(_ && _)
+
+ override def getCatalystType(
+ sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] =
+ dialects.map(_.getCatalystType(sqlType, typeName, size, md)).flatten.headOption
+
+ override def getJDBCType(dt: DataType): Option[JdbcType] =
+ dialects.map(_.getJDBCType(dt)).flatten.headOption
+
+}
+
+/**
+ * :: DeveloperApi ::
+ * NOOP dialect object, always returning the neutral element.
+ */
+@DeveloperApi
+case object NoopDialect extends JdbcDialect {
+ def canHandle(url : String): Boolean = true
+}
+
+/**
+ * :: DeveloperApi ::
+ * Default postgres dialect, mapping bit/cidr/inet on read and string/binary/boolean on write.
+ */
+@DeveloperApi
+case object PostgresDialect extends JdbcDialect {
+ def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql")
+ override def getCatalystType(
+ sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
+ if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) {
+ Some(BinaryType)
+ } else if (sqlType == Types.OTHER && typeName.equals("cidr")) {
+ Some(StringType)
+ } else if (sqlType == Types.OTHER && typeName.equals("inet")) {
+ Some(StringType)
+ } else None
+ }
+
+ override def getJDBCType(dt: DataType): Option[JdbcType] = dt match {
+ case StringType => Some(JdbcType("TEXT", java.sql.Types.CHAR))
+ case BinaryType => Some(JdbcType("BYTEA", java.sql.Types.BINARY))
+ case BooleanType => Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN))
+ case _ => None
+ }
+}
+
+/**
+ * :: DeveloperApi ::
+ * Default mysql dialect to read bit/bitsets correctly.
+ */
+@DeveloperApi
+case object MySQLDialect extends JdbcDialect {
+ def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql")
+ override def getCatalystType(
+ sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
+ if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) {
+ // This could instead be a BinaryType if we'd rather return bit-vectors of up to 64 bits as
+ // byte arrays instead of longs.
+ md.putLong("binarylong", 1)
+ Some(LongType)
+ } else if (sqlType == Types.BIT && typeName.equals("TINYINT")) {
+ Some(BooleanType)
+ } else None
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala
index a61790b847..f21dd29aca 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala
@@ -129,25 +129,26 @@ package object jdbc {
*/
def schemaString(df: DataFrame, url: String): String = {
val sb = new StringBuilder()
- val quirks = DriverQuirks.get(url)
+ val dialect = JdbcDialects.get(url)
df.schema.fields foreach { field => {
val name = field.name
- var typ: String = quirks.getJDBCType(field.dataType)._1
- if (typ == null) typ = field.dataType match {
- case IntegerType => "INTEGER"
- case LongType => "BIGINT"
- case DoubleType => "DOUBLE PRECISION"
- case FloatType => "REAL"
- case ShortType => "INTEGER"
- case ByteType => "BYTE"
- case BooleanType => "BIT(1)"
- case StringType => "TEXT"
- case BinaryType => "BLOB"
- case TimestampType => "TIMESTAMP"
- case DateType => "DATE"
- case DecimalType.Unlimited => "DECIMAL(40,20)"
- case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC")
- }
+ val typ: String =
+ dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse(
+ field.dataType match {
+ case IntegerType => "INTEGER"
+ case LongType => "BIGINT"
+ case DoubleType => "DOUBLE PRECISION"
+ case FloatType => "REAL"
+ case ShortType => "INTEGER"
+ case ByteType => "BYTE"
+ case BooleanType => "BIT(1)"
+ case StringType => "TEXT"
+ case BinaryType => "BLOB"
+ case TimestampType => "TIMESTAMP"
+ case DateType => "DATE"
+ case DecimalType.Unlimited => "DECIMAL(40,20)"
+ case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC")
+ })
val nullable = if (field.nullable) "" else "NOT NULL"
sb.append(s", $name $typ $nullable")
}}
@@ -162,10 +163,9 @@ package object jdbc {
url: String,
table: String,
properties: Properties = new Properties()) {
- val quirks = DriverQuirks.get(url)
+ val dialect = JdbcDialects.get(url)
val nullTypes: Array[Int] = df.schema.fields.map { field =>
- val nullType: Option[Int] = quirks.getJDBCType(field.dataType)._2
- if (nullType.isEmpty) {
+ dialect.getJDBCType(field.dataType).map(_.jdbcNullType).getOrElse(
field.dataType match {
case IntegerType => java.sql.Types.INTEGER
case LongType => java.sql.Types.BIGINT
@@ -181,8 +181,7 @@ package object jdbc {
case DecimalType.Unlimited => java.sql.Types.DECIMAL
case _ => throw new IllegalArgumentException(
s"Can't translate null value for field $field")
- }
- } else nullType.get
+ })
}
val rddSchema = df.schema
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index 5a7b6f0aac..a8dddfb9b6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -35,6 +35,13 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
val testBytes = Array[Byte](99.toByte, 134.toByte, 135.toByte, 200.toByte, 205.toByte)
+ val testH2Dialect = new JdbcDialect {
+ def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2")
+ override def getCatalystType(
+ sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] =
+ Some(StringType)
+ }
+
before {
Class.forName("org.h2.Driver")
// Extra properties that will be specified for our database. We need these to test
@@ -353,4 +360,46 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
""".stripMargin.replaceAll("\n", " "))
}
}
+
+ test("Remap types via JdbcDialects") {
+ JdbcDialects.registerDialect(testH2Dialect)
+ val df = TestSQLContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties)
+ assert(df.schema.filter(
+ _.dataType != org.apache.spark.sql.types.StringType
+ ).isEmpty)
+ val rows = df.collect()
+ assert(rows(0).get(0).isInstanceOf[String])
+ assert(rows(0).get(1).isInstanceOf[String])
+ JdbcDialects.unregisterDialect(testH2Dialect)
+ }
+
+ test("Default jdbc dialect registration") {
+ assert(JdbcDialects.get("jdbc:mysql://127.0.0.1/db") == MySQLDialect)
+ assert(JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") == PostgresDialect)
+ assert(JdbcDialects.get("test.invalid") == NoopDialect)
+ }
+
+ test("Dialect unregister") {
+ JdbcDialects.registerDialect(testH2Dialect)
+ JdbcDialects.unregisterDialect(testH2Dialect)
+ assert(JdbcDialects.get(urlWithUserAndPass) == NoopDialect)
+ }
+
+ test("Aggregated dialects") {
+ val agg = new AggregatedDialect(List(new JdbcDialect {
+ def canHandle(url: String) : Boolean = url.startsWith("jdbc:h2:")
+ override def getCatalystType(
+ sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] =
+ if (sqlType % 2 == 0) {
+ Some(LongType)
+ } else {
+ None
+ }
+ }, testH2Dialect))
+ assert(agg.canHandle("jdbc:h2:xxx"))
+ assert(!agg.canHandle("jdbc:h2"))
+ assert(agg.getCatalystType(0,"",1,null) == Some(LongType))
+ assert(agg.getCatalystType(1,"",1,null) == Some(StringType))
+ }
+
}