aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala66
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala150
5 files changed, 212 insertions, 12 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
index d4d3464654..89fe86c038 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
@@ -119,6 +119,7 @@ class JDBCOptions(
// E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8"
// TODO: to reuse the existing partition parameters for those partition specific options
val createTableOptions = parameters.getOrElse(JDBC_CREATE_TABLE_OPTIONS, "")
+ val createTableColumnTypes = parameters.get(JDBC_CREATE_TABLE_COLUMN_TYPES)
val batchSize = {
val size = parameters.getOrElse(JDBC_BATCH_INSERT_SIZE, "1000").toInt
require(size >= 1,
@@ -154,6 +155,7 @@ object JDBCOptions {
val JDBC_BATCH_FETCH_SIZE = newOption("fetchsize")
val JDBC_TRUNCATE = newOption("truncate")
val JDBC_CREATE_TABLE_OPTIONS = newOption("createTableOptions")
+ val JDBC_CREATE_TABLE_COLUMN_TYPES = newOption("createTableColumnTypes")
val JDBC_BATCH_INSERT_SIZE = newOption("batchsize")
val JDBC_TXN_ISOLATION_LEVEL = newOption("isolationLevel")
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
index 88f6cb0021..74dcfb06f5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
@@ -69,7 +69,7 @@ class JdbcRelationProvider extends CreatableRelationProvider
} else {
// Otherwise, do not truncate the table, instead drop and recreate it
dropTable(conn, options.table)
- createTable(conn, df.schema, options)
+ createTable(conn, df, options)
saveTable(df, Some(df.schema), isCaseSensitive, options)
}
@@ -87,7 +87,7 @@ class JdbcRelationProvider extends CreatableRelationProvider
// Therefore, it is okay to do nothing here and then just return the relation below.
}
} else {
- createTable(conn, df.schema, options)
+ createTable(conn, df, options)
saveTable(df, Some(df.schema), isCaseSensitive, options)
}
} finally {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index d89f600874..774d1ba194 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -30,7 +30,8 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow
-import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
+import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
+import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -680,12 +681,19 @@ object JdbcUtils extends Logging {
/**
* Compute the schema string for this RDD.
*/
- def schemaString(schema: StructType, url: String): String = {
+ def schemaString(
+ df: DataFrame,
+ url: String,
+ createTableColumnTypes: Option[String] = None): String = {
val sb = new StringBuilder()
val dialect = JdbcDialects.get(url)
- schema.fields foreach { field =>
+ val userSpecifiedColTypesMap = createTableColumnTypes
+ .map(parseUserSpecifiedCreateTableColumnTypes(df, _))
+ .getOrElse(Map.empty[String, String])
+ df.schema.fields.foreach { field =>
val name = dialect.quoteIdentifier(field.name)
- val typ: String = getJdbcType(field.dataType, dialect).databaseTypeDefinition
+ val typ = userSpecifiedColTypesMap
+ .getOrElse(field.name, getJdbcType(field.dataType, dialect).databaseTypeDefinition)
val nullable = if (field.nullable) "" else "NOT NULL"
sb.append(s", $name $typ $nullable")
}
@@ -693,6 +701,51 @@ object JdbcUtils extends Logging {
}
/**
+ * Parses the user specified createTableColumnTypes option value string specified in the same
+ * format as create table ddl column types, and returns Map of field name and the data type to
+ * use in-place of the default data type.
+ */
+ private def parseUserSpecifiedCreateTableColumnTypes(
+ df: DataFrame,
+ createTableColumnTypes: String): Map[String, String] = {
+ def typeName(f: StructField): String = {
+ // char/varchar gets translated to string type. Real data type specified by the user
+ // is available in the field metadata as HIVE_TYPE_STRING
+ if (f.metadata.contains(HIVE_TYPE_STRING)) {
+ f.metadata.getString(HIVE_TYPE_STRING)
+ } else {
+ f.dataType.catalogString
+ }
+ }
+
+ val userSchema = CatalystSqlParser.parseTableSchema(createTableColumnTypes)
+ val nameEquality = df.sparkSession.sessionState.conf.resolver
+
+ // checks duplicate columns in the user specified column types.
+ userSchema.fieldNames.foreach { col =>
+ val duplicatesCols = userSchema.fieldNames.filter(nameEquality(_, col))
+ if (duplicatesCols.size >= 2) {
+ throw new AnalysisException(
+ "Found duplicate column(s) in createTableColumnTypes option value: " +
+ duplicatesCols.mkString(", "))
+ }
+ }
+
+ // checks if user specified column names exist in the DataFrame schema
+ userSchema.fieldNames.foreach { col =>
+ df.schema.find(f => nameEquality(f.name, col)).getOrElse {
+ throw new AnalysisException(
+ s"createTableColumnTypes option column $col not found in schema " +
+ df.schema.catalogString)
+ }
+ }
+
+ val userSchemaMap = userSchema.fields.map(f => f.name -> typeName(f)).toMap
+ val isCaseSensitive = df.sparkSession.sessionState.conf.caseSensitiveAnalysis
+ if (isCaseSensitive) userSchemaMap else CaseInsensitiveMap(userSchemaMap)
+ }
+
+ /**
* Saves the RDD to the database in a single transaction.
*/
def saveTable(
@@ -726,9 +779,10 @@ object JdbcUtils extends Logging {
*/
def createTable(
conn: Connection,
- schema: StructType,
+ df: DataFrame,
options: JDBCOptions): Unit = {
- val strSchema = schemaString(schema, options.url)
+ val strSchema = schemaString(
+ df, options.url, options.createTableColumnTypes)
val table = options.table
val createTableOptions = options.createTableOptions
// Create the table if the table does not exist.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index 5463728ca0..4a02277631 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -869,7 +869,7 @@ class JDBCSuite extends SparkFunSuite
test("SPARK-16387: Reserved SQL words are not escaped by JDBC writer") {
val df = spark.createDataset(Seq("a", "b", "c")).toDF("order")
- val schema = JdbcUtils.schemaString(df.schema, "jdbc:mysql://localhost:3306/temp")
+ val schema = JdbcUtils.schemaString(df, "jdbc:mysql://localhost:3306/temp")
assert(schema.contains("`order` TEXT"))
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
index ec7b19e666..bf1fd16070 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
@@ -17,15 +17,16 @@
package org.apache.spark.sql.jdbc
-import java.sql.DriverManager
+import java.sql.{Date, DriverManager, Timestamp}
import java.util.Properties
import scala.collection.JavaConverters.propertiesAsScalaMapConverter
import org.scalatest.BeforeAndAfter
-import org.apache.spark.sql.{AnalysisException, Row, SaveMode}
-import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
+import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode}
+import org.apache.spark.sql.catalyst.parser.ParseException
+import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
@@ -362,4 +363,147 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
assert(sql("select * from people_view").count() == 2)
}
}
+
+ test("SPARK-10849: test schemaString - from createTableColumnTypes option values") {
+ def testCreateTableColDataTypes(types: Seq[String]): Unit = {
+ val colTypes = types.zipWithIndex.map { case (t, i) => (s"col$i", t) }
+ val schema = colTypes
+ .foldLeft(new StructType())((schema, colType) => schema.add(colType._1, colType._2))
+ val createTableColTypes =
+ colTypes.map { case (col, dataType) => s"$col $dataType" }.mkString(", ")
+ val df = spark.createDataFrame(sparkContext.parallelize(Seq(Row.empty)), schema)
+
+ val expectedSchemaStr =
+ colTypes.map { case (col, dataType) => s""""$col" $dataType """ }.mkString(", ")
+
+ assert(JdbcUtils.schemaString(df, url1, Option(createTableColTypes)) == expectedSchemaStr)
+ }
+
+ testCreateTableColDataTypes(Seq("boolean"))
+ testCreateTableColDataTypes(Seq("tinyint", "smallint", "int", "bigint"))
+ testCreateTableColDataTypes(Seq("float", "double"))
+ testCreateTableColDataTypes(Seq("string", "char(10)", "varchar(20)"))
+ testCreateTableColDataTypes(Seq("decimal(10,0)", "decimal(10,5)"))
+ testCreateTableColDataTypes(Seq("date", "timestamp"))
+ testCreateTableColDataTypes(Seq("binary"))
+ }
+
+ test("SPARK-10849: create table using user specified column type and verify on target table") {
+ def testUserSpecifiedColTypes(
+ df: DataFrame,
+ createTableColTypes: String,
+ expectedTypes: Map[String, String]): Unit = {
+ df.write
+ .mode(SaveMode.Overwrite)
+ .option("createTableColumnTypes", createTableColTypes)
+ .jdbc(url1, "TEST.DBCOLTYPETEST", properties)
+
+ // verify the data types of the created table by reading the database catalog of H2
+ val query =
+ """
+ |(SELECT column_name, type_name, character_maximum_length
+ | FROM information_schema.columns WHERE table_name = 'DBCOLTYPETEST')
+ """.stripMargin
+ val rows = spark.read.jdbc(url1, query, properties).collect()
+
+ rows.foreach { row =>
+ val typeName = row.getString(1)
+ // For CHAR and VARCHAR, we also compare the max length
+ if (typeName.contains("CHAR")) {
+ val charMaxLength = row.getInt(2)
+ assert(expectedTypes(row.getString(0)) == s"$typeName($charMaxLength)")
+ } else {
+ assert(expectedTypes(row.getString(0)) == typeName)
+ }
+ }
+ }
+
+ val data = Seq[Row](Row(1, "dave", "Boston"))
+ val schema = StructType(
+ StructField("id", IntegerType) ::
+ StructField("first#name", StringType) ::
+ StructField("city", StringType) :: Nil)
+ val df = spark.createDataFrame(sparkContext.parallelize(data), schema)
+
+ // out-of-order
+ val expected1 = Map("id" -> "BIGINT", "first#name" -> "VARCHAR(123)", "city" -> "CHAR(20)")
+ testUserSpecifiedColTypes(df, "`first#name` VARCHAR(123), id BIGINT, city CHAR(20)", expected1)
+ // partial schema
+ val expected2 = Map("id" -> "INTEGER", "first#name" -> "VARCHAR(123)", "city" -> "CHAR(20)")
+ testUserSpecifiedColTypes(df, "`first#name` VARCHAR(123), city CHAR(20)", expected2)
+
+ withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
+ // should still respect the original column names
+ val expected = Map("id" -> "INTEGER", "first#name" -> "VARCHAR(123)", "city" -> "CLOB")
+ testUserSpecifiedColTypes(df, "`FiRsT#NaMe` VARCHAR(123)", expected)
+ }
+
+ withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
+ val schema = StructType(
+ StructField("id", IntegerType) ::
+ StructField("First#Name", StringType) ::
+ StructField("city", StringType) :: Nil)
+ val df = spark.createDataFrame(sparkContext.parallelize(data), schema)
+ val expected = Map("id" -> "INTEGER", "First#Name" -> "VARCHAR(123)", "city" -> "CLOB")
+ testUserSpecifiedColTypes(df, "`First#Name` VARCHAR(123)", expected)
+ }
+ }
+
+ test("SPARK-10849: jdbc CreateTableColumnTypes option with invalid data type") {
+ val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
+ val msg = intercept[ParseException] {
+ df.write.mode(SaveMode.Overwrite)
+ .option("createTableColumnTypes", "name CLOB(2000)")
+ .jdbc(url1, "TEST.USERDBTYPETEST", properties)
+ }.getMessage()
+ assert(msg.contains("DataType clob(2000) is not supported."))
+ }
+
+ test("SPARK-10849: jdbc CreateTableColumnTypes option with invalid syntax") {
+ val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
+ val msg = intercept[ParseException] {
+ df.write.mode(SaveMode.Overwrite)
+ .option("createTableColumnTypes", "`name char(20)") // incorrectly quoted column
+ .jdbc(url1, "TEST.USERDBTYPETEST", properties)
+ }.getMessage()
+ assert(msg.contains("no viable alternative at input"))
+ }
+
+ test("SPARK-10849: jdbc CreateTableColumnTypes duplicate columns") {
+ withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
+ val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
+ val msg = intercept[AnalysisException] {
+ df.write.mode(SaveMode.Overwrite)
+ .option("createTableColumnTypes", "name CHAR(20), id int, NaMe VARCHAR(100)")
+ .jdbc(url1, "TEST.USERDBTYPETEST", properties)
+ }.getMessage()
+ assert(msg.contains(
+ "Found duplicate column(s) in createTableColumnTypes option value: name, NaMe"))
+ }
+ }
+
+ test("SPARK-10849: jdbc CreateTableColumnTypes invalid columns") {
+ // schema2 has the column "id" and "name"
+ val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
+
+ withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
+ val msg = intercept[AnalysisException] {
+ df.write.mode(SaveMode.Overwrite)
+ .option("createTableColumnTypes", "firstName CHAR(20), id int")
+ .jdbc(url1, "TEST.USERDBTYPETEST", properties)
+ }.getMessage()
+ assert(msg.contains("createTableColumnTypes option column firstName not found in " +
+ "schema struct<name:string,id:int>"))
+ }
+
+ withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
+ val msg = intercept[AnalysisException] {
+ df.write.mode(SaveMode.Overwrite)
+ .option("createTableColumnTypes", "id int, Name VARCHAR(100)")
+ .jdbc(url1, "TEST.USERDBTYPETEST", properties)
+ }.getMessage()
+ assert(msg.contains("createTableColumnTypes option column Name not found in " +
+ "schema struct<name:string,id:int>"))
+ }
+ }
}