aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala302
1 files changed, 301 insertions, 1 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index 37153e545a..132472ad0c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -17,17 +17,25 @@
package org.apache.spark.sql.execution.datasources.jdbc
-import java.sql.{Connection, Driver, DriverManager, PreparedStatement, SQLException}
+import java.sql.{Connection, Driver, DriverManager, PreparedStatement, ResultSet, ResultSetMetaData, SQLException}
import java.util.Properties
import scala.collection.JavaConverters._
import scala.util.Try
import scala.util.control.NonFatal
+import org.apache.spark.TaskContext
+import org.apache.spark.executor.InputMetrics
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.RowEncoder
+import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow}
+import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType}
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.NextIterator
/**
* Util functions for JDBC tables.
@@ -127,6 +135,7 @@ object JdbcUtils extends Logging {
/**
* Retrieve standard jdbc types.
+ *
* @param dt The datatype (e.g. [[org.apache.spark.sql.types.StringType]])
* @return The default JdbcType for this DataType
*/
@@ -154,6 +163,297 @@ object JdbcUtils extends Logging {
throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.simpleString}"))
}
+ /**
+ * Maps a JDBC type to a Catalyst type. This function is called only when
+ * the JdbcDialect class corresponding to your database driver returns null.
+ *
+ * @param sqlType - A field of java.sql.Types
+ * @return The Catalyst type corresponding to sqlType.
+ */
+ private def getCatalystType(
+ sqlType: Int,
+ precision: Int,
+ scale: Int,
+ signed: Boolean): DataType = {
+ val answer = sqlType match {
+ // scalastyle:off
+ case java.sql.Types.ARRAY => null
+ case java.sql.Types.BIGINT => if (signed) { LongType } else { DecimalType(20,0) }
+ case java.sql.Types.BINARY => BinaryType
+ case java.sql.Types.BIT => BooleanType // @see JdbcDialect for quirks
+ case java.sql.Types.BLOB => BinaryType
+ case java.sql.Types.BOOLEAN => BooleanType
+ case java.sql.Types.CHAR => StringType
+ case java.sql.Types.CLOB => StringType
+ case java.sql.Types.DATALINK => null
+ case java.sql.Types.DATE => DateType
+ case java.sql.Types.DECIMAL
+ if precision != 0 || scale != 0 => DecimalType.bounded(precision, scale)
+ case java.sql.Types.DECIMAL => DecimalType.SYSTEM_DEFAULT
+ case java.sql.Types.DISTINCT => null
+ case java.sql.Types.DOUBLE => DoubleType
+ case java.sql.Types.FLOAT => FloatType
+ case java.sql.Types.INTEGER => if (signed) { IntegerType } else { LongType }
+ case java.sql.Types.JAVA_OBJECT => null
+ case java.sql.Types.LONGNVARCHAR => StringType
+ case java.sql.Types.LONGVARBINARY => BinaryType
+ case java.sql.Types.LONGVARCHAR => StringType
+ case java.sql.Types.NCHAR => StringType
+ case java.sql.Types.NCLOB => StringType
+ case java.sql.Types.NULL => null
+ case java.sql.Types.NUMERIC
+ if precision != 0 || scale != 0 => DecimalType.bounded(precision, scale)
+ case java.sql.Types.NUMERIC => DecimalType.SYSTEM_DEFAULT
+ case java.sql.Types.NVARCHAR => StringType
+ case java.sql.Types.OTHER => null
+ case java.sql.Types.REAL => DoubleType
+ case java.sql.Types.REF => StringType
+ case java.sql.Types.ROWID => LongType
+ case java.sql.Types.SMALLINT => IntegerType
+ case java.sql.Types.SQLXML => StringType
+ case java.sql.Types.STRUCT => StringType
+ case java.sql.Types.TIME => TimestampType
+ case java.sql.Types.TIMESTAMP => TimestampType
+ case java.sql.Types.TINYINT => IntegerType
+ case java.sql.Types.VARBINARY => BinaryType
+ case java.sql.Types.VARCHAR => StringType
+ case _ => null
+ // scalastyle:on
+ }
+
+ if (answer == null) throw new SQLException("Unsupported type " + sqlType)
+ answer
+ }
+
+ /**
+ * Takes a [[ResultSet]] and returns its Catalyst schema.
+ *
+ * @return A [[StructType]] giving the Catalyst schema.
+ * @throws SQLException if the schema contains an unsupported type.
+ */
+ def getSchema(resultSet: ResultSet, dialect: JdbcDialect): StructType = {
+ val rsmd = resultSet.getMetaData
+ val ncols = rsmd.getColumnCount
+ val fields = new Array[StructField](ncols)
+ var i = 0
+ while (i < ncols) {
+ val columnName = rsmd.getColumnLabel(i + 1)
+ val dataType = rsmd.getColumnType(i + 1)
+ val typeName = rsmd.getColumnTypeName(i + 1)
+ val fieldSize = rsmd.getPrecision(i + 1)
+ val fieldScale = rsmd.getScale(i + 1)
+ val isSigned = {
+ try {
+ rsmd.isSigned(i + 1)
+ } catch {
+ // Workaround for HIVE-14684:
+ case e: SQLException if
+ e.getMessage == "Method not supported" &&
+ rsmd.getClass.getName == "org.apache.hive.jdbc.HiveResultSetMetaData" => true
+ }
+ }
+ val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls
+ val metadata = new MetadataBuilder()
+ .putString("name", columnName)
+ .putLong("scale", fieldScale)
+ val columnType =
+ dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse(
+ getCatalystType(dataType, fieldSize, fieldScale, isSigned))
+ fields(i) = StructField(columnName, columnType, nullable, metadata.build())
+ i = i + 1
+ }
+ new StructType(fields)
+ }
+
+ /**
+ * Convert a [[ResultSet]] into an iterator of Catalyst Rows.
+ */
+ def resultSetToRows(resultSet: ResultSet, schema: StructType): Iterator[Row] = {
+ val inputMetrics =
+ Option(TaskContext.get()).map(_.taskMetrics().inputMetrics).getOrElse(new InputMetrics)
+ val encoder = RowEncoder(schema).resolveAndBind()
+ val internalRows = resultSetToSparkInternalRows(resultSet, schema, inputMetrics)
+ internalRows.map(encoder.fromRow)
+ }
+
+ private[spark] def resultSetToSparkInternalRows(
+ resultSet: ResultSet,
+ schema: StructType,
+ inputMetrics: InputMetrics): Iterator[InternalRow] = {
+ new NextIterator[InternalRow] {
+ private[this] val rs = resultSet
+ private[this] val getters: Array[JDBCValueGetter] = makeGetters(schema)
+ private[this] val mutableRow = new SpecificMutableRow(schema.fields.map(x => x.dataType))
+
+ override protected def close(): Unit = {
+ try {
+ rs.close()
+ } catch {
+ case e: Exception => logWarning("Exception closing resultset", e)
+ }
+ }
+
+ override protected def getNext(): InternalRow = {
+ if (rs.next()) {
+ inputMetrics.incRecordsRead(1)
+ var i = 0
+ while (i < getters.length) {
+ getters(i).apply(rs, mutableRow, i)
+ if (rs.wasNull) mutableRow.setNullAt(i)
+ i = i + 1
+ }
+ mutableRow
+ } else {
+ finished = true
+ null.asInstanceOf[InternalRow]
+ }
+ }
+ }
+ }
+
+ // A `JDBCValueGetter` is responsible for getting a value from `ResultSet` into a field
+ // for `MutableRow`. The last argument `Int` means the index for the value to be set in
+ // the row and also used for the value in `ResultSet`.
+ private type JDBCValueGetter = (ResultSet, MutableRow, Int) => Unit
+
+ /**
+ * Creates `JDBCValueGetter`s according to [[StructType]], which can set
+ * each value from `ResultSet` to each field of [[MutableRow]] correctly.
+ */
+ private def makeGetters(schema: StructType): Array[JDBCValueGetter] =
+ schema.fields.map(sf => makeGetter(sf.dataType, sf.metadata))
+
+ private def makeGetter(dt: DataType, metadata: Metadata): JDBCValueGetter = dt match {
+ case BooleanType =>
+ (rs: ResultSet, row: MutableRow, pos: Int) =>
+ row.setBoolean(pos, rs.getBoolean(pos + 1))
+
+ case DateType =>
+ (rs: ResultSet, row: MutableRow, pos: Int) =>
+ // DateTimeUtils.fromJavaDate does not handle null value, so we need to check it.
+ val dateVal = rs.getDate(pos + 1)
+ if (dateVal != null) {
+ row.setInt(pos, DateTimeUtils.fromJavaDate(dateVal))
+ } else {
+ row.update(pos, null)
+ }
+
+ // When connecting with Oracle DB through JDBC, the precision and scale of BigDecimal
+ // object returned by ResultSet.getBigDecimal is not correctly matched to the table
+ // schema reported by ResultSetMetaData.getPrecision and ResultSetMetaData.getScale.
+ // If inserting values like 19999 into a column with NUMBER(12, 2) type, you get through
+ // a BigDecimal object with scale as 0. But the dataframe schema has correct type as
+ // DecimalType(12, 2). Thus, after saving the dataframe into parquet file and then
+ // retrieve it, you will get wrong result 199.99.
+ // So it is needed to set precision and scale for Decimal based on JDBC metadata.
+ case DecimalType.Fixed(p, s) =>
+ (rs: ResultSet, row: MutableRow, pos: Int) =>
+ val decimal =
+ nullSafeConvert[java.math.BigDecimal](rs.getBigDecimal(pos + 1), d => Decimal(d, p, s))
+ row.update(pos, decimal)
+
+ case DoubleType =>
+ (rs: ResultSet, row: MutableRow, pos: Int) =>
+ row.setDouble(pos, rs.getDouble(pos + 1))
+
+ case FloatType =>
+ (rs: ResultSet, row: MutableRow, pos: Int) =>
+ row.setFloat(pos, rs.getFloat(pos + 1))
+
+ case IntegerType =>
+ (rs: ResultSet, row: MutableRow, pos: Int) =>
+ row.setInt(pos, rs.getInt(pos + 1))
+
+ case LongType if metadata.contains("binarylong") =>
+ (rs: ResultSet, row: MutableRow, pos: Int) =>
+ val bytes = rs.getBytes(pos + 1)
+ var ans = 0L
+ var j = 0
+ while (j < bytes.size) {
+ ans = 256 * ans + (255 & bytes(j))
+ j = j + 1
+ }
+ row.setLong(pos, ans)
+
+ case LongType =>
+ (rs: ResultSet, row: MutableRow, pos: Int) =>
+ row.setLong(pos, rs.getLong(pos + 1))
+
+ case ShortType =>
+ (rs: ResultSet, row: MutableRow, pos: Int) =>
+ row.setShort(pos, rs.getShort(pos + 1))
+
+ case StringType =>
+ (rs: ResultSet, row: MutableRow, pos: Int) =>
+ // TODO(davies): use getBytes for better performance, if the encoding is UTF-8
+ row.update(pos, UTF8String.fromString(rs.getString(pos + 1)))
+
+ case TimestampType =>
+ (rs: ResultSet, row: MutableRow, pos: Int) =>
+ val t = rs.getTimestamp(pos + 1)
+ if (t != null) {
+ row.setLong(pos, DateTimeUtils.fromJavaTimestamp(t))
+ } else {
+ row.update(pos, null)
+ }
+
+ case BinaryType =>
+ (rs: ResultSet, row: MutableRow, pos: Int) =>
+ row.update(pos, rs.getBytes(pos + 1))
+
+ case ArrayType(et, _) =>
+ val elementConversion = et match {
+ case TimestampType =>
+ (array: Object) =>
+ array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp =>
+ nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp)
+ }
+
+ case StringType =>
+ (array: Object) =>
+ array.asInstanceOf[Array[java.lang.String]]
+ .map(UTF8String.fromString)
+
+ case DateType =>
+ (array: Object) =>
+ array.asInstanceOf[Array[java.sql.Date]].map { date =>
+ nullSafeConvert(date, DateTimeUtils.fromJavaDate)
+ }
+
+ case dt: DecimalType =>
+ (array: Object) =>
+ array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal =>
+ nullSafeConvert[java.math.BigDecimal](
+ decimal, d => Decimal(d, dt.precision, dt.scale))
+ }
+
+ case LongType if metadata.contains("binarylong") =>
+ throw new IllegalArgumentException(s"Unsupported array element " +
+ s"type ${dt.simpleString} based on binary")
+
+ case ArrayType(_, _) =>
+ throw new IllegalArgumentException("Nested arrays unsupported")
+
+ case _ => (array: Object) => array.asInstanceOf[Array[Any]]
+ }
+
+ (rs: ResultSet, row: MutableRow, pos: Int) =>
+ val array = nullSafeConvert[Object](
+ rs.getArray(pos + 1).getArray,
+ array => new GenericArrayData(elementConversion.apply(array)))
+ row.update(pos, array)
+
+ case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.simpleString}")
+ }
+
+ private def nullSafeConvert[T](input: T, f: T => Any): Any = {
+ if (input == null) {
+ null
+ } else {
+ f(input)
+ }
+ }
+
// A `JDBCValueSetter` is responsible for setting a value from `Row` into a field for
// `PreparedStatement`. The last argument `Int` means the index for the value to be set
// in the SQL statement and also used for the value in `Row`.