diff options
3 files changed, 39 insertions, 1 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 89c850ce23..f9b72597dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -224,6 +224,7 @@ private[sql] object JDBCRDD extends Logging { quotedColumns, filters, parts, + url, properties) } } @@ -241,6 +242,7 @@ private[sql] class JDBCRDD( columns: Array[String], filters: Array[Filter], partitions: Array[Partition], + url: String, properties: Properties) extends RDD[InternalRow](sc, Nil) { @@ -361,6 +363,9 @@ private[sql] class JDBCRDD( context.addTaskCompletionListener{ context => close() } val part = thePart.asInstanceOf[JDBCPartition] val conn = getConnection() + val dialect = JdbcDialects.get(url) + import scala.collection.JavaConverters._ + dialect.beforeFetch(conn, properties.asScala.toMap) // H2's JDBC driver does not support the setSchema() method. We pass a // fully-qualified table name in the SELECT statement. I don't know how to @@ -489,6 +494,13 @@ private[sql] class JDBCRDD( } try { if (null != conn) { + if (!conn.getAutoCommit && !conn.isClosed) { + try { + conn.commit() + } catch { + case e: Throwable => logWarning("Exception committing transaction", e) + } + } conn.close() } logInfo("closed connection") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index b3b2cb6178..13db141f27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.jdbc +import java.sql.Connection + import org.apache.spark.sql.types._ import org.apache.spark.annotation.DeveloperApi @@ -97,6 +99,15 @@ abstract class JdbcDialect extends Serializable { s"SELECT * FROM $table WHERE 1=0" } + /** + * Override connection specific properties to run before a select is made. This is in place to + * allow dialects that need special treatment to optimize behavior. + * @param connection The connection object + * @param properties The connection properties. This is passed through from the relation. + */ + def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = { + } + } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index ed3faa1268..3cf80f576e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.jdbc -import java.sql.Types +import java.sql.{Connection, Types} import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils import org.apache.spark.sql.types._ @@ -70,4 +70,19 @@ private object PostgresDialect extends JdbcDialect { override def getTableExistsQuery(table: String): String = { s"SELECT 1 FROM $table LIMIT 1" } + + override def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = { + super.beforeFetch(connection, properties) + + // According to the postgres jdbc documentation we need to be in autocommit=false if we actually + // want to have fetchsize be non 0 (all the rows). This allows us to not have to cache all the + // rows inside the driver when fetching. + // + // See: https://jdbc.postgresql.org/documentation/head/query.html#query-with-cursor + // + if (properties.getOrElse("fetchsize", "0").toInt > 0) { + connection.setAutoCommit(false) + } + + } } |