aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core/src/main')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala66
3 files changed, 64 insertions, 8 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
index d4d3464654..89fe86c038 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
@@ -119,6 +119,7 @@ class JDBCOptions(
// E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8"
// TODO: to reuse the existing partition parameters for those partition specific options
val createTableOptions = parameters.getOrElse(JDBC_CREATE_TABLE_OPTIONS, "")
+ val createTableColumnTypes = parameters.get(JDBC_CREATE_TABLE_COLUMN_TYPES)
val batchSize = {
val size = parameters.getOrElse(JDBC_BATCH_INSERT_SIZE, "1000").toInt
require(size >= 1,
@@ -154,6 +155,7 @@ object JDBCOptions {
val JDBC_BATCH_FETCH_SIZE = newOption("fetchsize")
val JDBC_TRUNCATE = newOption("truncate")
val JDBC_CREATE_TABLE_OPTIONS = newOption("createTableOptions")
+ val JDBC_CREATE_TABLE_COLUMN_TYPES = newOption("createTableColumnTypes")
val JDBC_BATCH_INSERT_SIZE = newOption("batchsize")
val JDBC_TXN_ISOLATION_LEVEL = newOption("isolationLevel")
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
index 88f6cb0021..74dcfb06f5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
@@ -69,7 +69,7 @@ class JdbcRelationProvider extends CreatableRelationProvider
} else {
// Otherwise, do not truncate the table, instead drop and recreate it
dropTable(conn, options.table)
- createTable(conn, df.schema, options)
+ createTable(conn, df, options)
saveTable(df, Some(df.schema), isCaseSensitive, options)
}
@@ -87,7 +87,7 @@ class JdbcRelationProvider extends CreatableRelationProvider
// Therefore, it is okay to do nothing here and then just return the relation below.
}
} else {
- createTable(conn, df.schema, options)
+ createTable(conn, df, options)
saveTable(df, Some(df.schema), isCaseSensitive, options)
}
} finally {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index d89f600874..774d1ba194 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -30,7 +30,8 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow
-import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
+import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
+import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -680,12 +681,19 @@ object JdbcUtils extends Logging {
/**
* Compute the schema string for this RDD.
*/
- def schemaString(schema: StructType, url: String): String = {
+ def schemaString(
+ df: DataFrame,
+ url: String,
+ createTableColumnTypes: Option[String] = None): String = {
val sb = new StringBuilder()
val dialect = JdbcDialects.get(url)
- schema.fields foreach { field =>
+ val userSpecifiedColTypesMap = createTableColumnTypes
+ .map(parseUserSpecifiedCreateTableColumnTypes(df, _))
+ .getOrElse(Map.empty[String, String])
+ df.schema.fields.foreach { field =>
val name = dialect.quoteIdentifier(field.name)
- val typ: String = getJdbcType(field.dataType, dialect).databaseTypeDefinition
+ val typ = userSpecifiedColTypesMap
+ .getOrElse(field.name, getJdbcType(field.dataType, dialect).databaseTypeDefinition)
val nullable = if (field.nullable) "" else "NOT NULL"
sb.append(s", $name $typ $nullable")
}
@@ -693,6 +701,51 @@ object JdbcUtils extends Logging {
}
/**
+ * Parses the user specified createTableColumnTypes option value string specified in the same
+ * format as create table ddl column types, and returns Map of field name and the data type to
+ * use in-place of the default data type.
+ */
+ private def parseUserSpecifiedCreateTableColumnTypes(
+ df: DataFrame,
+ createTableColumnTypes: String): Map[String, String] = {
+ def typeName(f: StructField): String = {
+ // char/varchar gets translated to string type. Real data type specified by the user
+ // is available in the field metadata as HIVE_TYPE_STRING
+ if (f.metadata.contains(HIVE_TYPE_STRING)) {
+ f.metadata.getString(HIVE_TYPE_STRING)
+ } else {
+ f.dataType.catalogString
+ }
+ }
+
+ val userSchema = CatalystSqlParser.parseTableSchema(createTableColumnTypes)
+ val nameEquality = df.sparkSession.sessionState.conf.resolver
+
+ // checks duplicate columns in the user specified column types.
+ userSchema.fieldNames.foreach { col =>
+ val duplicatesCols = userSchema.fieldNames.filter(nameEquality(_, col))
+ if (duplicatesCols.size >= 2) {
+ throw new AnalysisException(
+ "Found duplicate column(s) in createTableColumnTypes option value: " +
+ duplicatesCols.mkString(", "))
+ }
+ }
+
+ // checks if user specified column names exist in the DataFrame schema
+ userSchema.fieldNames.foreach { col =>
+ df.schema.find(f => nameEquality(f.name, col)).getOrElse {
+ throw new AnalysisException(
+ s"createTableColumnTypes option column $col not found in schema " +
+ df.schema.catalogString)
+ }
+ }
+
+ val userSchemaMap = userSchema.fields.map(f => f.name -> typeName(f)).toMap
+ val isCaseSensitive = df.sparkSession.sessionState.conf.caseSensitiveAnalysis
+ if (isCaseSensitive) userSchemaMap else CaseInsensitiveMap(userSchemaMap)
+ }
+
+ /**
* Saves the RDD to the database in a single transaction.
*/
def saveTable(
@@ -726,9 +779,10 @@ object JdbcUtils extends Logging {
*/
def createTable(
conn: Connection,
- schema: StructType,
+ df: DataFrame,
options: JDBCOptions): Unit = {
- val strSchema = schemaString(schema, options.url)
+ val strSchema = schemaString(
+ df, options.url, options.createTableColumnTypes)
val table = options.table
val createTableOptions = options.createTableOptions
// Create the table if the table does not exist.