aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorhyukjinkwon <gurwls223@gmail.com>2016-07-25 19:57:47 +0800
committerWenchen Fan <wenchen@databricks.com>2016-07-25 19:57:47 +0800
commit7ffd99ec5f267730734431097cbb700ad074bebe (patch)
treebe4f281ee7b1af42e60245deb0fd72c968ab1309 /sql
parent68b4020d0c0d4f063facfbf4639ef4251dcfda8b (diff)
downloadspark-7ffd99ec5f267730734431097cbb700ad074bebe.tar.gz
spark-7ffd99ec5f267730734431097cbb700ad074bebe.tar.bz2
spark-7ffd99ec5f267730734431097cbb700ad074bebe.zip
[SPARK-16674][SQL] Avoid per-record type dispatch in JDBC when reading
## What changes were proposed in this pull request? Currently, `JDBCRDD.compute` is doing type dispatch for each row to read appropriate values. It might not have to be done like this because the schema is already kept in `JDBCRDD`. So, appropriate converters can be created first according to the schema, and then apply them to each row. ## How was this patch tested? Existing tests should cover this. Author: hyukjinkwon <gurwls223@gmail.com> Closes #14313 from HyukjinKwon/SPARK-16674.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala245
1 files changed, 129 insertions, 116 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index 24e2c1a5fd..4c98430363 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -28,7 +28,7 @@ import org.apache.spark.{Partition, SparkContext, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow
+import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow}
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.sources._
@@ -322,43 +322,134 @@ private[sql] class JDBCRDD(
}
}
- // Each JDBC-to-Catalyst conversion corresponds to a tag defined here so that
- // we don't have to potentially poke around in the Metadata once for every
- // row.
- // Is there a better way to do this? I'd rather be using a type that
- // contains only the tags I define.
- abstract class JDBCConversion
- case object BooleanConversion extends JDBCConversion
- case object DateConversion extends JDBCConversion
- case class DecimalConversion(precision: Int, scale: Int) extends JDBCConversion
- case object DoubleConversion extends JDBCConversion
- case object FloatConversion extends JDBCConversion
- case object IntegerConversion extends JDBCConversion
- case object LongConversion extends JDBCConversion
- case object BinaryLongConversion extends JDBCConversion
- case object StringConversion extends JDBCConversion
- case object TimestampConversion extends JDBCConversion
- case object BinaryConversion extends JDBCConversion
- case class ArrayConversion(elementConversion: JDBCConversion) extends JDBCConversion
+ // A `JDBCValueSetter` is responsible for converting and setting a value from `ResultSet`
+ // into a field for `MutableRow`. The last argument `Int` means the index for the
+ // value to be set in the row and also used for the value to retrieve from `ResultSet`.
+ private type JDBCValueSetter = (ResultSet, MutableRow, Int) => Unit
/**
- * Maps a StructType to a type tag list.
+ * Creates `JDBCValueSetter`s according to [[StructType]], which can set
+ * each value from `ResultSet` to each field of [[MutableRow]] correctly.
*/
- def getConversions(schema: StructType): Array[JDBCConversion] =
- schema.fields.map(sf => getConversions(sf.dataType, sf.metadata))
-
- private def getConversions(dt: DataType, metadata: Metadata): JDBCConversion = dt match {
- case BooleanType => BooleanConversion
- case DateType => DateConversion
- case DecimalType.Fixed(p, s) => DecimalConversion(p, s)
- case DoubleType => DoubleConversion
- case FloatType => FloatConversion
- case IntegerType => IntegerConversion
- case LongType => if (metadata.contains("binarylong")) BinaryLongConversion else LongConversion
- case StringType => StringConversion
- case TimestampType => TimestampConversion
- case BinaryType => BinaryConversion
- case ArrayType(et, _) => ArrayConversion(getConversions(et, metadata))
+ def makeSetters(schema: StructType): Array[JDBCValueSetter] =
+ schema.fields.map(sf => makeSetter(sf.dataType, sf.metadata))
+
+ private def makeSetter(dt: DataType, metadata: Metadata): JDBCValueSetter = dt match {
+ case BooleanType =>
+ (rs: ResultSet, row: MutableRow, pos: Int) =>
+ row.setBoolean(pos, rs.getBoolean(pos + 1))
+
+ case DateType =>
+ (rs: ResultSet, row: MutableRow, pos: Int) =>
+ // DateTimeUtils.fromJavaDate does not handle null value, so we need to check it.
+ val dateVal = rs.getDate(pos + 1)
+ if (dateVal != null) {
+ row.setInt(pos, DateTimeUtils.fromJavaDate(dateVal))
+ } else {
+ row.update(pos, null)
+ }
+
+ // When connecting with Oracle DB through JDBC, the precision and scale of BigDecimal
+ // object returned by ResultSet.getBigDecimal is not correctly matched to the table
+ // schema reported by ResultSetMetaData.getPrecision and ResultSetMetaData.getScale.
+ // If inserting values like 19999 into a column with NUMBER(12, 2) type, you get through
+ // a BigDecimal object with scale as 0. But the dataframe schema has correct type as
+ // DecimalType(12, 2). Thus, after saving the dataframe into parquet file and then
+ // retrieve it, you will get wrong result 199.99.
+ // So it is needed to set precision and scale for Decimal based on JDBC metadata.
+ case DecimalType.Fixed(p, s) =>
+ (rs: ResultSet, row: MutableRow, pos: Int) =>
+ val decimal =
+ nullSafeConvert[java.math.BigDecimal](rs.getBigDecimal(pos + 1), d => Decimal(d, p, s))
+ row.update(pos, decimal)
+
+ case DoubleType =>
+ (rs: ResultSet, row: MutableRow, pos: Int) =>
+ row.setDouble(pos, rs.getDouble(pos + 1))
+
+ case FloatType =>
+ (rs: ResultSet, row: MutableRow, pos: Int) =>
+ row.setFloat(pos, rs.getFloat(pos + 1))
+
+ case IntegerType =>
+ (rs: ResultSet, row: MutableRow, pos: Int) =>
+ row.setInt(pos, rs.getInt(pos + 1))
+
+ case LongType if metadata.contains("binarylong") =>
+ (rs: ResultSet, row: MutableRow, pos: Int) =>
+ val bytes = rs.getBytes(pos + 1)
+ var ans = 0L
+ var j = 0
+ while (j < bytes.size) {
+ ans = 256 * ans + (255 & bytes(j))
+ j = j + 1
+ }
+ row.setLong(pos, ans)
+
+ case LongType =>
+ (rs: ResultSet, row: MutableRow, pos: Int) =>
+ row.setLong(pos, rs.getLong(pos + 1))
+
+ case StringType =>
+ (rs: ResultSet, row: MutableRow, pos: Int) =>
+ // TODO(davies): use getBytes for better performance, if the encoding is UTF-8
+ row.update(pos, UTF8String.fromString(rs.getString(pos + 1)))
+
+ case TimestampType =>
+ (rs: ResultSet, row: MutableRow, pos: Int) =>
+ val t = rs.getTimestamp(pos + 1)
+ if (t != null) {
+ row.setLong(pos, DateTimeUtils.fromJavaTimestamp(t))
+ } else {
+ row.update(pos, null)
+ }
+
+ case BinaryType =>
+ (rs: ResultSet, row: MutableRow, pos: Int) =>
+ row.update(pos, rs.getBytes(pos + 1))
+
+ case ArrayType(et, _) =>
+ val elementConversion = et match {
+ case TimestampType =>
+ (array: Object) =>
+ array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp =>
+ nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp)
+ }
+
+ case StringType =>
+ (array: Object) =>
+ array.asInstanceOf[Array[java.lang.String]]
+ .map(UTF8String.fromString)
+
+ case DateType =>
+ (array: Object) =>
+ array.asInstanceOf[Array[java.sql.Date]].map { date =>
+ nullSafeConvert(date, DateTimeUtils.fromJavaDate)
+ }
+
+ case dt: DecimalType =>
+ (array: Object) =>
+ array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal =>
+ nullSafeConvert[java.math.BigDecimal](
+ decimal, d => Decimal(d, dt.precision, dt.scale))
+ }
+
+ case LongType if metadata.contains("binarylong") =>
+ throw new IllegalArgumentException(s"Unsupported array element " +
+ s"type ${dt.simpleString} based on binary")
+
+ case ArrayType(_, _) =>
+ throw new IllegalArgumentException("Nested arrays unsupported")
+
+ case _ => (array: Object) => array.asInstanceOf[Array[Any]]
+ }
+
+ (rs: ResultSet, row: MutableRow, pos: Int) =>
+ val array = nullSafeConvert[Object](
+ rs.getArray(pos + 1).getArray,
+ array => new GenericArrayData(elementConversion.apply(array)))
+ row.update(pos, array)
+
case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.simpleString}")
}
@@ -398,93 +489,15 @@ private[sql] class JDBCRDD(
stmt.setFetchSize(fetchSize)
val rs = stmt.executeQuery()
- val conversions = getConversions(schema)
+ val setters: Array[JDBCValueSetter] = makeSetters(schema)
val mutableRow = new SpecificMutableRow(schema.fields.map(x => x.dataType))
def getNext(): InternalRow = {
if (rs.next()) {
inputMetrics.incRecordsRead(1)
var i = 0
- while (i < conversions.length) {
- val pos = i + 1
- conversions(i) match {
- case BooleanConversion => mutableRow.setBoolean(i, rs.getBoolean(pos))
- case DateConversion =>
- // DateTimeUtils.fromJavaDate does not handle null value, so we need to check it.
- val dateVal = rs.getDate(pos)
- if (dateVal != null) {
- mutableRow.setInt(i, DateTimeUtils.fromJavaDate(dateVal))
- } else {
- mutableRow.update(i, null)
- }
- // When connecting with Oracle DB through JDBC, the precision and scale of BigDecimal
- // object returned by ResultSet.getBigDecimal is not correctly matched to the table
- // schema reported by ResultSetMetaData.getPrecision and ResultSetMetaData.getScale.
- // If inserting values like 19999 into a column with NUMBER(12, 2) type, you get through
- // a BigDecimal object with scale as 0. But the dataframe schema has correct type as
- // DecimalType(12, 2). Thus, after saving the dataframe into parquet file and then
- // retrieve it, you will get wrong result 199.99.
- // So it is needed to set precision and scale for Decimal based on JDBC metadata.
- case DecimalConversion(p, s) =>
- val decimalVal = rs.getBigDecimal(pos)
- if (decimalVal == null) {
- mutableRow.update(i, null)
- } else {
- mutableRow.update(i, Decimal(decimalVal, p, s))
- }
- case DoubleConversion => mutableRow.setDouble(i, rs.getDouble(pos))
- case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos))
- case IntegerConversion => mutableRow.setInt(i, rs.getInt(pos))
- case LongConversion => mutableRow.setLong(i, rs.getLong(pos))
- // TODO(davies): use getBytes for better performance, if the encoding is UTF-8
- case StringConversion => mutableRow.update(i, UTF8String.fromString(rs.getString(pos)))
- case TimestampConversion =>
- val t = rs.getTimestamp(pos)
- if (t != null) {
- mutableRow.setLong(i, DateTimeUtils.fromJavaTimestamp(t))
- } else {
- mutableRow.update(i, null)
- }
- case BinaryConversion => mutableRow.update(i, rs.getBytes(pos))
- case BinaryLongConversion =>
- val bytes = rs.getBytes(pos)
- var ans = 0L
- var j = 0
- while (j < bytes.size) {
- ans = 256 * ans + (255 & bytes(j))
- j = j + 1
- }
- mutableRow.setLong(i, ans)
- case ArrayConversion(elementConversion) =>
- val array = rs.getArray(pos).getArray
- if (array != null) {
- val data = elementConversion match {
- case TimestampConversion =>
- array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp =>
- nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp)
- }
- case StringConversion =>
- array.asInstanceOf[Array[java.lang.String]]
- .map(UTF8String.fromString)
- case DateConversion =>
- array.asInstanceOf[Array[java.sql.Date]].map { date =>
- nullSafeConvert(date, DateTimeUtils.fromJavaDate)
- }
- case DecimalConversion(p, s) =>
- array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal =>
- nullSafeConvert[java.math.BigDecimal](decimal, d => Decimal(d, p, s))
- }
- case BinaryLongConversion =>
- throw new IllegalArgumentException(s"Unsupported array element conversion $i")
- case _: ArrayConversion =>
- throw new IllegalArgumentException("Nested arrays unsupported")
- case _ => array.asInstanceOf[Array[Any]]
- }
- mutableRow.update(i, new GenericArrayData(data))
- } else {
- mutableRow.update(i, null)
- }
- }
+ while (i < setters.length) {
+ setters(i).apply(rs, mutableRow, i)
if (rs.wasNull) mutableRow.setNullAt(i)
i = i + 1
}