aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core/src')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala46
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala21
2 files changed, 51 insertions, 16 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala
index ce8731efd1..f541996b65 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala
@@ -28,28 +28,42 @@ private case object OracleDialect extends JdbcDialect {
override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
- // Handle NUMBER fields that have no precision/scale in special way
- // because JDBC ResultSetMetaData converts this to 0 precision and -127 scale
- // For more details, please see
- // https://github.com/apache/spark/pull/8780#issuecomment-145598968
- // and
- // https://github.com/apache/spark/pull/8780#issuecomment-144541760
- if (sqlType == Types.NUMERIC && size == 0) {
- // This is sub-optimal as we have to pick a precision/scale in advance whereas the data
- // in Oracle is allowed to have different precision/scale for each value.
- Option(DecimalType(DecimalType.MAX_PRECISION, 10))
- } else if (sqlType == Types.NUMERIC && md.build().getLong("scale") == -127) {
- // Handle FLOAT fields in a special way because JDBC ResultSetMetaData converts
- // this to NUMERIC with -127 scale
- // Not sure if there is a more robust way to identify the field as a float (or other
- // numeric types that do not specify a scale.
- Option(DecimalType(DecimalType.MAX_PRECISION, 10))
+ if (sqlType == Types.NUMERIC) {
+ val scale = if (null != md) md.build().getLong("scale") else 0L
+ size match {
+ // Handle NUMBER fields that have no precision/scale in special way
+ // because JDBC ResultSetMetaData converts this to 0 precision and -127 scale
+ // For more details, please see
+ // https://github.com/apache/spark/pull/8780#issuecomment-145598968
+ // and
+ // https://github.com/apache/spark/pull/8780#issuecomment-144541760
+ case 0 => Option(DecimalType(DecimalType.MAX_PRECISION, 10))
+ // Handle FLOAT fields in a special way because JDBC ResultSetMetaData converts
+ // this to NUMERIC with -127 scale
+ // Not sure if there is a more robust way to identify the field as a float (or other
+ // numeric types that do not specify a scale.
+ case _ if scale == -127L => Option(DecimalType(DecimalType.MAX_PRECISION, 10))
+ case 1 => Option(BooleanType)
+ case 3 | 5 | 10 => Option(IntegerType)
+ case 19 if scale == 0L => Option(LongType)
+ case 19 if scale == 4L => Option(FloatType)
+ case _ => None
+ }
} else {
None
}
}
override def getJDBCType(dt: DataType): Option[JdbcType] = dt match {
+ // For more details, please see
+ // https://docs.oracle.com/cd/E19501-01/819-3659/gcmaz/
+ case BooleanType => Some(JdbcType("NUMBER(1)", java.sql.Types.BOOLEAN))
+ case IntegerType => Some(JdbcType("NUMBER(10)", java.sql.Types.INTEGER))
+ case LongType => Some(JdbcType("NUMBER(19)", java.sql.Types.BIGINT))
+ case FloatType => Some(JdbcType("NUMBER(19, 4)", java.sql.Types.FLOAT))
+ case DoubleType => Some(JdbcType("NUMBER(19, 4)", java.sql.Types.DOUBLE))
+ case ByteType => Some(JdbcType("NUMBER(3)", java.sql.Types.SMALLINT))
+ case ShortType => Some(JdbcType("NUMBER(5)", java.sql.Types.SMALLINT))
case StringType => Some(JdbcType("VARCHAR2(255)", java.sql.Types.VARCHAR))
case _ => None
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index 995b1200a2..2d8ee338a9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -739,6 +739,27 @@ class JDBCSuite extends SparkFunSuite
map(_.databaseTypeDefinition).get == "VARCHAR2(255)")
}
+ test("SPARK-16625: General data types to be mapped to Oracle") {
+
+ def getJdbcType(dialect: JdbcDialect, dt: DataType): String = {
+ dialect.getJDBCType(dt).orElse(JdbcUtils.getCommonJDBCType(dt)).
+ map(_.databaseTypeDefinition).get
+ }
+
+ val oracleDialect = JdbcDialects.get("jdbc:oracle://127.0.0.1/db")
+ assert(getJdbcType(oracleDialect, BooleanType) == "NUMBER(1)")
+ assert(getJdbcType(oracleDialect, IntegerType) == "NUMBER(10)")
+ assert(getJdbcType(oracleDialect, LongType) == "NUMBER(19)")
+ assert(getJdbcType(oracleDialect, FloatType) == "NUMBER(19, 4)")
+ assert(getJdbcType(oracleDialect, DoubleType) == "NUMBER(19, 4)")
+ assert(getJdbcType(oracleDialect, ByteType) == "NUMBER(3)")
+ assert(getJdbcType(oracleDialect, ShortType) == "NUMBER(5)")
+ assert(getJdbcType(oracleDialect, StringType) == "VARCHAR2(255)")
+ assert(getJdbcType(oracleDialect, BinaryType) == "BLOB")
+ assert(getJdbcType(oracleDialect, DateType) == "DATE")
+ assert(getJdbcType(oracleDialect, TimestampType) == "TIMESTAMP")
+ }
+
private def assertEmptyQuery(sqlString: String): Unit = {
assert(sql(sqlString).collect().isEmpty)
}