diff options
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala | 9 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala | 6 |
2 files changed, 11 insertions, 4 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 3529ee6e3b..d3e1efc562 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -100,8 +100,9 @@ object JdbcUtils extends Logging { /** * Returns a PreparedStatement that inserts a row into table via conn. */ - def insertStatement(conn: Connection, table: String, rddSchema: StructType): PreparedStatement = { - val columns = rddSchema.fields.map(_.name).mkString(",") + def insertStatement(conn: Connection, table: String, rddSchema: StructType, dialect: JdbcDialect) + : PreparedStatement = { + val columns = rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",") val placeholders = rddSchema.fields.map(_ => "?").mkString(",") val sql = s"INSERT INTO $table ($columns) VALUES ($placeholders)" conn.prepareStatement(sql) @@ -177,7 +178,7 @@ object JdbcUtils extends Logging { if (supportsTransactions) { conn.setAutoCommit(false) // Everything in the same db transaction. } - val stmt = insertStatement(conn, table, rddSchema) + val stmt = insertStatement(conn, table, rddSchema, dialect) try { var rowCount = 0 while (iterator.hasNext) { @@ -260,7 +261,7 @@ object JdbcUtils extends Logging { val sb = new StringBuilder() val dialect = JdbcDialects.get(url) df.schema.fields foreach { field => - val name = field.name + val name = dialect.quoteIdentifier(field.name) val typ: String = getJdbcType(field.dataType, dialect).databaseTypeDefinition val nullable = if (field.nullable) "" else "NOT NULL" sb.append(s", $name $typ $nullable") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 11e66ad080..228e4250f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -764,4 +764,10 @@ class JDBCSuite extends SparkFunSuite assertEmptyQuery(s"SELECT * FROM tempFrame where $FALSE2") } } + + test("SPARK-16387: Reserved SQL words are not escaped by JDBC writer") { + val df = spark.createDataset(Seq("a", "b", "c")).toDF("order") + val schema = JdbcUtils.schemaString(df, "jdbc:mysql://localhost:3306/temp") + assert(schema.contains("`order` TEXT")) + } } |