diff options
author | Josh Rosen <joshrosen@databricks.com> | 2016-09-02 18:53:12 +0200 |
---|---|---|
committer | Herman van Hovell <hvanhovell@databricks.com> | 2016-09-02 18:53:12 +0200 |
commit | 6bcbf9b74351b5ac5221e3c309cb98e6f9cc7c5a (patch) | |
tree | 364adc0465598e60b7d15e3e810fa3875bd98e6c /sql/core/src | |
parent | 806d8a8e980d8ba2f4261bceb393c40bafaa2f73 (diff) | |
download | spark-6bcbf9b74351b5ac5221e3c309cb98e6f9cc7c5a.tar.gz spark-6bcbf9b74351b5ac5221e3c309cb98e6f9cc7c5a.tar.bz2 spark-6bcbf9b74351b5ac5221e3c309cb98e6f9cc7c5a.zip |
[SPARK-17351] Refactor JDBCRDD to expose ResultSet -> Seq[Row] utility methods
This patch refactors the internals of the JDBC data source in order to allow some of its code to be re-used in an automated comparison testing harness. Here are the key changes:
- Move the JDBC `ResultSetMetadata` to `StructType` conversion logic from `JDBCRDD.resolveTable()` to the `JdbcUtils` object (as a new `getSchema(ResultSet, JdbcDialect)` method), allowing it to be applied on `ResultSet`s that are created elsewhere.
- Move the `ResultSet` to `InternalRow` conversion methods from `JDBCRDD` to `JdbcUtils`:
- It makes sense to move the `JDBCValueGetter` type and `makeGetter` functions here given that their write-path counterparts (`JDBCValueSetter`) are already in `JdbcUtils`.
- Add an internal `resultSetToSparkInternalRows` method which takes a `ResultSet` and schema and returns an `Iterator[InternalRow]`. This effectively extracts the main loop of `JDBCRDD` into its own method.
- Add a public `resultSetToRows` method to `JdbcUtils`, which wraps the minimal machinery around `resultSetToSparkInternalRows` in order to allow it to be called from outside of a Spark job.
- Make `JdbcDialect.get` into a `DeveloperApi` (`JdbcDialect` itself is already a `DeveloperApi`).
Put together, these changes enable the following testing pattern:
```scala
val jdbResultSet: ResultSet = conn.prepareStatement(query).executeQuery()
val resultSchema: StructType = JdbcUtils.getSchema(jdbResultSet, JdbcDialects.get("jdbc:postgresql"))
val jdbcRows: Seq[Row] = JdbcUtils.resultSetToRows(jdbResultSet, schema).toSeq
checkAnswer(sparkResult, jdbcRows) // in a test case
```
Author: Josh Rosen <joshrosen@databricks.com>
Closes #14907 from JoshRosen/modularize-jdbc-internals.
Diffstat (limited to 'sql/core/src')
3 files changed, 335 insertions, 309 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 9b5088fbfd..a7da29f925 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.jdbc -import java.sql.{Connection, Date, ResultSet, ResultSetMetaData, SQLException, Timestamp} +import java.sql.{Connection, Date, PreparedStatement, ResultSet, SQLException, Timestamp} import java.util.Properties import scala.util.control.NonFatal @@ -28,12 +28,10 @@ import org.apache.spark.{Partition, SparkContext, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow} -import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData} import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.CompletionIterator /** * Data corresponding to one partition of a JDBCRDD. @@ -45,68 +43,6 @@ case class JDBCPartition(whereClause: String, idx: Int) extends Partition { object JDBCRDD extends Logging { /** - * Maps a JDBC type to a Catalyst type. This function is called only when - * the JdbcDialect class corresponding to your database driver returns null. - * - * @param sqlType - A field of java.sql.Types - * @return The Catalyst type corresponding to sqlType. - */ - private def getCatalystType( - sqlType: Int, - precision: Int, - scale: Int, - signed: Boolean): DataType = { - val answer = sqlType match { - // scalastyle:off - case java.sql.Types.ARRAY => null - case java.sql.Types.BIGINT => if (signed) { LongType } else { DecimalType(20,0) } - case java.sql.Types.BINARY => BinaryType - case java.sql.Types.BIT => BooleanType // @see JdbcDialect for quirks - case java.sql.Types.BLOB => BinaryType - case java.sql.Types.BOOLEAN => BooleanType - case java.sql.Types.CHAR => StringType - case java.sql.Types.CLOB => StringType - case java.sql.Types.DATALINK => null - case java.sql.Types.DATE => DateType - case java.sql.Types.DECIMAL - if precision != 0 || scale != 0 => DecimalType.bounded(precision, scale) - case java.sql.Types.DECIMAL => DecimalType.SYSTEM_DEFAULT - case java.sql.Types.DISTINCT => null - case java.sql.Types.DOUBLE => DoubleType - case java.sql.Types.FLOAT => FloatType - case java.sql.Types.INTEGER => if (signed) { IntegerType } else { LongType } - case java.sql.Types.JAVA_OBJECT => null - case java.sql.Types.LONGNVARCHAR => StringType - case java.sql.Types.LONGVARBINARY => BinaryType - case java.sql.Types.LONGVARCHAR => StringType - case java.sql.Types.NCHAR => StringType - case java.sql.Types.NCLOB => StringType - case java.sql.Types.NULL => null - case java.sql.Types.NUMERIC - if precision != 0 || scale != 0 => DecimalType.bounded(precision, scale) - case java.sql.Types.NUMERIC => DecimalType.SYSTEM_DEFAULT - case java.sql.Types.NVARCHAR => StringType - case java.sql.Types.OTHER => null - case java.sql.Types.REAL => DoubleType - case java.sql.Types.REF => StringType - case java.sql.Types.ROWID => LongType - case java.sql.Types.SMALLINT => IntegerType - case java.sql.Types.SQLXML => StringType - case java.sql.Types.STRUCT => StringType - case java.sql.Types.TIME => TimestampType - case java.sql.Types.TIMESTAMP => TimestampType - case java.sql.Types.TINYINT => IntegerType - case java.sql.Types.VARBINARY => BinaryType - case java.sql.Types.VARCHAR => StringType - case _ => null - // scalastyle:on - } - - if (answer == null) throw new SQLException("Unsupported type " + sqlType) - answer - } - - /** * Takes a (schema, table) specification and returns the table's Catalyst * schema. * @@ -126,37 +62,7 @@ object JDBCRDD extends Logging { try { val rs = statement.executeQuery() try { - val rsmd = rs.getMetaData - val ncols = rsmd.getColumnCount - val fields = new Array[StructField](ncols) - var i = 0 - while (i < ncols) { - val columnName = rsmd.getColumnLabel(i + 1) - val dataType = rsmd.getColumnType(i + 1) - val typeName = rsmd.getColumnTypeName(i + 1) - val fieldSize = rsmd.getPrecision(i + 1) - val fieldScale = rsmd.getScale(i + 1) - val isSigned = { - try { - rsmd.isSigned(i + 1) - } catch { - // Workaround for HIVE-14684: - case e: SQLException if - e.getMessage == "Method not supported" && - rsmd.getClass.getName == "org.apache.hive.jdbc.HiveResultSetMetaData" => true - } - } - val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls - val metadata = new MetadataBuilder() - .putString("name", columnName) - .putLong("scale", fieldScale) - val columnType = - dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse( - getCatalystType(dataType, fieldSize, fieldScale, isSigned)) - fields(i) = StructField(columnName, columnType, nullable, metadata.build()) - i = i + 1 - } - return new StructType(fields) + return JdbcUtils.getSchema(rs, dialect) } finally { rs.close() } @@ -331,195 +237,15 @@ private[jdbc] class JDBCRDD( } } - // A `JDBCValueGetter` is responsible for getting a value from `ResultSet` into a field - // for `MutableRow`. The last argument `Int` means the index for the value to be set in - // the row and also used for the value in `ResultSet`. - private type JDBCValueGetter = (ResultSet, MutableRow, Int) => Unit - - /** - * Creates `JDBCValueGetter`s according to [[StructType]], which can set - * each value from `ResultSet` to each field of [[MutableRow]] correctly. - */ - def makeGetters(schema: StructType): Array[JDBCValueGetter] = - schema.fields.map(sf => makeGetter(sf.dataType, sf.metadata)) - - private def makeGetter(dt: DataType, metadata: Metadata): JDBCValueGetter = dt match { - case BooleanType => - (rs: ResultSet, row: MutableRow, pos: Int) => - row.setBoolean(pos, rs.getBoolean(pos + 1)) - - case DateType => - (rs: ResultSet, row: MutableRow, pos: Int) => - // DateTimeUtils.fromJavaDate does not handle null value, so we need to check it. - val dateVal = rs.getDate(pos + 1) - if (dateVal != null) { - row.setInt(pos, DateTimeUtils.fromJavaDate(dateVal)) - } else { - row.update(pos, null) - } - - // When connecting with Oracle DB through JDBC, the precision and scale of BigDecimal - // object returned by ResultSet.getBigDecimal is not correctly matched to the table - // schema reported by ResultSetMetaData.getPrecision and ResultSetMetaData.getScale. - // If inserting values like 19999 into a column with NUMBER(12, 2) type, you get through - // a BigDecimal object with scale as 0. But the dataframe schema has correct type as - // DecimalType(12, 2). Thus, after saving the dataframe into parquet file and then - // retrieve it, you will get wrong result 199.99. - // So it is needed to set precision and scale for Decimal based on JDBC metadata. - case DecimalType.Fixed(p, s) => - (rs: ResultSet, row: MutableRow, pos: Int) => - val decimal = - nullSafeConvert[java.math.BigDecimal](rs.getBigDecimal(pos + 1), d => Decimal(d, p, s)) - row.update(pos, decimal) - - case DoubleType => - (rs: ResultSet, row: MutableRow, pos: Int) => - row.setDouble(pos, rs.getDouble(pos + 1)) - - case FloatType => - (rs: ResultSet, row: MutableRow, pos: Int) => - row.setFloat(pos, rs.getFloat(pos + 1)) - - case IntegerType => - (rs: ResultSet, row: MutableRow, pos: Int) => - row.setInt(pos, rs.getInt(pos + 1)) - - case LongType if metadata.contains("binarylong") => - (rs: ResultSet, row: MutableRow, pos: Int) => - val bytes = rs.getBytes(pos + 1) - var ans = 0L - var j = 0 - while (j < bytes.size) { - ans = 256 * ans + (255 & bytes(j)) - j = j + 1 - } - row.setLong(pos, ans) - - case LongType => - (rs: ResultSet, row: MutableRow, pos: Int) => - row.setLong(pos, rs.getLong(pos + 1)) - - case ShortType => - (rs: ResultSet, row: MutableRow, pos: Int) => - row.setShort(pos, rs.getShort(pos + 1)) - - case StringType => - (rs: ResultSet, row: MutableRow, pos: Int) => - // TODO(davies): use getBytes for better performance, if the encoding is UTF-8 - row.update(pos, UTF8String.fromString(rs.getString(pos + 1))) - - case TimestampType => - (rs: ResultSet, row: MutableRow, pos: Int) => - val t = rs.getTimestamp(pos + 1) - if (t != null) { - row.setLong(pos, DateTimeUtils.fromJavaTimestamp(t)) - } else { - row.update(pos, null) - } - - case BinaryType => - (rs: ResultSet, row: MutableRow, pos: Int) => - row.update(pos, rs.getBytes(pos + 1)) - - case ArrayType(et, _) => - val elementConversion = et match { - case TimestampType => - (array: Object) => - array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp => - nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp) - } - - case StringType => - (array: Object) => - array.asInstanceOf[Array[java.lang.String]] - .map(UTF8String.fromString) - - case DateType => - (array: Object) => - array.asInstanceOf[Array[java.sql.Date]].map { date => - nullSafeConvert(date, DateTimeUtils.fromJavaDate) - } - - case dt: DecimalType => - (array: Object) => - array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal => - nullSafeConvert[java.math.BigDecimal]( - decimal, d => Decimal(d, dt.precision, dt.scale)) - } - - case LongType if metadata.contains("binarylong") => - throw new IllegalArgumentException(s"Unsupported array element " + - s"type ${dt.simpleString} based on binary") - - case ArrayType(_, _) => - throw new IllegalArgumentException("Nested arrays unsupported") - - case _ => (array: Object) => array.asInstanceOf[Array[Any]] - } - - (rs: ResultSet, row: MutableRow, pos: Int) => - val array = nullSafeConvert[Object]( - rs.getArray(pos + 1).getArray, - array => new GenericArrayData(elementConversion.apply(array))) - row.update(pos, array) - - case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.simpleString}") - } - /** * Runs the SQL query against the JDBC driver. * */ - override def compute(thePart: Partition, context: TaskContext): Iterator[InternalRow] = - new Iterator[InternalRow] { + override def compute(thePart: Partition, context: TaskContext): Iterator[InternalRow] = { var closed = false - var finished = false - var gotNext = false - var nextValue: InternalRow = null - - context.addTaskCompletionListener{ context => close() } - val inputMetrics = context.taskMetrics().inputMetrics - val part = thePart.asInstanceOf[JDBCPartition] - val conn = getConnection() - val dialect = JdbcDialects.get(url) - import scala.collection.JavaConverters._ - dialect.beforeFetch(conn, properties.asScala.toMap) - - // H2's JDBC driver does not support the setSchema() method. We pass a - // fully-qualified table name in the SELECT statement. I don't know how to - // talk about a table in a completely portable way. - - val myWhereClause = getWhereClause(part) - - val sqlText = s"SELECT $columnList FROM $fqTable $myWhereClause" - val stmt = conn.prepareStatement(sqlText, - ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) - val fetchSize = properties.getProperty(JdbcUtils.JDBC_BATCH_FETCH_SIZE, "0").toInt - require(fetchSize >= 0, - s"Invalid value `${fetchSize.toString}` for parameter " + - s"`${JdbcUtils.JDBC_BATCH_FETCH_SIZE}`. The minimum value is 0. When the value is 0, " + - "the JDBC driver ignores the value and does the estimates.") - stmt.setFetchSize(fetchSize) - val rs = stmt.executeQuery() - - val getters: Array[JDBCValueGetter] = makeGetters(schema) - val mutableRow = new SpecificMutableRow(schema.fields.map(x => x.dataType)) - - def getNext(): InternalRow = { - if (rs.next()) { - inputMetrics.incRecordsRead(1) - var i = 0 - while (i < getters.length) { - getters(i).apply(rs, mutableRow, i) - if (rs.wasNull) mutableRow.setNullAt(i) - i = i + 1 - } - mutableRow - } else { - finished = true - null.asInstanceOf[InternalRow] - } - } + var rs: ResultSet = null + var stmt: PreparedStatement = null + var conn: Connection = null def close() { if (closed) return @@ -555,33 +281,33 @@ private[jdbc] class JDBCRDD( closed = true } - override def hasNext: Boolean = { - if (!finished) { - if (!gotNext) { - nextValue = getNext() - if (finished) { - close() - } - gotNext = true - } - } - !finished - } + context.addTaskCompletionListener{ context => close() } - override def next(): InternalRow = { - if (!hasNext) { - throw new NoSuchElementException("End of stream") - } - gotNext = false - nextValue - } - } + val inputMetrics = context.taskMetrics().inputMetrics + val part = thePart.asInstanceOf[JDBCPartition] + conn = getConnection() + val dialect = JdbcDialects.get(url) + import scala.collection.JavaConverters._ + dialect.beforeFetch(conn, properties.asScala.toMap) - private def nullSafeConvert[T](input: T, f: T => Any): Any = { - if (input == null) { - null - } else { - f(input) - } + // H2's JDBC driver does not support the setSchema() method. We pass a + // fully-qualified table name in the SELECT statement. I don't know how to + // talk about a table in a completely portable way. + + val myWhereClause = getWhereClause(part) + + val sqlText = s"SELECT $columnList FROM $fqTable $myWhereClause" + stmt = conn.prepareStatement(sqlText, + ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) + val fetchSize = properties.getProperty(JdbcUtils.JDBC_BATCH_FETCH_SIZE, "0").toInt + require(fetchSize >= 0, + s"Invalid value `${fetchSize.toString}` for parameter " + + s"`${JdbcUtils.JDBC_BATCH_FETCH_SIZE}`. The minimum value is 0. When the value is 0, " + + "the JDBC driver ignores the value and does the estimates.") + stmt.setFetchSize(fetchSize) + rs = stmt.executeQuery() + val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics) + + CompletionIterator[InternalRow, Iterator[InternalRow]](rowsIterator, close()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 37153e545a..132472ad0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -17,17 +17,25 @@ package org.apache.spark.sql.execution.datasources.jdbc -import java.sql.{Connection, Driver, DriverManager, PreparedStatement, SQLException} +import java.sql.{Connection, Driver, DriverManager, PreparedStatement, ResultSet, ResultSetMetaData, SQLException} import java.util.Properties import scala.collection.JavaConverters._ import scala.util.Try import scala.util.control.NonFatal +import org.apache.spark.TaskContext +import org.apache.spark.executor.InputMetrics import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow} +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData} import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.NextIterator /** * Util functions for JDBC tables. @@ -127,6 +135,7 @@ object JdbcUtils extends Logging { /** * Retrieve standard jdbc types. + * * @param dt The datatype (e.g. [[org.apache.spark.sql.types.StringType]]) * @return The default JdbcType for this DataType */ @@ -154,6 +163,297 @@ object JdbcUtils extends Logging { throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.simpleString}")) } + /** + * Maps a JDBC type to a Catalyst type. This function is called only when + * the JdbcDialect class corresponding to your database driver returns null. + * + * @param sqlType - A field of java.sql.Types + * @return The Catalyst type corresponding to sqlType. + */ + private def getCatalystType( + sqlType: Int, + precision: Int, + scale: Int, + signed: Boolean): DataType = { + val answer = sqlType match { + // scalastyle:off + case java.sql.Types.ARRAY => null + case java.sql.Types.BIGINT => if (signed) { LongType } else { DecimalType(20,0) } + case java.sql.Types.BINARY => BinaryType + case java.sql.Types.BIT => BooleanType // @see JdbcDialect for quirks + case java.sql.Types.BLOB => BinaryType + case java.sql.Types.BOOLEAN => BooleanType + case java.sql.Types.CHAR => StringType + case java.sql.Types.CLOB => StringType + case java.sql.Types.DATALINK => null + case java.sql.Types.DATE => DateType + case java.sql.Types.DECIMAL + if precision != 0 || scale != 0 => DecimalType.bounded(precision, scale) + case java.sql.Types.DECIMAL => DecimalType.SYSTEM_DEFAULT + case java.sql.Types.DISTINCT => null + case java.sql.Types.DOUBLE => DoubleType + case java.sql.Types.FLOAT => FloatType + case java.sql.Types.INTEGER => if (signed) { IntegerType } else { LongType } + case java.sql.Types.JAVA_OBJECT => null + case java.sql.Types.LONGNVARCHAR => StringType + case java.sql.Types.LONGVARBINARY => BinaryType + case java.sql.Types.LONGVARCHAR => StringType + case java.sql.Types.NCHAR => StringType + case java.sql.Types.NCLOB => StringType + case java.sql.Types.NULL => null + case java.sql.Types.NUMERIC + if precision != 0 || scale != 0 => DecimalType.bounded(precision, scale) + case java.sql.Types.NUMERIC => DecimalType.SYSTEM_DEFAULT + case java.sql.Types.NVARCHAR => StringType + case java.sql.Types.OTHER => null + case java.sql.Types.REAL => DoubleType + case java.sql.Types.REF => StringType + case java.sql.Types.ROWID => LongType + case java.sql.Types.SMALLINT => IntegerType + case java.sql.Types.SQLXML => StringType + case java.sql.Types.STRUCT => StringType + case java.sql.Types.TIME => TimestampType + case java.sql.Types.TIMESTAMP => TimestampType + case java.sql.Types.TINYINT => IntegerType + case java.sql.Types.VARBINARY => BinaryType + case java.sql.Types.VARCHAR => StringType + case _ => null + // scalastyle:on + } + + if (answer == null) throw new SQLException("Unsupported type " + sqlType) + answer + } + + /** + * Takes a [[ResultSet]] and returns its Catalyst schema. + * + * @return A [[StructType]] giving the Catalyst schema. + * @throws SQLException if the schema contains an unsupported type. + */ + def getSchema(resultSet: ResultSet, dialect: JdbcDialect): StructType = { + val rsmd = resultSet.getMetaData + val ncols = rsmd.getColumnCount + val fields = new Array[StructField](ncols) + var i = 0 + while (i < ncols) { + val columnName = rsmd.getColumnLabel(i + 1) + val dataType = rsmd.getColumnType(i + 1) + val typeName = rsmd.getColumnTypeName(i + 1) + val fieldSize = rsmd.getPrecision(i + 1) + val fieldScale = rsmd.getScale(i + 1) + val isSigned = { + try { + rsmd.isSigned(i + 1) + } catch { + // Workaround for HIVE-14684: + case e: SQLException if + e.getMessage == "Method not supported" && + rsmd.getClass.getName == "org.apache.hive.jdbc.HiveResultSetMetaData" => true + } + } + val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls + val metadata = new MetadataBuilder() + .putString("name", columnName) + .putLong("scale", fieldScale) + val columnType = + dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse( + getCatalystType(dataType, fieldSize, fieldScale, isSigned)) + fields(i) = StructField(columnName, columnType, nullable, metadata.build()) + i = i + 1 + } + new StructType(fields) + } + + /** + * Convert a [[ResultSet]] into an iterator of Catalyst Rows. + */ + def resultSetToRows(resultSet: ResultSet, schema: StructType): Iterator[Row] = { + val inputMetrics = + Option(TaskContext.get()).map(_.taskMetrics().inputMetrics).getOrElse(new InputMetrics) + val encoder = RowEncoder(schema).resolveAndBind() + val internalRows = resultSetToSparkInternalRows(resultSet, schema, inputMetrics) + internalRows.map(encoder.fromRow) + } + + private[spark] def resultSetToSparkInternalRows( + resultSet: ResultSet, + schema: StructType, + inputMetrics: InputMetrics): Iterator[InternalRow] = { + new NextIterator[InternalRow] { + private[this] val rs = resultSet + private[this] val getters: Array[JDBCValueGetter] = makeGetters(schema) + private[this] val mutableRow = new SpecificMutableRow(schema.fields.map(x => x.dataType)) + + override protected def close(): Unit = { + try { + rs.close() + } catch { + case e: Exception => logWarning("Exception closing resultset", e) + } + } + + override protected def getNext(): InternalRow = { + if (rs.next()) { + inputMetrics.incRecordsRead(1) + var i = 0 + while (i < getters.length) { + getters(i).apply(rs, mutableRow, i) + if (rs.wasNull) mutableRow.setNullAt(i) + i = i + 1 + } + mutableRow + } else { + finished = true + null.asInstanceOf[InternalRow] + } + } + } + } + + // A `JDBCValueGetter` is responsible for getting a value from `ResultSet` into a field + // for `MutableRow`. The last argument `Int` means the index for the value to be set in + // the row and also used for the value in `ResultSet`. + private type JDBCValueGetter = (ResultSet, MutableRow, Int) => Unit + + /** + * Creates `JDBCValueGetter`s according to [[StructType]], which can set + * each value from `ResultSet` to each field of [[MutableRow]] correctly. + */ + private def makeGetters(schema: StructType): Array[JDBCValueGetter] = + schema.fields.map(sf => makeGetter(sf.dataType, sf.metadata)) + + private def makeGetter(dt: DataType, metadata: Metadata): JDBCValueGetter = dt match { + case BooleanType => + (rs: ResultSet, row: MutableRow, pos: Int) => + row.setBoolean(pos, rs.getBoolean(pos + 1)) + + case DateType => + (rs: ResultSet, row: MutableRow, pos: Int) => + // DateTimeUtils.fromJavaDate does not handle null value, so we need to check it. + val dateVal = rs.getDate(pos + 1) + if (dateVal != null) { + row.setInt(pos, DateTimeUtils.fromJavaDate(dateVal)) + } else { + row.update(pos, null) + } + + // When connecting with Oracle DB through JDBC, the precision and scale of BigDecimal + // object returned by ResultSet.getBigDecimal is not correctly matched to the table + // schema reported by ResultSetMetaData.getPrecision and ResultSetMetaData.getScale. + // If inserting values like 19999 into a column with NUMBER(12, 2) type, you get through + // a BigDecimal object with scale as 0. But the dataframe schema has correct type as + // DecimalType(12, 2). Thus, after saving the dataframe into parquet file and then + // retrieve it, you will get wrong result 199.99. + // So it is needed to set precision and scale for Decimal based on JDBC metadata. + case DecimalType.Fixed(p, s) => + (rs: ResultSet, row: MutableRow, pos: Int) => + val decimal = + nullSafeConvert[java.math.BigDecimal](rs.getBigDecimal(pos + 1), d => Decimal(d, p, s)) + row.update(pos, decimal) + + case DoubleType => + (rs: ResultSet, row: MutableRow, pos: Int) => + row.setDouble(pos, rs.getDouble(pos + 1)) + + case FloatType => + (rs: ResultSet, row: MutableRow, pos: Int) => + row.setFloat(pos, rs.getFloat(pos + 1)) + + case IntegerType => + (rs: ResultSet, row: MutableRow, pos: Int) => + row.setInt(pos, rs.getInt(pos + 1)) + + case LongType if metadata.contains("binarylong") => + (rs: ResultSet, row: MutableRow, pos: Int) => + val bytes = rs.getBytes(pos + 1) + var ans = 0L + var j = 0 + while (j < bytes.size) { + ans = 256 * ans + (255 & bytes(j)) + j = j + 1 + } + row.setLong(pos, ans) + + case LongType => + (rs: ResultSet, row: MutableRow, pos: Int) => + row.setLong(pos, rs.getLong(pos + 1)) + + case ShortType => + (rs: ResultSet, row: MutableRow, pos: Int) => + row.setShort(pos, rs.getShort(pos + 1)) + + case StringType => + (rs: ResultSet, row: MutableRow, pos: Int) => + // TODO(davies): use getBytes for better performance, if the encoding is UTF-8 + row.update(pos, UTF8String.fromString(rs.getString(pos + 1))) + + case TimestampType => + (rs: ResultSet, row: MutableRow, pos: Int) => + val t = rs.getTimestamp(pos + 1) + if (t != null) { + row.setLong(pos, DateTimeUtils.fromJavaTimestamp(t)) + } else { + row.update(pos, null) + } + + case BinaryType => + (rs: ResultSet, row: MutableRow, pos: Int) => + row.update(pos, rs.getBytes(pos + 1)) + + case ArrayType(et, _) => + val elementConversion = et match { + case TimestampType => + (array: Object) => + array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp => + nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp) + } + + case StringType => + (array: Object) => + array.asInstanceOf[Array[java.lang.String]] + .map(UTF8String.fromString) + + case DateType => + (array: Object) => + array.asInstanceOf[Array[java.sql.Date]].map { date => + nullSafeConvert(date, DateTimeUtils.fromJavaDate) + } + + case dt: DecimalType => + (array: Object) => + array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal => + nullSafeConvert[java.math.BigDecimal]( + decimal, d => Decimal(d, dt.precision, dt.scale)) + } + + case LongType if metadata.contains("binarylong") => + throw new IllegalArgumentException(s"Unsupported array element " + + s"type ${dt.simpleString} based on binary") + + case ArrayType(_, _) => + throw new IllegalArgumentException("Nested arrays unsupported") + + case _ => (array: Object) => array.asInstanceOf[Array[Any]] + } + + (rs: ResultSet, row: MutableRow, pos: Int) => + val array = nullSafeConvert[Object]( + rs.getArray(pos + 1).getArray, + array => new GenericArrayData(elementConversion.apply(array))) + row.update(pos, array) + + case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.simpleString}") + } + + private def nullSafeConvert[T](input: T, f: T => Any): Any = { + if (input == null) { + null + } else { + f(input) + } + } + // A `JDBCValueSetter` is responsible for setting a value from `Row` into a field for // `PreparedStatement`. The last argument `Int` means the index for the value to be set // in the SQL statement and also used for the value in `Row`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 78107809a1..3a6d5b7f1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -162,7 +162,7 @@ object JdbcDialects { /** * Fetch the JdbcDialect class corresponding to a given database url. */ - private[sql] def get(url: String): JdbcDialect = { + def get(url: String): JdbcDialect = { val matchingDialects = dialects.filter(_.canHandle(url)) matchingDialects.length match { case 0 => NoopDialect |