aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDongjoon Hyun <dongjoon@apache.org>2016-12-30 10:27:14 -0800
committergatorsmile <gatorsmile@gmail.com>2016-12-30 10:27:14 -0800
commitb85e29437d570118f5980a1d6ba56c1f06a3dfd1 (patch)
treec0e0b305a8fe60c4a535517fbdd489e735ae7777 /sql
parent852782b83c89c358cf429e3913b71a1a6c44f27a (diff)
downloadspark-b85e29437d570118f5980a1d6ba56c1f06a3dfd1.tar.gz
spark-b85e29437d570118f5980a1d6ba56c1f06a3dfd1.tar.bz2
spark-b85e29437d570118f5980a1d6ba56c1f06a3dfd1.zip
[SPARK-18123][SQL] Use db column names instead of RDD column ones during JDBC Writing
## What changes were proposed in this pull request? Apache Spark supports the following cases **by quoting RDD column names** while saving through JDBC. - Allow reserved keyword as a column name, e.g., 'order'. - Allow mixed-case colume names like the following, e.g., `[a: int, A: int]`. ``` scala scala> val df = sql("select 1 a, 1 A") df: org.apache.spark.sql.DataFrame = [a: int, A: int] ... scala> df.write.mode("overwrite").format("jdbc").options(option).save() scala> df.write.mode("append").format("jdbc").options(option).save() ``` This PR aims to use **database column names** instead of RDD column ones in order to support the following additionally. Note that this case succeeds with `MySQL`, but fails on `Postgres`/`Oracle` before. ``` scala val df1 = sql("select 1 a") val df2 = sql("select 1 A") ... df1.write.mode("overwrite").format("jdbc").options(option).save() df2.write.mode("append").format("jdbc").options(option).save() ``` ## How was this patch tested? Pass the Jenkins test with a new testcase. Author: Dongjoon Hyun <dongjoon@apache.org> Author: gatorsmile <gatorsmile@gmail.com> Closes #15664 from dongjoon-hyun/SPARK-18123.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala74
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala35
3 files changed, 95 insertions, 25 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
index 74f397c01e..e39d936f39 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
@@ -57,6 +57,7 @@ class JdbcRelationProvider extends CreatableRelationProvider
val table = jdbcOptions.table
val createTableOptions = jdbcOptions.createTableOptions
val isTruncate = jdbcOptions.isTruncate
+ val isCaseSensitive = sqlContext.conf.caseSensitiveAnalysis
val conn = JdbcUtils.createConnectionFactory(jdbcOptions)()
try {
@@ -67,16 +68,18 @@ class JdbcRelationProvider extends CreatableRelationProvider
if (isTruncate && isCascadingTruncateTable(url) == Some(false)) {
// In this case, we should truncate table and then load.
truncateTable(conn, table)
- saveTable(df, url, table, jdbcOptions)
+ val tableSchema = JdbcUtils.getSchemaOption(conn, url, table)
+ saveTable(df, url, table, tableSchema, isCaseSensitive, jdbcOptions)
} else {
// Otherwise, do not truncate the table, instead drop and recreate it
dropTable(conn, table)
createTable(df.schema, url, table, createTableOptions, conn)
- saveTable(df, url, table, jdbcOptions)
+ saveTable(df, url, table, Some(df.schema), isCaseSensitive, jdbcOptions)
}
case SaveMode.Append =>
- saveTable(df, url, table, jdbcOptions)
+ val tableSchema = JdbcUtils.getSchemaOption(conn, url, table)
+ saveTable(df, url, table, tableSchema, isCaseSensitive, jdbcOptions)
case SaveMode.ErrorIfExists =>
throw new AnalysisException(
@@ -89,7 +92,7 @@ class JdbcRelationProvider extends CreatableRelationProvider
}
} else {
createTable(df.schema, url, table, createTableOptions, conn)
- saveTable(df, url, table, jdbcOptions)
+ saveTable(df, url, table, Some(df.schema), isCaseSensitive, jdbcOptions)
}
} finally {
conn.close()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index ff29a15960..b138494758 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -26,7 +26,7 @@ import scala.util.control.NonFatal
import org.apache.spark.TaskContext
import org.apache.spark.executor.InputMetrics
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{AnalysisException, DataFrame, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow
@@ -108,14 +108,36 @@ object JdbcUtils extends Logging {
}
/**
- * Returns a PreparedStatement that inserts a row into table via conn.
+ * Returns an Insert SQL statement for inserting a row into the target table via JDBC conn.
*/
- def insertStatement(conn: Connection, table: String, rddSchema: StructType, dialect: JdbcDialect)
- : PreparedStatement = {
- val columns = rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",")
+ def getInsertStatement(
+ table: String,
+ rddSchema: StructType,
+ tableSchema: Option[StructType],
+ isCaseSensitive: Boolean,
+ dialect: JdbcDialect): String = {
+ val columns = if (tableSchema.isEmpty) {
+ rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",")
+ } else {
+ val columnNameEquality = if (isCaseSensitive) {
+ org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
+ } else {
+ org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution
+ }
+ // The generated insert statement needs to follow rddSchema's column sequence and
+ // tableSchema's column names. When appending data into some case-sensitive DBMSs like
+ // PostgreSQL/Oracle, we need to respect the existing case-sensitive column names instead of
+ // RDD column names for user convenience.
+ val tableColumnNames = tableSchema.get.fieldNames
+ rddSchema.fields.map { col =>
+ val normalizedName = tableColumnNames.find(f => columnNameEquality(f, col.name)).getOrElse {
+ throw new AnalysisException(s"""Column "${col.name}" not found in schema $tableSchema""")
+ }
+ dialect.quoteIdentifier(normalizedName)
+ }.mkString(",")
+ }
val placeholders = rddSchema.fields.map(_ => "?").mkString(",")
- val sql = s"INSERT INTO $table ($columns) VALUES ($placeholders)"
- conn.prepareStatement(sql)
+ s"INSERT INTO $table ($columns) VALUES ($placeholders)"
}
/**
@@ -211,6 +233,26 @@ object JdbcUtils extends Logging {
}
/**
+ * Returns the schema if the table already exists in the JDBC database.
+ */
+ def getSchemaOption(conn: Connection, url: String, table: String): Option[StructType] = {
+ val dialect = JdbcDialects.get(url)
+
+ try {
+ val statement = conn.prepareStatement(dialect.getSchemaQuery(table))
+ try {
+ Some(getSchema(statement.executeQuery(), dialect))
+ } catch {
+ case _: SQLException => None
+ } finally {
+ statement.close()
+ }
+ } catch {
+ case _: SQLException => None
+ }
+ }
+
+ /**
* Takes a [[ResultSet]] and returns its Catalyst schema.
*
* @return A [[StructType]] giving the Catalyst schema.
@@ -531,7 +573,7 @@ object JdbcUtils extends Logging {
table: String,
iterator: Iterator[Row],
rddSchema: StructType,
- nullTypes: Array[Int],
+ insertStmt: String,
batchSize: Int,
dialect: JdbcDialect,
isolationLevel: Int): Iterator[Byte] = {
@@ -568,9 +610,9 @@ object JdbcUtils extends Logging {
conn.setAutoCommit(false) // Everything in the same db transaction.
conn.setTransactionIsolation(finalIsolationLevel)
}
- val stmt = insertStatement(conn, table, rddSchema, dialect)
- val setters: Array[JDBCValueSetter] = rddSchema.fields.map(_.dataType)
- .map(makeSetter(conn, dialect, _)).toArray
+ val stmt = conn.prepareStatement(insertStmt)
+ val setters = rddSchema.fields.map(f => makeSetter(conn, dialect, f.dataType))
+ val nullTypes = rddSchema.fields.map(f => getJdbcType(f.dataType, dialect).jdbcNullType)
val numFields = rddSchema.fields.length
try {
@@ -657,16 +699,16 @@ object JdbcUtils extends Logging {
df: DataFrame,
url: String,
table: String,
+ tableSchema: Option[StructType],
+ isCaseSensitive: Boolean,
options: JDBCOptions): Unit = {
val dialect = JdbcDialects.get(url)
- val nullTypes: Array[Int] = df.schema.fields.map { field =>
- getJdbcType(field.dataType, dialect).jdbcNullType
- }
-
val rddSchema = df.schema
val getConnection: () => Connection = createConnectionFactory(options)
val batchSize = options.batchSize
val isolationLevel = options.isolationLevel
+
+ val insertStmt = getInsertStatement(table, rddSchema, tableSchema, isCaseSensitive, dialect)
val repartitionedDF = options.numPartitions match {
case Some(n) if n <= 0 => throw new IllegalArgumentException(
s"Invalid value `$n` for parameter `${JDBCOptions.JDBC_NUM_PARTITIONS}` in table writing " +
@@ -675,7 +717,7 @@ object JdbcUtils extends Logging {
case _ => df
}
repartitionedDF.foreachPartition(iterator => savePartition(
- getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect, isolationLevel)
+ getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel)
)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
index f49ac23149..354af29d42 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
@@ -24,9 +24,9 @@ import scala.collection.JavaConverters.propertiesAsScalaMapConverter
import org.scalatest.BeforeAndAfter
-import org.apache.spark.SparkException
-import org.apache.spark.sql.{Row, SaveMode}
+import org.apache.spark.sql.{AnalysisException, Row, SaveMode}
import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -96,6 +96,10 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
StructField("id", IntegerType) ::
StructField("seq", IntegerType) :: Nil)
+ private lazy val schema4 = StructType(
+ StructField("NAME", StringType) ::
+ StructField("ID", IntegerType) :: Nil)
+
test("Basic CREATE") {
val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
@@ -165,6 +169,26 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
assert(2 === spark.read.jdbc(url, "TEST.APPENDTEST", new Properties()).collect()(0).length)
}
+ test("SPARK-18123 Append with column names with different cases") {
+ val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
+ val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema4)
+
+ df.write.jdbc(url, "TEST.APPENDTEST", new Properties())
+
+ withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
+ val m = intercept[AnalysisException] {
+ df2.write.mode(SaveMode.Append).jdbc(url, "TEST.APPENDTEST", new Properties())
+ }.getMessage
+ assert(m.contains("Column \"NAME\" not found"))
+ }
+
+ withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
+ df2.write.mode(SaveMode.Append).jdbc(url, "TEST.APPENDTEST", new Properties())
+ assert(3 === spark.read.jdbc(url, "TEST.APPENDTEST", new Properties()).count())
+ assert(2 === spark.read.jdbc(url, "TEST.APPENDTEST", new Properties()).collect()(0).length)
+ }
+ }
+
test("Truncate") {
JdbcDialects.registerDialect(testH2Dialect)
val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
@@ -177,7 +201,7 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
assert(1 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count())
assert(2 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length)
- val m = intercept[SparkException] {
+ val m = intercept[AnalysisException] {
df3.write.mode(SaveMode.Overwrite).option("truncate", true)
.jdbc(url1, "TEST.TRUNCATETEST", properties)
}.getMessage
@@ -203,9 +227,10 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
val df2 = spark.createDataFrame(sparkContext.parallelize(arr2x3), schema3)
df.write.jdbc(url, "TEST.INCOMPATIBLETEST", new Properties())
- intercept[org.apache.spark.SparkException] {
+ val m = intercept[AnalysisException] {
df2.write.mode(SaveMode.Append).jdbc(url, "TEST.INCOMPATIBLETEST", new Properties())
- }
+ }.getMessage
+ assert(m.contains("Column \"seq\" not found"))
}
test("INSERT to JDBC Datasource") {