From d9251496640a77568a1e9ed5045ce2dfba4b437b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 17 Nov 2015 11:29:02 -0800 Subject: [SPARK-10186][SQL] support postgre array type in JDBCRDD Add ARRAY support to `PostgresDialect`. Nested ARRAY is not allowed for now because it's hard to get the array dimension info. See http://stackoverflow.com/questions/16619113/how-to-get-array-base-type-in-postgres-via-jdbc Thanks for the initial work from mariusvniekerk ! Close https://github.com/apache/spark/pull/9137 Author: Wenchen Fan Closes #9662 from cloud-fan/postgre. --- .../sql/execution/datasources/jdbc/JDBCRDD.scala | 76 +++++++++++++++------ .../sql/execution/datasources/jdbc/JdbcUtils.scala | 77 +++++++++++----------- .../org/apache/spark/sql/jdbc/JdbcDialects.scala | 2 +- .../apache/spark/sql/jdbc/PostgresDialect.scala | 43 ++++++++---- 4 files changed, 129 insertions(+), 69 deletions(-) (limited to 'sql/core/src/main/scala/org') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 018a009fbd..89c850ce23 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -25,7 +25,7 @@ import org.apache.commons.lang3.StringUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{GenericArrayData, DateTimeUtils} import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -324,25 +324,27 @@ private[sql] class JDBCRDD( case object StringConversion extends JDBCConversion case object TimestampConversion extends JDBCConversion case object BinaryConversion extends JDBCConversion + case class ArrayConversion(elementConversion: JDBCConversion) extends JDBCConversion /** * Maps a StructType to a type tag list. */ - def getConversions(schema: StructType): Array[JDBCConversion] = { - schema.fields.map(sf => sf.dataType match { - case BooleanType => BooleanConversion - case DateType => DateConversion - case DecimalType.Fixed(p, s) => DecimalConversion(p, s) - case DoubleType => DoubleConversion - case FloatType => FloatConversion - case IntegerType => IntegerConversion - case LongType => - if (sf.metadata.contains("binarylong")) BinaryLongConversion else LongConversion - case StringType => StringConversion - case TimestampType => TimestampConversion - case BinaryType => BinaryConversion - case _ => throw new IllegalArgumentException(s"Unsupported field $sf") - }).toArray + def getConversions(schema: StructType): Array[JDBCConversion] = + schema.fields.map(sf => getConversions(sf.dataType, sf.metadata)) + + private def getConversions(dt: DataType, metadata: Metadata): JDBCConversion = dt match { + case BooleanType => BooleanConversion + case DateType => DateConversion + case DecimalType.Fixed(p, s) => DecimalConversion(p, s) + case DoubleType => DoubleConversion + case FloatType => FloatConversion + case IntegerType => IntegerConversion + case LongType => if (metadata.contains("binarylong")) BinaryLongConversion else LongConversion + case StringType => StringConversion + case TimestampType => TimestampConversion + case BinaryType => BinaryConversion + case ArrayType(et, _) => ArrayConversion(getConversions(et, metadata)) + case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.simpleString}") } /** @@ -420,16 +422,44 @@ private[sql] class JDBCRDD( mutableRow.update(i, null) } case BinaryConversion => mutableRow.update(i, rs.getBytes(pos)) - case BinaryLongConversion => { + case BinaryLongConversion => val bytes = rs.getBytes(pos) var ans = 0L var j = 0 while (j < bytes.size) { ans = 256 * ans + (255 & bytes(j)) - j = j + 1; + j = j + 1 } mutableRow.setLong(i, ans) - } + case ArrayConversion(elementConversion) => + val array = rs.getArray(pos).getArray + if (array != null) { + val data = elementConversion match { + case TimestampConversion => + array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp => + nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp) + } + case StringConversion => + array.asInstanceOf[Array[java.lang.String]] + .map(UTF8String.fromString) + case DateConversion => + array.asInstanceOf[Array[java.sql.Date]].map { date => + nullSafeConvert(date, DateTimeUtils.fromJavaDate) + } + case DecimalConversion(p, s) => + array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal => + nullSafeConvert[java.math.BigDecimal](decimal, d => Decimal(d, p, s)) + } + case BinaryLongConversion => + throw new IllegalArgumentException(s"Unsupported array element conversion $i") + case _: ArrayConversion => + throw new IllegalArgumentException("Nested arrays unsupported") + case _ => array.asInstanceOf[Array[Any]] + } + mutableRow.update(i, new GenericArrayData(data)) + } else { + mutableRow.update(i, null) + } } if (rs.wasNull) mutableRow.setNullAt(i) i = i + 1 @@ -488,4 +518,12 @@ private[sql] class JDBCRDD( nextValue } } + + private def nullSafeConvert[T](input: T, f: T => Any): Any = { + if (input == null) { + null + } else { + f(input) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index f89d55b20e..32d28e5937 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -23,7 +23,7 @@ import java.util.Properties import scala.util.Try import org.apache.spark.Logging -import org.apache.spark.sql.jdbc.JdbcDialects +import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcType, JdbcDialects} import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, Row} @@ -72,6 +72,35 @@ object JdbcUtils extends Logging { conn.prepareStatement(sql.toString()) } + /** + * Retrieve standard jdbc types. + * @param dt The datatype (e.g. [[org.apache.spark.sql.types.StringType]]) + * @return The default JdbcType for this DataType + */ + def getCommonJDBCType(dt: DataType): Option[JdbcType] = { + dt match { + case IntegerType => Option(JdbcType("INTEGER", java.sql.Types.INTEGER)) + case LongType => Option(JdbcType("BIGINT", java.sql.Types.BIGINT)) + case DoubleType => Option(JdbcType("DOUBLE PRECISION", java.sql.Types.DOUBLE)) + case FloatType => Option(JdbcType("REAL", java.sql.Types.FLOAT)) + case ShortType => Option(JdbcType("INTEGER", java.sql.Types.SMALLINT)) + case ByteType => Option(JdbcType("BYTE", java.sql.Types.TINYINT)) + case BooleanType => Option(JdbcType("BIT(1)", java.sql.Types.BIT)) + case StringType => Option(JdbcType("TEXT", java.sql.Types.CLOB)) + case BinaryType => Option(JdbcType("BLOB", java.sql.Types.BLOB)) + case TimestampType => Option(JdbcType("TIMESTAMP", java.sql.Types.TIMESTAMP)) + case DateType => Option(JdbcType("DATE", java.sql.Types.DATE)) + case t: DecimalType => Option( + JdbcType(s"DECIMAL(${t.precision},${t.scale})", java.sql.Types.DECIMAL)) + case _ => None + } + } + + private def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = { + dialect.getJDBCType(dt).orElse(getCommonJDBCType(dt)).getOrElse( + throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.simpleString}")) + } + /** * Saves a partition of a DataFrame to the JDBC database. This is done in * a single database transaction in order to avoid repeatedly inserting @@ -92,7 +121,8 @@ object JdbcUtils extends Logging { iterator: Iterator[Row], rddSchema: StructType, nullTypes: Array[Int], - batchSize: Int): Iterator[Byte] = { + batchSize: Int, + dialect: JdbcDialect): Iterator[Byte] = { val conn = getConnection() var committed = false try { @@ -121,6 +151,11 @@ object JdbcUtils extends Logging { case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i)) case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i)) case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i)) + case ArrayType(et, _) => + val array = conn.createArrayOf( + getJdbcType(et, dialect).databaseTypeDefinition.toLowerCase, + row.getSeq[AnyRef](i).toArray) + stmt.setArray(i + 1, array) case _ => throw new IllegalArgumentException( s"Can't translate non-null value for field $i") } @@ -169,23 +204,7 @@ object JdbcUtils extends Logging { val dialect = JdbcDialects.get(url) df.schema.fields foreach { field => { val name = field.name - val typ: String = - dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse( - field.dataType match { - case IntegerType => "INTEGER" - case LongType => "BIGINT" - case DoubleType => "DOUBLE PRECISION" - case FloatType => "REAL" - case ShortType => "INTEGER" - case ByteType => "BYTE" - case BooleanType => "BIT(1)" - case StringType => "TEXT" - case BinaryType => "BLOB" - case TimestampType => "TIMESTAMP" - case DateType => "DATE" - case t: DecimalType => s"DECIMAL(${t.precision},${t.scale})" - case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC") - }) + val typ: String = getJdbcType(field.dataType, dialect).databaseTypeDefinition val nullable = if (field.nullable) "" else "NOT NULL" sb.append(s", $name $typ $nullable") }} @@ -202,23 +221,7 @@ object JdbcUtils extends Logging { properties: Properties = new Properties()) { val dialect = JdbcDialects.get(url) val nullTypes: Array[Int] = df.schema.fields.map { field => - dialect.getJDBCType(field.dataType).map(_.jdbcNullType).getOrElse( - field.dataType match { - case IntegerType => java.sql.Types.INTEGER - case LongType => java.sql.Types.BIGINT - case DoubleType => java.sql.Types.DOUBLE - case FloatType => java.sql.Types.REAL - case ShortType => java.sql.Types.INTEGER - case ByteType => java.sql.Types.INTEGER - case BooleanType => java.sql.Types.BIT - case StringType => java.sql.Types.CLOB - case BinaryType => java.sql.Types.BLOB - case TimestampType => java.sql.Types.TIMESTAMP - case DateType => java.sql.Types.DATE - case t: DecimalType => java.sql.Types.DECIMAL - case _ => throw new IllegalArgumentException( - s"Can't translate null value for field $field") - }) + getJdbcType(field.dataType, dialect).jdbcNullType } val rddSchema = df.schema @@ -226,7 +229,7 @@ object JdbcUtils extends Logging { val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties) val batchSize = properties.getProperty("batchsize", "1000").toInt df.foreachPartition { iterator => - savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize) + savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 14bfea4e3e..b3b2cb6178 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -51,7 +51,7 @@ case class JdbcType(databaseTypeDefinition : String, jdbcNullType : Int) * for the given Catalyst type. */ @DeveloperApi -abstract class JdbcDialect { +abstract class JdbcDialect extends Serializable { /** * Check if this dialect instance can handle a certain jdbc url. * @param url the jdbc url. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index e701a7fcd9..ed3faa1268 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.jdbc import java.sql.Types +import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils import org.apache.spark.sql.types._ @@ -29,22 +30,40 @@ private object PostgresDialect extends JdbcDialect { override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) { - Option(BinaryType) - } else if (sqlType == Types.OTHER && typeName.equals("cidr")) { - Option(StringType) - } else if (sqlType == Types.OTHER && typeName.equals("inet")) { - Option(StringType) - } else if (sqlType == Types.OTHER && typeName.equals("json")) { - Option(StringType) - } else if (sqlType == Types.OTHER && typeName.equals("jsonb")) { - Option(StringType) + Some(BinaryType) + } else if (sqlType == Types.OTHER) { + toCatalystType(typeName).filter(_ == StringType) + } else if (sqlType == Types.ARRAY && typeName.length > 1 && typeName(0) == '_') { + toCatalystType(typeName.drop(1)).map(ArrayType(_)) } else None } + // TODO: support more type names. + private def toCatalystType(typeName: String): Option[DataType] = typeName match { + case "bool" => Some(BooleanType) + case "bit" => Some(BinaryType) + case "int2" => Some(ShortType) + case "int4" => Some(IntegerType) + case "int8" | "oid" => Some(LongType) + case "float4" => Some(FloatType) + case "money" | "float8" => Some(DoubleType) + case "text" | "varchar" | "char" | "cidr" | "inet" | "json" | "jsonb" | "uuid" => + Some(StringType) + case "bytea" => Some(BinaryType) + case "timestamp" | "timestamptz" | "time" | "timetz" => Some(TimestampType) + case "date" => Some(DateType) + case "numeric" => Some(DecimalType.SYSTEM_DEFAULT) + case _ => None + } + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { - case StringType => Some(JdbcType("TEXT", java.sql.Types.CHAR)) - case BinaryType => Some(JdbcType("BYTEA", java.sql.Types.BINARY)) - case BooleanType => Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN)) + case StringType => Some(JdbcType("TEXT", Types.CHAR)) + case BinaryType => Some(JdbcType("BYTEA", Types.BINARY)) + case BooleanType => Some(JdbcType("BOOLEAN", Types.BOOLEAN)) + case ArrayType(et, _) if et.isInstanceOf[AtomicType] => + getJDBCType(et).map(_.databaseTypeDefinition) + .orElse(JdbcUtils.getCommonJDBCType(et).map(_.databaseTypeDefinition)) + .map(typeName => JdbcType(s"$typeName[]", java.sql.Types.ARRAY)) case _ => None } -- cgit v1.2.3