aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala20
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala14
4 files changed, 41 insertions, 4 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index b2a66dd417..745bb4ec9c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -255,7 +255,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
val conn = JdbcUtils.createConnection(url, props)
try {
- var tableExists = JdbcUtils.tableExists(conn, table)
+ var tableExists = JdbcUtils.tableExists(conn, url, table)
if (mode == SaveMode.Ignore && tableExists) {
return
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index 26788b2a4f..f89d55b20e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -42,10 +42,13 @@ object JdbcUtils extends Logging {
/**
* Returns true if the table already exists in the JDBC database.
*/
- def tableExists(conn: Connection, table: String): Boolean = {
+ def tableExists(conn: Connection, url: String, table: String): Boolean = {
+ val dialect = JdbcDialects.get(url)
+
// Somewhat hacky, but there isn't a good way to identify whether a table exists for all
- // SQL database systems, considering "table" could also include the database name.
- Try(conn.prepareStatement(s"SELECT 1 FROM $table LIMIT 1").executeQuery().next()).isSuccess
+ // SQL database systems using JDBC meta data calls, considering "table" could also include
+ // the database name. Query used to find table exists can be overriden by the dialects.
+ Try(conn.prepareStatement(dialect.getTableExistsQuery(table)).executeQuery()).isSuccess
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
index c6d05c9b83..68ebaaca6c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
@@ -88,6 +88,17 @@ abstract class JdbcDialect {
def quoteIdentifier(colName: String): String = {
s""""$colName""""
}
+
+ /**
+ * Get the SQL query that should be used to find if the given table exists. Dialects can
+ * override this method to return a query that works best in a particular database.
+ * @param table The name of the table.
+ * @return The SQL query to use for checking the table.
+ */
+ def getTableExistsQuery(table: String): String = {
+ s"SELECT * FROM $table WHERE 1=0"
+ }
+
}
/**
@@ -198,6 +209,11 @@ case object PostgresDialect extends JdbcDialect {
case BooleanType => Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN))
case _ => None
}
+
+ override def getTableExistsQuery(table: String): String = {
+ s"SELECT 1 FROM $table LIMIT 1"
+ }
+
}
/**
@@ -222,6 +238,10 @@ case object MySQLDialect extends JdbcDialect {
override def quoteIdentifier(colName: String): String = {
s"`$colName`"
}
+
+ override def getTableExistsQuery(table: String): String = {
+ s"SELECT 1 FROM $table LIMIT 1"
+ }
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index ed710689cc..5ab9381de4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -450,4 +450,18 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext
assert(db2Dialect.getJDBCType(StringType).map(_.databaseTypeDefinition).get == "CLOB")
assert(db2Dialect.getJDBCType(BooleanType).map(_.databaseTypeDefinition).get == "CHAR(1)")
}
+
+ test("table exists query by jdbc dialect") {
+ val MySQL = JdbcDialects.get("jdbc:mysql://127.0.0.1/db")
+ val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db")
+ val db2 = JdbcDialects.get("jdbc:db2://127.0.0.1/db")
+ val h2 = JdbcDialects.get(url)
+ val table = "weblogs"
+ val defaultQuery = s"SELECT * FROM $table WHERE 1=0"
+ val limitQuery = s"SELECT 1 FROM $table LIMIT 1"
+ assert(MySQL.getTableExistsQuery(table) == limitQuery)
+ assert(Postgres.getTableExistsQuery(table) == limitQuery)
+ assert(db2.getTableExistsQuery(table) == defaultQuery)
+ assert(h2.getTableExistsQuery(table) == defaultQuery)
+ }
}