aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala17
1 files changed, 14 insertions, 3 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index 2d0e736ee4..26788b2a4f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -88,13 +88,15 @@ object JdbcUtils extends Logging {
table: String,
iterator: Iterator[Row],
rddSchema: StructType,
- nullTypes: Array[Int]): Iterator[Byte] = {
+ nullTypes: Array[Int],
+ batchSize: Int): Iterator[Byte] = {
val conn = getConnection()
var committed = false
try {
conn.setAutoCommit(false) // Everything in the same db transaction.
val stmt = insertStatement(conn, table, rddSchema)
try {
+ var rowCount = 0
while (iterator.hasNext) {
val row = iterator.next()
val numFields = rddSchema.fields.length
@@ -122,7 +124,15 @@ object JdbcUtils extends Logging {
}
i = i + 1
}
- stmt.executeUpdate()
+ stmt.addBatch()
+ rowCount += 1
+ if (rowCount % batchSize == 0) {
+ stmt.executeBatch()
+ rowCount = 0
+ }
+ }
+ if (rowCount > 0) {
+ stmt.executeBatch()
}
} finally {
stmt.close()
@@ -211,8 +221,9 @@ object JdbcUtils extends Logging {
val rddSchema = df.schema
val driver: String = DriverRegistry.getDriverClassName(url)
val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties)
+ val batchSize = properties.getProperty("batchsize", "1000").toInt
df.foreachPartition { iterator =>
- savePartition(getConnection, table, iterator, rddSchema, nullTypes)
+ savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize)
}
}