diff options
Diffstat (limited to 'sql/core/src/main')
3 files changed, 64 insertions, 8 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index d4d3464654..89fe86c038 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -119,6 +119,7 @@ class JDBCOptions( // E.g., "CREATE TABLE t (name string) ENGINE=InnoDB DEFAULT CHARSET=utf8" // TODO: to reuse the existing partition parameters for those partition specific options val createTableOptions = parameters.getOrElse(JDBC_CREATE_TABLE_OPTIONS, "") + val createTableColumnTypes = parameters.get(JDBC_CREATE_TABLE_COLUMN_TYPES) val batchSize = { val size = parameters.getOrElse(JDBC_BATCH_INSERT_SIZE, "1000").toInt require(size >= 1, @@ -154,6 +155,7 @@ object JDBCOptions { val JDBC_BATCH_FETCH_SIZE = newOption("fetchsize") val JDBC_TRUNCATE = newOption("truncate") val JDBC_CREATE_TABLE_OPTIONS = newOption("createTableOptions") + val JDBC_CREATE_TABLE_COLUMN_TYPES = newOption("createTableColumnTypes") val JDBC_BATCH_INSERT_SIZE = newOption("batchsize") val JDBC_TXN_ISOLATION_LEVEL = newOption("isolationLevel") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index 88f6cb0021..74dcfb06f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -69,7 +69,7 @@ class JdbcRelationProvider extends CreatableRelationProvider } else { // Otherwise, do not truncate the table, instead drop and recreate it dropTable(conn, options.table) - createTable(conn, df.schema, options) + createTable(conn, df, options) saveTable(df, Some(df.schema), isCaseSensitive, options) } @@ -87,7 +87,7 @@ class JdbcRelationProvider extends CreatableRelationProvider // Therefore, it is okay to do nothing here and then just return the relation below. } } else { - createTable(conn, df.schema, options) + createTable(conn, df, options) saveTable(df, Some(df.schema), isCaseSensitive, options) } } finally { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index d89f600874..774d1ba194 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -30,7 +30,8 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow -import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -680,12 +681,19 @@ object JdbcUtils extends Logging { /** * Compute the schema string for this RDD. */ - def schemaString(schema: StructType, url: String): String = { + def schemaString( + df: DataFrame, + url: String, + createTableColumnTypes: Option[String] = None): String = { val sb = new StringBuilder() val dialect = JdbcDialects.get(url) - schema.fields foreach { field => + val userSpecifiedColTypesMap = createTableColumnTypes + .map(parseUserSpecifiedCreateTableColumnTypes(df, _)) + .getOrElse(Map.empty[String, String]) + df.schema.fields.foreach { field => val name = dialect.quoteIdentifier(field.name) - val typ: String = getJdbcType(field.dataType, dialect).databaseTypeDefinition + val typ = userSpecifiedColTypesMap + .getOrElse(field.name, getJdbcType(field.dataType, dialect).databaseTypeDefinition) val nullable = if (field.nullable) "" else "NOT NULL" sb.append(s", $name $typ $nullable") } @@ -693,6 +701,51 @@ object JdbcUtils extends Logging { } /** + * Parses the user specified createTableColumnTypes option value string specified in the same + * format as create table ddl column types, and returns Map of field name and the data type to + * use in-place of the default data type. + */ + private def parseUserSpecifiedCreateTableColumnTypes( + df: DataFrame, + createTableColumnTypes: String): Map[String, String] = { + def typeName(f: StructField): String = { + // char/varchar gets translated to string type. Real data type specified by the user + // is available in the field metadata as HIVE_TYPE_STRING + if (f.metadata.contains(HIVE_TYPE_STRING)) { + f.metadata.getString(HIVE_TYPE_STRING) + } else { + f.dataType.catalogString + } + } + + val userSchema = CatalystSqlParser.parseTableSchema(createTableColumnTypes) + val nameEquality = df.sparkSession.sessionState.conf.resolver + + // checks duplicate columns in the user specified column types. + userSchema.fieldNames.foreach { col => + val duplicatesCols = userSchema.fieldNames.filter(nameEquality(_, col)) + if (duplicatesCols.size >= 2) { + throw new AnalysisException( + "Found duplicate column(s) in createTableColumnTypes option value: " + + duplicatesCols.mkString(", ")) + } + } + + // checks if user specified column names exist in the DataFrame schema + userSchema.fieldNames.foreach { col => + df.schema.find(f => nameEquality(f.name, col)).getOrElse { + throw new AnalysisException( + s"createTableColumnTypes option column $col not found in schema " + + df.schema.catalogString) + } + } + + val userSchemaMap = userSchema.fields.map(f => f.name -> typeName(f)).toMap + val isCaseSensitive = df.sparkSession.sessionState.conf.caseSensitiveAnalysis + if (isCaseSensitive) userSchemaMap else CaseInsensitiveMap(userSchemaMap) + } + + /** * Saves the RDD to the database in a single transaction. */ def saveTable( @@ -726,9 +779,10 @@ object JdbcUtils extends Logging { */ def createTable( conn: Connection, - schema: StructType, + df: DataFrame, options: JDBCOptions): Unit = { - val strSchema = schemaString(schema, options.url) + val strSchema = schemaString( + df, options.url, options.createTableColumnTypes) val table = options.table val createTableOptions = options.createTableOptions // Create the table if the table does not exist. |