diff options
author | Brandon Bradley <bradleytastic@gmail.com> | 2016-02-19 14:43:21 -0800 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2016-02-19 14:43:21 -0800 |
commit | dbb08cdd5ae320082cdbcc9cfb8155f5a9da8b8c (patch) | |
tree | 87af57431d10b42eb8a498df99028808351d4b95 /sql | |
parent | c7c55637bfc523237f5cc5c5b61837b1e3d5fdfc (diff) | |
download | spark-dbb08cdd5ae320082cdbcc9cfb8155f5a9da8b8c.tar.gz spark-dbb08cdd5ae320082cdbcc9cfb8155f5a9da8b8c.tar.bz2 spark-dbb08cdd5ae320082cdbcc9cfb8155f5a9da8b8c.zip |
[SPARK-12966][SQL] ArrayType(DecimalType) support in Postgres JDBC
Fixes error `org.postgresql.util.PSQLException: Unable to find server array type for provided name decimal(38,18)`.
* Passes scale metadata to JDBC dialect for usage in type conversions.
* Removes unused length/scale/precision parameters from `createArrayOf` parameter `typeName` (for writing).
* Adds configurable precision and scale to Postgres `DecimalType` (for reading).
* Adds a new kind of test that verifies the schema written by `DataFrame.write.jdbc`.
Author: Brandon Bradley <bradleytastic@gmail.com>
Closes #10928 from blbradley/spark-12966.
Diffstat (limited to 'sql')
3 files changed, 19 insertions, 8 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index befba867bc..ed02b3f95f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -137,7 +137,9 @@ private[sql] object JDBCRDD extends Logging { val fieldScale = rsmd.getScale(i + 1) val isSigned = rsmd.isSigned(i + 1) val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls - val metadata = new MetadataBuilder().putString("name", columnName) + val metadata = new MetadataBuilder() + .putString("name", columnName) + .putLong("scale", fieldScale) val columnType = dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse( getCatalystType(dataType, fieldSize, fieldScale, isSigned)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 69ba84646f..e295722cac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -194,8 +194,11 @@ object JdbcUtils extends Logging { case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i)) case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i)) case ArrayType(et, _) => + // remove type length parameters from end of type name + val typeName = getJdbcType(et, dialect).databaseTypeDefinition + .toLowerCase.split("\\(")(0) val array = conn.createArrayOf( - getJdbcType(et, dialect).databaseTypeDefinition.toLowerCase, + typeName, row.getSeq[AnyRef](i).toArray) stmt.setArray(i + 1, array) case _ => throw new IllegalArgumentException( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 8d43966480..2d6c3974a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -32,14 +32,18 @@ private object PostgresDialect extends JdbcDialect { if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) { Some(BinaryType) } else if (sqlType == Types.OTHER) { - toCatalystType(typeName).filter(_ == StringType) - } else if (sqlType == Types.ARRAY && typeName.length > 1 && typeName(0) == '_') { - toCatalystType(typeName.drop(1)).map(ArrayType(_)) + Some(StringType) + } else if (sqlType == Types.ARRAY) { + val scale = md.build.getLong("scale").toInt + // postgres array type names start with underscore + toCatalystType(typeName.drop(1), size, scale).map(ArrayType(_)) } else None } - // TODO: support more type names. - private def toCatalystType(typeName: String): Option[DataType] = typeName match { + private def toCatalystType( + typeName: String, + precision: Int, + scale: Int): Option[DataType] = typeName match { case "bool" => Some(BooleanType) case "bit" => Some(BinaryType) case "int2" => Some(ShortType) @@ -52,7 +56,7 @@ private object PostgresDialect extends JdbcDialect { case "bytea" => Some(BinaryType) case "timestamp" | "timestamptz" | "time" | "timetz" => Some(TimestampType) case "date" => Some(DateType) - case "numeric" => Some(DecimalType.SYSTEM_DEFAULT) + case "numeric" | "decimal" => Some(DecimalType.bounded(precision, scale)) case _ => None } @@ -62,6 +66,8 @@ private object PostgresDialect extends JdbcDialect { case BooleanType => Some(JdbcType("BOOLEAN", Types.BOOLEAN)) case FloatType => Some(JdbcType("FLOAT4", Types.FLOAT)) case DoubleType => Some(JdbcType("FLOAT8", Types.DOUBLE)) + case t: DecimalType => Some( + JdbcType(s"NUMERIC(${t.precision},${t.scale})", java.sql.Types.NUMERIC)) case ArrayType(et, _) if et.isInstanceOf[AtomicType] => getJDBCType(et).map(_.databaseTypeDefinition) .orElse(JdbcUtils.getCommonJDBCType(et).map(_.databaseTypeDefinition)) |