aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorhyukjinkwon <gurwls223@gmail.com>2016-07-26 17:14:58 +0800
committerWenchen Fan <wenchen@databricks.com>2016-07-26 17:14:58 +0800
commit3b2b785ece4394ca332377647a6305ea493f411b (patch)
treecaf1f45dcc784d570b57b346bbbb065f0bc3b435
parent03c27435aee4e319abe290771ba96e69469109ac (diff)
downloadspark-3b2b785ece4394ca332377647a6305ea493f411b.tar.gz
spark-3b2b785ece4394ca332377647a6305ea493f411b.tar.bz2
spark-3b2b785ece4394ca332377647a6305ea493f411b.zip
[SPARK-16675][SQL] Avoid per-record type dispatch in JDBC when writing
## What changes were proposed in this pull request? Currently, `JdbcUtils.savePartition` is doing type-based dispatch for each row to write appropriate values. So, appropriate setters for `PreparedStatement` can be created first according to the schema, and then apply them to each row. This approach is similar with `CatalystWriteSupport`. This PR simply make the setters to avoid this. ## How was this patch tested? Existing tests should cover this. Author: hyukjinkwon <gurwls223@gmail.com> Closes #14323 from HyukjinKwon/SPARK-16675.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala22
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala102
2 files changed, 88 insertions, 36 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index 4c98430363..e267e77c52 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -322,19 +322,19 @@ private[sql] class JDBCRDD(
}
}
- // A `JDBCValueSetter` is responsible for converting and setting a value from `ResultSet`
- // into a field for `MutableRow`. The last argument `Int` means the index for the
- // value to be set in the row and also used for the value to retrieve from `ResultSet`.
- private type JDBCValueSetter = (ResultSet, MutableRow, Int) => Unit
+ // A `JDBCValueGetter` is responsible for getting a value from `ResultSet` into a field
+ // for `MutableRow`. The last argument `Int` means the index for the value to be set in
+ // the row and also used for the value in `ResultSet`.
+ private type JDBCValueGetter = (ResultSet, MutableRow, Int) => Unit
/**
- * Creates `JDBCValueSetter`s according to [[StructType]], which can set
+ * Creates `JDBCValueGetter`s according to [[StructType]], which can set
* each value from `ResultSet` to each field of [[MutableRow]] correctly.
*/
- def makeSetters(schema: StructType): Array[JDBCValueSetter] =
- schema.fields.map(sf => makeSetter(sf.dataType, sf.metadata))
+ def makeGetters(schema: StructType): Array[JDBCValueGetter] =
+ schema.fields.map(sf => makeGetter(sf.dataType, sf.metadata))
- private def makeSetter(dt: DataType, metadata: Metadata): JDBCValueSetter = dt match {
+ private def makeGetter(dt: DataType, metadata: Metadata): JDBCValueGetter = dt match {
case BooleanType =>
(rs: ResultSet, row: MutableRow, pos: Int) =>
row.setBoolean(pos, rs.getBoolean(pos + 1))
@@ -489,15 +489,15 @@ private[sql] class JDBCRDD(
stmt.setFetchSize(fetchSize)
val rs = stmt.executeQuery()
- val setters: Array[JDBCValueSetter] = makeSetters(schema)
+ val getters: Array[JDBCValueGetter] = makeGetters(schema)
val mutableRow = new SpecificMutableRow(schema.fields.map(x => x.dataType))
def getNext(): InternalRow = {
if (rs.next()) {
inputMetrics.incRecordsRead(1)
var i = 0
- while (i < setters.length) {
- setters(i).apply(rs, mutableRow, i)
+ while (i < getters.length) {
+ getters(i).apply(rs, mutableRow, i)
if (rs.wasNull) mutableRow.setNullAt(i)
i = i + 1
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index cb474cbd0a..81d38e3699 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -154,6 +154,79 @@ object JdbcUtils extends Logging {
throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.simpleString}"))
}
+ // A `JDBCValueSetter` is responsible for setting a value from `Row` into a field for
+ // `PreparedStatement`. The last argument `Int` means the index for the value to be set
+ // in the SQL statement and also used for the value in `Row`.
+ private type JDBCValueSetter = (PreparedStatement, Row, Int) => Unit
+
+ private def makeSetter(
+ conn: Connection,
+ dialect: JdbcDialect,
+ dataType: DataType): JDBCValueSetter = dataType match {
+ case IntegerType =>
+ (stmt: PreparedStatement, row: Row, pos: Int) =>
+ stmt.setInt(pos + 1, row.getInt(pos))
+
+ case LongType =>
+ (stmt: PreparedStatement, row: Row, pos: Int) =>
+ stmt.setLong(pos + 1, row.getLong(pos))
+
+ case DoubleType =>
+ (stmt: PreparedStatement, row: Row, pos: Int) =>
+ stmt.setDouble(pos + 1, row.getDouble(pos))
+
+ case FloatType =>
+ (stmt: PreparedStatement, row: Row, pos: Int) =>
+ stmt.setFloat(pos + 1, row.getFloat(pos))
+
+ case ShortType =>
+ (stmt: PreparedStatement, row: Row, pos: Int) =>
+ stmt.setInt(pos + 1, row.getShort(pos))
+
+ case ByteType =>
+ (stmt: PreparedStatement, row: Row, pos: Int) =>
+ stmt.setInt(pos + 1, row.getByte(pos))
+
+ case BooleanType =>
+ (stmt: PreparedStatement, row: Row, pos: Int) =>
+ stmt.setBoolean(pos + 1, row.getBoolean(pos))
+
+ case StringType =>
+ (stmt: PreparedStatement, row: Row, pos: Int) =>
+ stmt.setString(pos + 1, row.getString(pos))
+
+ case BinaryType =>
+ (stmt: PreparedStatement, row: Row, pos: Int) =>
+ stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos))
+
+ case TimestampType =>
+ (stmt: PreparedStatement, row: Row, pos: Int) =>
+ stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos))
+
+ case DateType =>
+ (stmt: PreparedStatement, row: Row, pos: Int) =>
+ stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos))
+
+ case t: DecimalType =>
+ (stmt: PreparedStatement, row: Row, pos: Int) =>
+ stmt.setBigDecimal(pos + 1, row.getDecimal(pos))
+
+ case ArrayType(et, _) =>
+ // remove type length parameters from end of type name
+ val typeName = getJdbcType(et, dialect).databaseTypeDefinition
+ .toLowerCase.split("\\(")(0)
+ (stmt: PreparedStatement, row: Row, pos: Int) =>
+ val array = conn.createArrayOf(
+ typeName,
+ row.getSeq[AnyRef](pos).toArray)
+ stmt.setArray(pos + 1, array)
+
+ case _ =>
+ (_: PreparedStatement, _: Row, pos: Int) =>
+ throw new IllegalArgumentException(
+ s"Can't translate non-null value for field $pos")
+ }
+
/**
* Saves a partition of a DataFrame to the JDBC database. This is done in
* a single database transaction (unless isolation level is "NONE")
@@ -215,6 +288,9 @@ object JdbcUtils extends Logging {
conn.setTransactionIsolation(finalIsolationLevel)
}
val stmt = insertStatement(conn, table, rddSchema, dialect)
+ val setters: Array[JDBCValueSetter] = rddSchema.fields.map(_.dataType)
+ .map(makeSetter(conn, dialect, _)).toArray
+
try {
var rowCount = 0
while (iterator.hasNext) {
@@ -225,30 +301,7 @@ object JdbcUtils extends Logging {
if (row.isNullAt(i)) {
stmt.setNull(i + 1, nullTypes(i))
} else {
- rddSchema.fields(i).dataType match {
- case IntegerType => stmt.setInt(i + 1, row.getInt(i))
- case LongType => stmt.setLong(i + 1, row.getLong(i))
- case DoubleType => stmt.setDouble(i + 1, row.getDouble(i))
- case FloatType => stmt.setFloat(i + 1, row.getFloat(i))
- case ShortType => stmt.setInt(i + 1, row.getShort(i))
- case ByteType => stmt.setInt(i + 1, row.getByte(i))
- case BooleanType => stmt.setBoolean(i + 1, row.getBoolean(i))
- case StringType => stmt.setString(i + 1, row.getString(i))
- case BinaryType => stmt.setBytes(i + 1, row.getAs[Array[Byte]](i))
- case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i))
- case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i))
- case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i))
- case ArrayType(et, _) =>
- // remove type length parameters from end of type name
- val typeName = getJdbcType(et, dialect).databaseTypeDefinition
- .toLowerCase.split("\\(")(0)
- val array = conn.createArrayOf(
- typeName,
- row.getSeq[AnyRef](i).toArray)
- stmt.setArray(i + 1, array)
- case _ => throw new IllegalArgumentException(
- s"Can't translate non-null value for field $i")
- }
+ setters(i).apply(stmt, row, i)
}
i = i + 1
}
@@ -333,5 +386,4 @@ object JdbcUtils extends Logging {
getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect, isolationLevel)
)
}
-
}