diff options
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala | 17 |
1 files changed, 14 insertions, 3 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 2d0e736ee4..26788b2a4f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -88,13 +88,15 @@ object JdbcUtils extends Logging { table: String, iterator: Iterator[Row], rddSchema: StructType, - nullTypes: Array[Int]): Iterator[Byte] = { + nullTypes: Array[Int], + batchSize: Int): Iterator[Byte] = { val conn = getConnection() var committed = false try { conn.setAutoCommit(false) // Everything in the same db transaction. val stmt = insertStatement(conn, table, rddSchema) try { + var rowCount = 0 while (iterator.hasNext) { val row = iterator.next() val numFields = rddSchema.fields.length @@ -122,7 +124,15 @@ object JdbcUtils extends Logging { } i = i + 1 } - stmt.executeUpdate() + stmt.addBatch() + rowCount += 1 + if (rowCount % batchSize == 0) { + stmt.executeBatch() + rowCount = 0 + } + } + if (rowCount > 0) { + stmt.executeBatch() } } finally { stmt.close() @@ -211,8 +221,9 @@ object JdbcUtils extends Logging { val rddSchema = df.schema val driver: String = DriverRegistry.getDriverClassName(url) val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties) + val batchSize = properties.getProperty("batchsize", "1000").toInt df.foreachPartition { iterator => - savePartition(getConnection, table, iterator, rddSchema, nullTypes) + savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize) } } |