aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorBrandon Bradley <bradleytastic@gmail.com>2016-02-19 14:43:21 -0800
committerMichael Armbrust <michael@databricks.com>2016-02-19 14:43:21 -0800
commitdbb08cdd5ae320082cdbcc9cfb8155f5a9da8b8c (patch)
tree87af57431d10b42eb8a498df99028808351d4b95 /sql
parentc7c55637bfc523237f5cc5c5b61837b1e3d5fdfc (diff)
downloadspark-dbb08cdd5ae320082cdbcc9cfb8155f5a9da8b8c.tar.gz
spark-dbb08cdd5ae320082cdbcc9cfb8155f5a9da8b8c.tar.bz2
spark-dbb08cdd5ae320082cdbcc9cfb8155f5a9da8b8c.zip
[SPARK-12966][SQL] ArrayType(DecimalType) support in Postgres JDBC
Fixes error `org.postgresql.util.PSQLException: Unable to find server array type for provided name decimal(38,18)`. * Passes scale metadata to JDBC dialect for usage in type conversions. * Removes unused length/scale/precision parameters from `createArrayOf` parameter `typeName` (for writing). * Adds configurable precision and scale to Postgres `DecimalType` (for reading). * Adds a new kind of test that verifies the schema written by `DataFrame.write.jdbc`. Author: Brandon Bradley <bradleytastic@gmail.com> Closes #10928 from blbradley/spark-12966.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala18
3 files changed, 19 insertions, 8 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index befba867bc..ed02b3f95f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -137,7 +137,9 @@ private[sql] object JDBCRDD extends Logging {
val fieldScale = rsmd.getScale(i + 1)
val isSigned = rsmd.isSigned(i + 1)
val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls
- val metadata = new MetadataBuilder().putString("name", columnName)
+ val metadata = new MetadataBuilder()
+ .putString("name", columnName)
+ .putLong("scale", fieldScale)
val columnType =
dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse(
getCatalystType(dataType, fieldSize, fieldScale, isSigned))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index 69ba84646f..e295722cac 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -194,8 +194,11 @@ object JdbcUtils extends Logging {
case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i))
case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i))
case ArrayType(et, _) =>
+ // remove type length parameters from end of type name
+ val typeName = getJdbcType(et, dialect).databaseTypeDefinition
+ .toLowerCase.split("\\(")(0)
val array = conn.createArrayOf(
- getJdbcType(et, dialect).databaseTypeDefinition.toLowerCase,
+ typeName,
row.getSeq[AnyRef](i).toArray)
stmt.setArray(i + 1, array)
case _ => throw new IllegalArgumentException(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
index 8d43966480..2d6c3974a8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
@@ -32,14 +32,18 @@ private object PostgresDialect extends JdbcDialect {
if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) {
Some(BinaryType)
} else if (sqlType == Types.OTHER) {
- toCatalystType(typeName).filter(_ == StringType)
- } else if (sqlType == Types.ARRAY && typeName.length > 1 && typeName(0) == '_') {
- toCatalystType(typeName.drop(1)).map(ArrayType(_))
+ Some(StringType)
+ } else if (sqlType == Types.ARRAY) {
+ val scale = md.build.getLong("scale").toInt
+ // postgres array type names start with underscore
+ toCatalystType(typeName.drop(1), size, scale).map(ArrayType(_))
} else None
}
- // TODO: support more type names.
- private def toCatalystType(typeName: String): Option[DataType] = typeName match {
+ private def toCatalystType(
+ typeName: String,
+ precision: Int,
+ scale: Int): Option[DataType] = typeName match {
case "bool" => Some(BooleanType)
case "bit" => Some(BinaryType)
case "int2" => Some(ShortType)
@@ -52,7 +56,7 @@ private object PostgresDialect extends JdbcDialect {
case "bytea" => Some(BinaryType)
case "timestamp" | "timestamptz" | "time" | "timetz" => Some(TimestampType)
case "date" => Some(DateType)
- case "numeric" => Some(DecimalType.SYSTEM_DEFAULT)
+ case "numeric" | "decimal" => Some(DecimalType.bounded(precision, scale))
case _ => None
}
@@ -62,6 +66,8 @@ private object PostgresDialect extends JdbcDialect {
case BooleanType => Some(JdbcType("BOOLEAN", Types.BOOLEAN))
case FloatType => Some(JdbcType("FLOAT4", Types.FLOAT))
case DoubleType => Some(JdbcType("FLOAT8", Types.DOUBLE))
+ case t: DecimalType => Some(
+ JdbcType(s"NUMERIC(${t.precision},${t.scale})", java.sql.Types.NUMERIC))
case ArrayType(et, _) if et.isInstanceOf[AtomicType] =>
getJDBCType(et).map(_.databaseTypeDefinition)
.orElse(JdbcUtils.getCommonJDBCType(et).map(_.databaseTypeDefinition))