aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2016-09-02 18:53:12 +0200
committerHerman van Hovell <hvanhovell@databricks.com>2016-09-02 18:53:12 +0200
commit6bcbf9b74351b5ac5221e3c309cb98e6f9cc7c5a (patch)
tree364adc0465598e60b7d15e3e810fa3875bd98e6c /sql
parent806d8a8e980d8ba2f4261bceb393c40bafaa2f73 (diff)
downloadspark-6bcbf9b74351b5ac5221e3c309cb98e6f9cc7c5a.tar.gz
spark-6bcbf9b74351b5ac5221e3c309cb98e6f9cc7c5a.tar.bz2
spark-6bcbf9b74351b5ac5221e3c309cb98e6f9cc7c5a.zip
[SPARK-17351] Refactor JDBCRDD to expose ResultSet -> Seq[Row] utility methods
This patch refactors the internals of the JDBC data source in order to allow some of its code to be re-used in an automated comparison testing harness. Here are the key changes: - Move the JDBC `ResultSetMetadata` to `StructType` conversion logic from `JDBCRDD.resolveTable()` to the `JdbcUtils` object (as a new `getSchema(ResultSet, JdbcDialect)` method), allowing it to be applied on `ResultSet`s that are created elsewhere. - Move the `ResultSet` to `InternalRow` conversion methods from `JDBCRDD` to `JdbcUtils`: - It makes sense to move the `JDBCValueGetter` type and `makeGetter` functions here given that their write-path counterparts (`JDBCValueSetter`) are already in `JdbcUtils`. - Add an internal `resultSetToSparkInternalRows` method which takes a `ResultSet` and schema and returns an `Iterator[InternalRow]`. This effectively extracts the main loop of `JDBCRDD` into its own method. - Add a public `resultSetToRows` method to `JdbcUtils`, which wraps the minimal machinery around `resultSetToSparkInternalRows` in order to allow it to be called from outside of a Spark job. - Make `JdbcDialect.get` into a `DeveloperApi` (`JdbcDialect` itself is already a `DeveloperApi`). Put together, these changes enable the following testing pattern: ```scala val jdbResultSet: ResultSet = conn.prepareStatement(query).executeQuery() val resultSchema: StructType = JdbcUtils.getSchema(jdbResultSet, JdbcDialects.get("jdbc:postgresql")) val jdbcRows: Seq[Row] = JdbcUtils.resultSetToRows(jdbResultSet, schema).toSeq checkAnswer(sparkResult, jdbcRows) // in a test case ``` Author: Josh Rosen <joshrosen@databricks.com> Closes #14907 from JoshRosen/modularize-jdbc-internals.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala340
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala302
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala2
3 files changed, 335 insertions, 309 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index 9b5088fbfd..a7da29f925 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.execution.datasources.jdbc
-import java.sql.{Connection, Date, ResultSet, ResultSetMetaData, SQLException, Timestamp}
+import java.sql.{Connection, Date, PreparedStatement, ResultSet, SQLException, Timestamp}
import java.util.Properties
import scala.util.control.NonFatal
@@ -28,12 +28,10 @@ import org.apache.spark.{Partition, SparkContext, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow}
-import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.CompletionIterator
/**
* Data corresponding to one partition of a JDBCRDD.
@@ -45,68 +43,6 @@ case class JDBCPartition(whereClause: String, idx: Int) extends Partition {
object JDBCRDD extends Logging {
/**
- * Maps a JDBC type to a Catalyst type. This function is called only when
- * the JdbcDialect class corresponding to your database driver returns null.
- *
- * @param sqlType - A field of java.sql.Types
- * @return The Catalyst type corresponding to sqlType.
- */
- private def getCatalystType(
- sqlType: Int,
- precision: Int,
- scale: Int,
- signed: Boolean): DataType = {
- val answer = sqlType match {
- // scalastyle:off
- case java.sql.Types.ARRAY => null
- case java.sql.Types.BIGINT => if (signed) { LongType } else { DecimalType(20,0) }
- case java.sql.Types.BINARY => BinaryType
- case java.sql.Types.BIT => BooleanType // @see JdbcDialect for quirks
- case java.sql.Types.BLOB => BinaryType
- case java.sql.Types.BOOLEAN => BooleanType
- case java.sql.Types.CHAR => StringType
- case java.sql.Types.CLOB => StringType
- case java.sql.Types.DATALINK => null
- case java.sql.Types.DATE => DateType
- case java.sql.Types.DECIMAL
- if precision != 0 || scale != 0 => DecimalType.bounded(precision, scale)
- case java.sql.Types.DECIMAL => DecimalType.SYSTEM_DEFAULT
- case java.sql.Types.DISTINCT => null
- case java.sql.Types.DOUBLE => DoubleType
- case java.sql.Types.FLOAT => FloatType
- case java.sql.Types.INTEGER => if (signed) { IntegerType } else { LongType }
- case java.sql.Types.JAVA_OBJECT => null
- case java.sql.Types.LONGNVARCHAR => StringType
- case java.sql.Types.LONGVARBINARY => BinaryType
- case java.sql.Types.LONGVARCHAR => StringType
- case java.sql.Types.NCHAR => StringType
- case java.sql.Types.NCLOB => StringType
- case java.sql.Types.NULL => null
- case java.sql.Types.NUMERIC
- if precision != 0 || scale != 0 => DecimalType.bounded(precision, scale)
- case java.sql.Types.NUMERIC => DecimalType.SYSTEM_DEFAULT
- case java.sql.Types.NVARCHAR => StringType
- case java.sql.Types.OTHER => null
- case java.sql.Types.REAL => DoubleType
- case java.sql.Types.REF => StringType
- case java.sql.Types.ROWID => LongType
- case java.sql.Types.SMALLINT => IntegerType
- case java.sql.Types.SQLXML => StringType
- case java.sql.Types.STRUCT => StringType
- case java.sql.Types.TIME => TimestampType
- case java.sql.Types.TIMESTAMP => TimestampType
- case java.sql.Types.TINYINT => IntegerType
- case java.sql.Types.VARBINARY => BinaryType
- case java.sql.Types.VARCHAR => StringType
- case _ => null
- // scalastyle:on
- }
-
- if (answer == null) throw new SQLException("Unsupported type " + sqlType)
- answer
- }
-
- /**
* Takes a (schema, table) specification and returns the table's Catalyst
* schema.
*
@@ -126,37 +62,7 @@ object JDBCRDD extends Logging {
try {
val rs = statement.executeQuery()
try {
- val rsmd = rs.getMetaData
- val ncols = rsmd.getColumnCount
- val fields = new Array[StructField](ncols)
- var i = 0
- while (i < ncols) {
- val columnName = rsmd.getColumnLabel(i + 1)
- val dataType = rsmd.getColumnType(i + 1)
- val typeName = rsmd.getColumnTypeName(i + 1)
- val fieldSize = rsmd.getPrecision(i + 1)
- val fieldScale = rsmd.getScale(i + 1)
- val isSigned = {
- try {
- rsmd.isSigned(i + 1)
- } catch {
- // Workaround for HIVE-14684:
- case e: SQLException if
- e.getMessage == "Method not supported" &&
- rsmd.getClass.getName == "org.apache.hive.jdbc.HiveResultSetMetaData" => true
- }
- }
- val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls
- val metadata = new MetadataBuilder()
- .putString("name", columnName)
- .putLong("scale", fieldScale)
- val columnType =
- dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse(
- getCatalystType(dataType, fieldSize, fieldScale, isSigned))
- fields(i) = StructField(columnName, columnType, nullable, metadata.build())
- i = i + 1
- }
- return new StructType(fields)
+ return JdbcUtils.getSchema(rs, dialect)
} finally {
rs.close()
}
@@ -331,195 +237,15 @@ private[jdbc] class JDBCRDD(
}
}
- // A `JDBCValueGetter` is responsible for getting a value from `ResultSet` into a field
- // for `MutableRow`. The last argument `Int` means the index for the value to be set in
- // the row and also used for the value in `ResultSet`.
- private type JDBCValueGetter = (ResultSet, MutableRow, Int) => Unit
-
- /**
- * Creates `JDBCValueGetter`s according to [[StructType]], which can set
- * each value from `ResultSet` to each field of [[MutableRow]] correctly.
- */
- def makeGetters(schema: StructType): Array[JDBCValueGetter] =
- schema.fields.map(sf => makeGetter(sf.dataType, sf.metadata))
-
- private def makeGetter(dt: DataType, metadata: Metadata): JDBCValueGetter = dt match {
- case BooleanType =>
- (rs: ResultSet, row: MutableRow, pos: Int) =>
- row.setBoolean(pos, rs.getBoolean(pos + 1))
-
- case DateType =>
- (rs: ResultSet, row: MutableRow, pos: Int) =>
- // DateTimeUtils.fromJavaDate does not handle null value, so we need to check it.
- val dateVal = rs.getDate(pos + 1)
- if (dateVal != null) {
- row.setInt(pos, DateTimeUtils.fromJavaDate(dateVal))
- } else {
- row.update(pos, null)
- }
-
- // When connecting with Oracle DB through JDBC, the precision and scale of BigDecimal
- // object returned by ResultSet.getBigDecimal is not correctly matched to the table
- // schema reported by ResultSetMetaData.getPrecision and ResultSetMetaData.getScale.
- // If inserting values like 19999 into a column with NUMBER(12, 2) type, you get through
- // a BigDecimal object with scale as 0. But the dataframe schema has correct type as
- // DecimalType(12, 2). Thus, after saving the dataframe into parquet file and then
- // retrieve it, you will get wrong result 199.99.
- // So it is needed to set precision and scale for Decimal based on JDBC metadata.
- case DecimalType.Fixed(p, s) =>
- (rs: ResultSet, row: MutableRow, pos: Int) =>
- val decimal =
- nullSafeConvert[java.math.BigDecimal](rs.getBigDecimal(pos + 1), d => Decimal(d, p, s))
- row.update(pos, decimal)
-
- case DoubleType =>
- (rs: ResultSet, row: MutableRow, pos: Int) =>
- row.setDouble(pos, rs.getDouble(pos + 1))
-
- case FloatType =>
- (rs: ResultSet, row: MutableRow, pos: Int) =>
- row.setFloat(pos, rs.getFloat(pos + 1))
-
- case IntegerType =>
- (rs: ResultSet, row: MutableRow, pos: Int) =>
- row.setInt(pos, rs.getInt(pos + 1))
-
- case LongType if metadata.contains("binarylong") =>
- (rs: ResultSet, row: MutableRow, pos: Int) =>
- val bytes = rs.getBytes(pos + 1)
- var ans = 0L
- var j = 0
- while (j < bytes.size) {
- ans = 256 * ans + (255 & bytes(j))
- j = j + 1
- }
- row.setLong(pos, ans)
-
- case LongType =>
- (rs: ResultSet, row: MutableRow, pos: Int) =>
- row.setLong(pos, rs.getLong(pos + 1))
-
- case ShortType =>
- (rs: ResultSet, row: MutableRow, pos: Int) =>
- row.setShort(pos, rs.getShort(pos + 1))
-
- case StringType =>
- (rs: ResultSet, row: MutableRow, pos: Int) =>
- // TODO(davies): use getBytes for better performance, if the encoding is UTF-8
- row.update(pos, UTF8String.fromString(rs.getString(pos + 1)))
-
- case TimestampType =>
- (rs: ResultSet, row: MutableRow, pos: Int) =>
- val t = rs.getTimestamp(pos + 1)
- if (t != null) {
- row.setLong(pos, DateTimeUtils.fromJavaTimestamp(t))
- } else {
- row.update(pos, null)
- }
-
- case BinaryType =>
- (rs: ResultSet, row: MutableRow, pos: Int) =>
- row.update(pos, rs.getBytes(pos + 1))
-
- case ArrayType(et, _) =>
- val elementConversion = et match {
- case TimestampType =>
- (array: Object) =>
- array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp =>
- nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp)
- }
-
- case StringType =>
- (array: Object) =>
- array.asInstanceOf[Array[java.lang.String]]
- .map(UTF8String.fromString)
-
- case DateType =>
- (array: Object) =>
- array.asInstanceOf[Array[java.sql.Date]].map { date =>
- nullSafeConvert(date, DateTimeUtils.fromJavaDate)
- }
-
- case dt: DecimalType =>
- (array: Object) =>
- array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal =>
- nullSafeConvert[java.math.BigDecimal](
- decimal, d => Decimal(d, dt.precision, dt.scale))
- }
-
- case LongType if metadata.contains("binarylong") =>
- throw new IllegalArgumentException(s"Unsupported array element " +
- s"type ${dt.simpleString} based on binary")
-
- case ArrayType(_, _) =>
- throw new IllegalArgumentException("Nested arrays unsupported")
-
- case _ => (array: Object) => array.asInstanceOf[Array[Any]]
- }
-
- (rs: ResultSet, row: MutableRow, pos: Int) =>
- val array = nullSafeConvert[Object](
- rs.getArray(pos + 1).getArray,
- array => new GenericArrayData(elementConversion.apply(array)))
- row.update(pos, array)
-
- case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.simpleString}")
- }
-
/**
* Runs the SQL query against the JDBC driver.
*
*/
- override def compute(thePart: Partition, context: TaskContext): Iterator[InternalRow] =
- new Iterator[InternalRow] {
+ override def compute(thePart: Partition, context: TaskContext): Iterator[InternalRow] = {
var closed = false
- var finished = false
- var gotNext = false
- var nextValue: InternalRow = null
-
- context.addTaskCompletionListener{ context => close() }
- val inputMetrics = context.taskMetrics().inputMetrics
- val part = thePart.asInstanceOf[JDBCPartition]
- val conn = getConnection()
- val dialect = JdbcDialects.get(url)
- import scala.collection.JavaConverters._
- dialect.beforeFetch(conn, properties.asScala.toMap)
-
- // H2's JDBC driver does not support the setSchema() method. We pass a
- // fully-qualified table name in the SELECT statement. I don't know how to
- // talk about a table in a completely portable way.
-
- val myWhereClause = getWhereClause(part)
-
- val sqlText = s"SELECT $columnList FROM $fqTable $myWhereClause"
- val stmt = conn.prepareStatement(sqlText,
- ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
- val fetchSize = properties.getProperty(JdbcUtils.JDBC_BATCH_FETCH_SIZE, "0").toInt
- require(fetchSize >= 0,
- s"Invalid value `${fetchSize.toString}` for parameter " +
- s"`${JdbcUtils.JDBC_BATCH_FETCH_SIZE}`. The minimum value is 0. When the value is 0, " +
- "the JDBC driver ignores the value and does the estimates.")
- stmt.setFetchSize(fetchSize)
- val rs = stmt.executeQuery()
-
- val getters: Array[JDBCValueGetter] = makeGetters(schema)
- val mutableRow = new SpecificMutableRow(schema.fields.map(x => x.dataType))
-
- def getNext(): InternalRow = {
- if (rs.next()) {
- inputMetrics.incRecordsRead(1)
- var i = 0
- while (i < getters.length) {
- getters(i).apply(rs, mutableRow, i)
- if (rs.wasNull) mutableRow.setNullAt(i)
- i = i + 1
- }
- mutableRow
- } else {
- finished = true
- null.asInstanceOf[InternalRow]
- }
- }
+ var rs: ResultSet = null
+ var stmt: PreparedStatement = null
+ var conn: Connection = null
def close() {
if (closed) return
@@ -555,33 +281,33 @@ private[jdbc] class JDBCRDD(
closed = true
}
- override def hasNext: Boolean = {
- if (!finished) {
- if (!gotNext) {
- nextValue = getNext()
- if (finished) {
- close()
- }
- gotNext = true
- }
- }
- !finished
- }
+ context.addTaskCompletionListener{ context => close() }
- override def next(): InternalRow = {
- if (!hasNext) {
- throw new NoSuchElementException("End of stream")
- }
- gotNext = false
- nextValue
- }
- }
+ val inputMetrics = context.taskMetrics().inputMetrics
+ val part = thePart.asInstanceOf[JDBCPartition]
+ conn = getConnection()
+ val dialect = JdbcDialects.get(url)
+ import scala.collection.JavaConverters._
+ dialect.beforeFetch(conn, properties.asScala.toMap)
- private def nullSafeConvert[T](input: T, f: T => Any): Any = {
- if (input == null) {
- null
- } else {
- f(input)
- }
+ // H2's JDBC driver does not support the setSchema() method. We pass a
+ // fully-qualified table name in the SELECT statement. I don't know how to
+ // talk about a table in a completely portable way.
+
+ val myWhereClause = getWhereClause(part)
+
+ val sqlText = s"SELECT $columnList FROM $fqTable $myWhereClause"
+ stmt = conn.prepareStatement(sqlText,
+ ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
+ val fetchSize = properties.getProperty(JdbcUtils.JDBC_BATCH_FETCH_SIZE, "0").toInt
+ require(fetchSize >= 0,
+ s"Invalid value `${fetchSize.toString}` for parameter " +
+ s"`${JdbcUtils.JDBC_BATCH_FETCH_SIZE}`. The minimum value is 0. When the value is 0, " +
+ "the JDBC driver ignores the value and does the estimates.")
+ stmt.setFetchSize(fetchSize)
+ rs = stmt.executeQuery()
+ val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics)
+
+ CompletionIterator[InternalRow, Iterator[InternalRow]](rowsIterator, close())
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index 37153e545a..132472ad0c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -17,17 +17,25 @@
package org.apache.spark.sql.execution.datasources.jdbc
-import java.sql.{Connection, Driver, DriverManager, PreparedStatement, SQLException}
+import java.sql.{Connection, Driver, DriverManager, PreparedStatement, ResultSet, ResultSetMetaData, SQLException}
import java.util.Properties
import scala.collection.JavaConverters._
import scala.util.Try
import scala.util.control.NonFatal
+import org.apache.spark.TaskContext
+import org.apache.spark.executor.InputMetrics
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.RowEncoder
+import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow}
+import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType}
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.NextIterator
/**
* Util functions for JDBC tables.
@@ -127,6 +135,7 @@ object JdbcUtils extends Logging {
/**
* Retrieve standard jdbc types.
+ *
* @param dt The datatype (e.g. [[org.apache.spark.sql.types.StringType]])
* @return The default JdbcType for this DataType
*/
@@ -154,6 +163,297 @@ object JdbcUtils extends Logging {
throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.simpleString}"))
}
+ /**
+ * Maps a JDBC type to a Catalyst type. This function is called only when
+ * the JdbcDialect class corresponding to your database driver returns null.
+ *
+ * @param sqlType - A field of java.sql.Types
+ * @return The Catalyst type corresponding to sqlType.
+ */
+ private def getCatalystType(
+ sqlType: Int,
+ precision: Int,
+ scale: Int,
+ signed: Boolean): DataType = {
+ val answer = sqlType match {
+ // scalastyle:off
+ case java.sql.Types.ARRAY => null
+ case java.sql.Types.BIGINT => if (signed) { LongType } else { DecimalType(20,0) }
+ case java.sql.Types.BINARY => BinaryType
+ case java.sql.Types.BIT => BooleanType // @see JdbcDialect for quirks
+ case java.sql.Types.BLOB => BinaryType
+ case java.sql.Types.BOOLEAN => BooleanType
+ case java.sql.Types.CHAR => StringType
+ case java.sql.Types.CLOB => StringType
+ case java.sql.Types.DATALINK => null
+ case java.sql.Types.DATE => DateType
+ case java.sql.Types.DECIMAL
+ if precision != 0 || scale != 0 => DecimalType.bounded(precision, scale)
+ case java.sql.Types.DECIMAL => DecimalType.SYSTEM_DEFAULT
+ case java.sql.Types.DISTINCT => null
+ case java.sql.Types.DOUBLE => DoubleType
+ case java.sql.Types.FLOAT => FloatType
+ case java.sql.Types.INTEGER => if (signed) { IntegerType } else { LongType }
+ case java.sql.Types.JAVA_OBJECT => null
+ case java.sql.Types.LONGNVARCHAR => StringType
+ case java.sql.Types.LONGVARBINARY => BinaryType
+ case java.sql.Types.LONGVARCHAR => StringType
+ case java.sql.Types.NCHAR => StringType
+ case java.sql.Types.NCLOB => StringType
+ case java.sql.Types.NULL => null
+ case java.sql.Types.NUMERIC
+ if precision != 0 || scale != 0 => DecimalType.bounded(precision, scale)
+ case java.sql.Types.NUMERIC => DecimalType.SYSTEM_DEFAULT
+ case java.sql.Types.NVARCHAR => StringType
+ case java.sql.Types.OTHER => null
+ case java.sql.Types.REAL => DoubleType
+ case java.sql.Types.REF => StringType
+ case java.sql.Types.ROWID => LongType
+ case java.sql.Types.SMALLINT => IntegerType
+ case java.sql.Types.SQLXML => StringType
+ case java.sql.Types.STRUCT => StringType
+ case java.sql.Types.TIME => TimestampType
+ case java.sql.Types.TIMESTAMP => TimestampType
+ case java.sql.Types.TINYINT => IntegerType
+ case java.sql.Types.VARBINARY => BinaryType
+ case java.sql.Types.VARCHAR => StringType
+ case _ => null
+ // scalastyle:on
+ }
+
+ if (answer == null) throw new SQLException("Unsupported type " + sqlType)
+ answer
+ }
+
+ /**
+ * Takes a [[ResultSet]] and returns its Catalyst schema.
+ *
+ * @return A [[StructType]] giving the Catalyst schema.
+ * @throws SQLException if the schema contains an unsupported type.
+ */
+ def getSchema(resultSet: ResultSet, dialect: JdbcDialect): StructType = {
+ val rsmd = resultSet.getMetaData
+ val ncols = rsmd.getColumnCount
+ val fields = new Array[StructField](ncols)
+ var i = 0
+ while (i < ncols) {
+ val columnName = rsmd.getColumnLabel(i + 1)
+ val dataType = rsmd.getColumnType(i + 1)
+ val typeName = rsmd.getColumnTypeName(i + 1)
+ val fieldSize = rsmd.getPrecision(i + 1)
+ val fieldScale = rsmd.getScale(i + 1)
+ val isSigned = {
+ try {
+ rsmd.isSigned(i + 1)
+ } catch {
+ // Workaround for HIVE-14684:
+ case e: SQLException if
+ e.getMessage == "Method not supported" &&
+ rsmd.getClass.getName == "org.apache.hive.jdbc.HiveResultSetMetaData" => true
+ }
+ }
+ val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls
+ val metadata = new MetadataBuilder()
+ .putString("name", columnName)
+ .putLong("scale", fieldScale)
+ val columnType =
+ dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse(
+ getCatalystType(dataType, fieldSize, fieldScale, isSigned))
+ fields(i) = StructField(columnName, columnType, nullable, metadata.build())
+ i = i + 1
+ }
+ new StructType(fields)
+ }
+
+ /**
+ * Convert a [[ResultSet]] into an iterator of Catalyst Rows.
+ */
+ def resultSetToRows(resultSet: ResultSet, schema: StructType): Iterator[Row] = {
+ val inputMetrics =
+ Option(TaskContext.get()).map(_.taskMetrics().inputMetrics).getOrElse(new InputMetrics)
+ val encoder = RowEncoder(schema).resolveAndBind()
+ val internalRows = resultSetToSparkInternalRows(resultSet, schema, inputMetrics)
+ internalRows.map(encoder.fromRow)
+ }
+
+ private[spark] def resultSetToSparkInternalRows(
+ resultSet: ResultSet,
+ schema: StructType,
+ inputMetrics: InputMetrics): Iterator[InternalRow] = {
+ new NextIterator[InternalRow] {
+ private[this] val rs = resultSet
+ private[this] val getters: Array[JDBCValueGetter] = makeGetters(schema)
+ private[this] val mutableRow = new SpecificMutableRow(schema.fields.map(x => x.dataType))
+
+ override protected def close(): Unit = {
+ try {
+ rs.close()
+ } catch {
+ case e: Exception => logWarning("Exception closing resultset", e)
+ }
+ }
+
+ override protected def getNext(): InternalRow = {
+ if (rs.next()) {
+ inputMetrics.incRecordsRead(1)
+ var i = 0
+ while (i < getters.length) {
+ getters(i).apply(rs, mutableRow, i)
+ if (rs.wasNull) mutableRow.setNullAt(i)
+ i = i + 1
+ }
+ mutableRow
+ } else {
+ finished = true
+ null.asInstanceOf[InternalRow]
+ }
+ }
+ }
+ }
+
+ // A `JDBCValueGetter` is responsible for getting a value from `ResultSet` into a field
+ // for `MutableRow`. The last argument `Int` means the index for the value to be set in
+ // the row and also used for the value in `ResultSet`.
+ private type JDBCValueGetter = (ResultSet, MutableRow, Int) => Unit
+
+ /**
+ * Creates `JDBCValueGetter`s according to [[StructType]], which can set
+ * each value from `ResultSet` to each field of [[MutableRow]] correctly.
+ */
+ private def makeGetters(schema: StructType): Array[JDBCValueGetter] =
+ schema.fields.map(sf => makeGetter(sf.dataType, sf.metadata))
+
+ private def makeGetter(dt: DataType, metadata: Metadata): JDBCValueGetter = dt match {
+ case BooleanType =>
+ (rs: ResultSet, row: MutableRow, pos: Int) =>
+ row.setBoolean(pos, rs.getBoolean(pos + 1))
+
+ case DateType =>
+ (rs: ResultSet, row: MutableRow, pos: Int) =>
+ // DateTimeUtils.fromJavaDate does not handle null value, so we need to check it.
+ val dateVal = rs.getDate(pos + 1)
+ if (dateVal != null) {
+ row.setInt(pos, DateTimeUtils.fromJavaDate(dateVal))
+ } else {
+ row.update(pos, null)
+ }
+
+ // When connecting with Oracle DB through JDBC, the precision and scale of BigDecimal
+ // object returned by ResultSet.getBigDecimal is not correctly matched to the table
+ // schema reported by ResultSetMetaData.getPrecision and ResultSetMetaData.getScale.
+ // If inserting values like 19999 into a column with NUMBER(12, 2) type, you get through
+ // a BigDecimal object with scale as 0. But the dataframe schema has correct type as
+ // DecimalType(12, 2). Thus, after saving the dataframe into parquet file and then
+ // retrieve it, you will get wrong result 199.99.
+ // So it is needed to set precision and scale for Decimal based on JDBC metadata.
+ case DecimalType.Fixed(p, s) =>
+ (rs: ResultSet, row: MutableRow, pos: Int) =>
+ val decimal =
+ nullSafeConvert[java.math.BigDecimal](rs.getBigDecimal(pos + 1), d => Decimal(d, p, s))
+ row.update(pos, decimal)
+
+ case DoubleType =>
+ (rs: ResultSet, row: MutableRow, pos: Int) =>
+ row.setDouble(pos, rs.getDouble(pos + 1))
+
+ case FloatType =>
+ (rs: ResultSet, row: MutableRow, pos: Int) =>
+ row.setFloat(pos, rs.getFloat(pos + 1))
+
+ case IntegerType =>
+ (rs: ResultSet, row: MutableRow, pos: Int) =>
+ row.setInt(pos, rs.getInt(pos + 1))
+
+ case LongType if metadata.contains("binarylong") =>
+ (rs: ResultSet, row: MutableRow, pos: Int) =>
+ val bytes = rs.getBytes(pos + 1)
+ var ans = 0L
+ var j = 0
+ while (j < bytes.size) {
+ ans = 256 * ans + (255 & bytes(j))
+ j = j + 1
+ }
+ row.setLong(pos, ans)
+
+ case LongType =>
+ (rs: ResultSet, row: MutableRow, pos: Int) =>
+ row.setLong(pos, rs.getLong(pos + 1))
+
+ case ShortType =>
+ (rs: ResultSet, row: MutableRow, pos: Int) =>
+ row.setShort(pos, rs.getShort(pos + 1))
+
+ case StringType =>
+ (rs: ResultSet, row: MutableRow, pos: Int) =>
+ // TODO(davies): use getBytes for better performance, if the encoding is UTF-8
+ row.update(pos, UTF8String.fromString(rs.getString(pos + 1)))
+
+ case TimestampType =>
+ (rs: ResultSet, row: MutableRow, pos: Int) =>
+ val t = rs.getTimestamp(pos + 1)
+ if (t != null) {
+ row.setLong(pos, DateTimeUtils.fromJavaTimestamp(t))
+ } else {
+ row.update(pos, null)
+ }
+
+ case BinaryType =>
+ (rs: ResultSet, row: MutableRow, pos: Int) =>
+ row.update(pos, rs.getBytes(pos + 1))
+
+ case ArrayType(et, _) =>
+ val elementConversion = et match {
+ case TimestampType =>
+ (array: Object) =>
+ array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp =>
+ nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp)
+ }
+
+ case StringType =>
+ (array: Object) =>
+ array.asInstanceOf[Array[java.lang.String]]
+ .map(UTF8String.fromString)
+
+ case DateType =>
+ (array: Object) =>
+ array.asInstanceOf[Array[java.sql.Date]].map { date =>
+ nullSafeConvert(date, DateTimeUtils.fromJavaDate)
+ }
+
+ case dt: DecimalType =>
+ (array: Object) =>
+ array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal =>
+ nullSafeConvert[java.math.BigDecimal](
+ decimal, d => Decimal(d, dt.precision, dt.scale))
+ }
+
+ case LongType if metadata.contains("binarylong") =>
+ throw new IllegalArgumentException(s"Unsupported array element " +
+ s"type ${dt.simpleString} based on binary")
+
+ case ArrayType(_, _) =>
+ throw new IllegalArgumentException("Nested arrays unsupported")
+
+ case _ => (array: Object) => array.asInstanceOf[Array[Any]]
+ }
+
+ (rs: ResultSet, row: MutableRow, pos: Int) =>
+ val array = nullSafeConvert[Object](
+ rs.getArray(pos + 1).getArray,
+ array => new GenericArrayData(elementConversion.apply(array)))
+ row.update(pos, array)
+
+ case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.simpleString}")
+ }
+
+ private def nullSafeConvert[T](input: T, f: T => Any): Any = {
+ if (input == null) {
+ null
+ } else {
+ f(input)
+ }
+ }
+
// A `JDBCValueSetter` is responsible for setting a value from `Row` into a field for
// `PreparedStatement`. The last argument `Int` means the index for the value to be set
// in the SQL statement and also used for the value in `Row`.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
index 78107809a1..3a6d5b7f1c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
@@ -162,7 +162,7 @@ object JdbcDialects {
/**
* Fetch the JdbcDialect class corresponding to a given database url.
*/
- private[sql] def get(url: String): JdbcDialect = {
+ def get(url: String): JdbcDialect = {
val matchingDialects = dialects.filter(_.canHandle(url))
matchingDialects.length match {
case 0 => NoopDialect