aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2015-11-17 11:29:02 -0800
committerMichael Armbrust <michael@databricks.com>2015-11-17 11:29:02 -0800
commitd9251496640a77568a1e9ed5045ce2dfba4b437b (patch)
treefb13126d187e531922b378d9bd528dd2652f2ab7
parent0158ff7737d10e68be2e289533241da96b496e89 (diff)
downloadspark-d9251496640a77568a1e9ed5045ce2dfba4b437b.tar.gz
spark-d9251496640a77568a1e9ed5045ce2dfba4b437b.tar.bz2
spark-d9251496640a77568a1e9ed5045ce2dfba4b437b.zip
[SPARK-10186][SQL] support postgre array type in JDBCRDD
Add ARRAY support to `PostgresDialect`. Nested ARRAY is not allowed for now because it's hard to get the array dimension info. See http://stackoverflow.com/questions/16619113/how-to-get-array-base-type-in-postgres-via-jdbc Thanks for the initial work from mariusvniekerk ! Close https://github.com/apache/spark/pull/9137 Author: Wenchen Fan <wenchen@databricks.com> Closes #9662 from cloud-fan/postgre.
-rw-r--r--docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala44
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala76
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala77
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala43
5 files changed, 157 insertions, 85 deletions
diff --git a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala
index 164a7f3962..2e18d0a2ba 100644
--- a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala
+++ b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala
@@ -20,6 +20,8 @@ package org.apache.spark.sql.jdbc
import java.sql.Connection
import java.util.Properties
+import org.apache.spark.sql.Column
+import org.apache.spark.sql.catalyst.expressions.{Literal, If}
import org.apache.spark.tags.DockerTest
@DockerTest
@@ -37,28 +39,32 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
override def dataPreparation(conn: Connection): Unit = {
conn.prepareStatement("CREATE DATABASE foo").executeUpdate()
conn.setCatalog("foo")
- conn.prepareStatement("CREATE TABLE bar (a text, b integer, c double precision, d bigint, "
- + "e bit(1), f bit(10), g bytea, h boolean, i inet, j cidr)").executeUpdate()
+ conn.prepareStatement("CREATE TABLE bar (c0 text, c1 integer, c2 double precision, c3 bigint, "
+ + "c4 bit(1), c5 bit(10), c6 bytea, c7 boolean, c8 inet, c9 cidr, "
+ + "c10 integer[], c11 text[])").executeUpdate()
conn.prepareStatement("INSERT INTO bar VALUES ('hello', 42, 1.25, 123456789012345, B'0', "
- + "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16')").executeUpdate()
+ + "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16', "
+ + """'{1, 2}', '{"a", null, "b"}')""").executeUpdate()
}
test("Type mapping for various types") {
val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties)
val rows = df.collect()
assert(rows.length == 1)
- val types = rows(0).toSeq.map(x => x.getClass.toString)
- assert(types.length == 10)
- assert(types(0).equals("class java.lang.String"))
- assert(types(1).equals("class java.lang.Integer"))
- assert(types(2).equals("class java.lang.Double"))
- assert(types(3).equals("class java.lang.Long"))
- assert(types(4).equals("class java.lang.Boolean"))
- assert(types(5).equals("class [B"))
- assert(types(6).equals("class [B"))
- assert(types(7).equals("class java.lang.Boolean"))
- assert(types(8).equals("class java.lang.String"))
- assert(types(9).equals("class java.lang.String"))
+ val types = rows(0).toSeq.map(x => x.getClass)
+ assert(types.length == 12)
+ assert(classOf[String].isAssignableFrom(types(0)))
+ assert(classOf[java.lang.Integer].isAssignableFrom(types(1)))
+ assert(classOf[java.lang.Double].isAssignableFrom(types(2)))
+ assert(classOf[java.lang.Long].isAssignableFrom(types(3)))
+ assert(classOf[java.lang.Boolean].isAssignableFrom(types(4)))
+ assert(classOf[Array[Byte]].isAssignableFrom(types(5)))
+ assert(classOf[Array[Byte]].isAssignableFrom(types(6)))
+ assert(classOf[java.lang.Boolean].isAssignableFrom(types(7)))
+ assert(classOf[String].isAssignableFrom(types(8)))
+ assert(classOf[String].isAssignableFrom(types(9)))
+ assert(classOf[Seq[Int]].isAssignableFrom(types(10)))
+ assert(classOf[Seq[String]].isAssignableFrom(types(11)))
assert(rows(0).getString(0).equals("hello"))
assert(rows(0).getInt(1) == 42)
assert(rows(0).getDouble(2) == 1.25)
@@ -72,11 +78,17 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite {
assert(rows(0).getBoolean(7) == true)
assert(rows(0).getString(8) == "172.16.0.42")
assert(rows(0).getString(9) == "192.168.0.0/16")
+ assert(rows(0).getSeq(10) == Seq(1, 2))
+ assert(rows(0).getSeq(11) == Seq("a", null, "b"))
}
test("Basic write test") {
val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties)
- df.write.jdbc(jdbcUrl, "public.barcopy", new Properties)
// Test only that it doesn't crash.
+ df.write.jdbc(jdbcUrl, "public.barcopy", new Properties)
+ // Test write null values.
+ df.select(df.queryExecution.analyzed.output.map { a =>
+ Column(If(Literal(true), Literal(null), a)).as(a.name)
+ }: _*).write.jdbc(jdbcUrl, "public.barcopy2", new Properties)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index 018a009fbd..89c850ce23 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -25,7 +25,7 @@ import org.apache.commons.lang3.StringUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.catalyst.util.{GenericArrayData, DateTimeUtils}
import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
@@ -324,25 +324,27 @@ private[sql] class JDBCRDD(
case object StringConversion extends JDBCConversion
case object TimestampConversion extends JDBCConversion
case object BinaryConversion extends JDBCConversion
+ case class ArrayConversion(elementConversion: JDBCConversion) extends JDBCConversion
/**
* Maps a StructType to a type tag list.
*/
- def getConversions(schema: StructType): Array[JDBCConversion] = {
- schema.fields.map(sf => sf.dataType match {
- case BooleanType => BooleanConversion
- case DateType => DateConversion
- case DecimalType.Fixed(p, s) => DecimalConversion(p, s)
- case DoubleType => DoubleConversion
- case FloatType => FloatConversion
- case IntegerType => IntegerConversion
- case LongType =>
- if (sf.metadata.contains("binarylong")) BinaryLongConversion else LongConversion
- case StringType => StringConversion
- case TimestampType => TimestampConversion
- case BinaryType => BinaryConversion
- case _ => throw new IllegalArgumentException(s"Unsupported field $sf")
- }).toArray
+ def getConversions(schema: StructType): Array[JDBCConversion] =
+ schema.fields.map(sf => getConversions(sf.dataType, sf.metadata))
+
+ private def getConversions(dt: DataType, metadata: Metadata): JDBCConversion = dt match {
+ case BooleanType => BooleanConversion
+ case DateType => DateConversion
+ case DecimalType.Fixed(p, s) => DecimalConversion(p, s)
+ case DoubleType => DoubleConversion
+ case FloatType => FloatConversion
+ case IntegerType => IntegerConversion
+ case LongType => if (metadata.contains("binarylong")) BinaryLongConversion else LongConversion
+ case StringType => StringConversion
+ case TimestampType => TimestampConversion
+ case BinaryType => BinaryConversion
+ case ArrayType(et, _) => ArrayConversion(getConversions(et, metadata))
+ case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.simpleString}")
}
/**
@@ -420,16 +422,44 @@ private[sql] class JDBCRDD(
mutableRow.update(i, null)
}
case BinaryConversion => mutableRow.update(i, rs.getBytes(pos))
- case BinaryLongConversion => {
+ case BinaryLongConversion =>
val bytes = rs.getBytes(pos)
var ans = 0L
var j = 0
while (j < bytes.size) {
ans = 256 * ans + (255 & bytes(j))
- j = j + 1;
+ j = j + 1
}
mutableRow.setLong(i, ans)
- }
+ case ArrayConversion(elementConversion) =>
+ val array = rs.getArray(pos).getArray
+ if (array != null) {
+ val data = elementConversion match {
+ case TimestampConversion =>
+ array.asInstanceOf[Array[java.sql.Timestamp]].map { timestamp =>
+ nullSafeConvert(timestamp, DateTimeUtils.fromJavaTimestamp)
+ }
+ case StringConversion =>
+ array.asInstanceOf[Array[java.lang.String]]
+ .map(UTF8String.fromString)
+ case DateConversion =>
+ array.asInstanceOf[Array[java.sql.Date]].map { date =>
+ nullSafeConvert(date, DateTimeUtils.fromJavaDate)
+ }
+ case DecimalConversion(p, s) =>
+ array.asInstanceOf[Array[java.math.BigDecimal]].map { decimal =>
+ nullSafeConvert[java.math.BigDecimal](decimal, d => Decimal(d, p, s))
+ }
+ case BinaryLongConversion =>
+ throw new IllegalArgumentException(s"Unsupported array element conversion $i")
+ case _: ArrayConversion =>
+ throw new IllegalArgumentException("Nested arrays unsupported")
+ case _ => array.asInstanceOf[Array[Any]]
+ }
+ mutableRow.update(i, new GenericArrayData(data))
+ } else {
+ mutableRow.update(i, null)
+ }
}
if (rs.wasNull) mutableRow.setNullAt(i)
i = i + 1
@@ -488,4 +518,12 @@ private[sql] class JDBCRDD(
nextValue
}
}
+
+ private def nullSafeConvert[T](input: T, f: T => Any): Any = {
+ if (input == null) {
+ null
+ } else {
+ f(input)
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index f89d55b20e..32d28e5937 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -23,7 +23,7 @@ import java.util.Properties
import scala.util.Try
import org.apache.spark.Logging
-import org.apache.spark.sql.jdbc.JdbcDialects
+import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcType, JdbcDialects}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row}
@@ -73,6 +73,35 @@ object JdbcUtils extends Logging {
}
/**
+ * Retrieve standard jdbc types.
+ * @param dt The datatype (e.g. [[org.apache.spark.sql.types.StringType]])
+ * @return The default JdbcType for this DataType
+ */
+ def getCommonJDBCType(dt: DataType): Option[JdbcType] = {
+ dt match {
+ case IntegerType => Option(JdbcType("INTEGER", java.sql.Types.INTEGER))
+ case LongType => Option(JdbcType("BIGINT", java.sql.Types.BIGINT))
+ case DoubleType => Option(JdbcType("DOUBLE PRECISION", java.sql.Types.DOUBLE))
+ case FloatType => Option(JdbcType("REAL", java.sql.Types.FLOAT))
+ case ShortType => Option(JdbcType("INTEGER", java.sql.Types.SMALLINT))
+ case ByteType => Option(JdbcType("BYTE", java.sql.Types.TINYINT))
+ case BooleanType => Option(JdbcType("BIT(1)", java.sql.Types.BIT))
+ case StringType => Option(JdbcType("TEXT", java.sql.Types.CLOB))
+ case BinaryType => Option(JdbcType("BLOB", java.sql.Types.BLOB))
+ case TimestampType => Option(JdbcType("TIMESTAMP", java.sql.Types.TIMESTAMP))
+ case DateType => Option(JdbcType("DATE", java.sql.Types.DATE))
+ case t: DecimalType => Option(
+ JdbcType(s"DECIMAL(${t.precision},${t.scale})", java.sql.Types.DECIMAL))
+ case _ => None
+ }
+ }
+
+ private def getJdbcType(dt: DataType, dialect: JdbcDialect): JdbcType = {
+ dialect.getJDBCType(dt).orElse(getCommonJDBCType(dt)).getOrElse(
+ throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.simpleString}"))
+ }
+
+ /**
* Saves a partition of a DataFrame to the JDBC database. This is done in
* a single database transaction in order to avoid repeatedly inserting
* data as much as possible.
@@ -92,7 +121,8 @@ object JdbcUtils extends Logging {
iterator: Iterator[Row],
rddSchema: StructType,
nullTypes: Array[Int],
- batchSize: Int): Iterator[Byte] = {
+ batchSize: Int,
+ dialect: JdbcDialect): Iterator[Byte] = {
val conn = getConnection()
var committed = false
try {
@@ -121,6 +151,11 @@ object JdbcUtils extends Logging {
case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i))
case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i))
case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i))
+ case ArrayType(et, _) =>
+ val array = conn.createArrayOf(
+ getJdbcType(et, dialect).databaseTypeDefinition.toLowerCase,
+ row.getSeq[AnyRef](i).toArray)
+ stmt.setArray(i + 1, array)
case _ => throw new IllegalArgumentException(
s"Can't translate non-null value for field $i")
}
@@ -169,23 +204,7 @@ object JdbcUtils extends Logging {
val dialect = JdbcDialects.get(url)
df.schema.fields foreach { field => {
val name = field.name
- val typ: String =
- dialect.getJDBCType(field.dataType).map(_.databaseTypeDefinition).getOrElse(
- field.dataType match {
- case IntegerType => "INTEGER"
- case LongType => "BIGINT"
- case DoubleType => "DOUBLE PRECISION"
- case FloatType => "REAL"
- case ShortType => "INTEGER"
- case ByteType => "BYTE"
- case BooleanType => "BIT(1)"
- case StringType => "TEXT"
- case BinaryType => "BLOB"
- case TimestampType => "TIMESTAMP"
- case DateType => "DATE"
- case t: DecimalType => s"DECIMAL(${t.precision},${t.scale})"
- case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC")
- })
+ val typ: String = getJdbcType(field.dataType, dialect).databaseTypeDefinition
val nullable = if (field.nullable) "" else "NOT NULL"
sb.append(s", $name $typ $nullable")
}}
@@ -202,23 +221,7 @@ object JdbcUtils extends Logging {
properties: Properties = new Properties()) {
val dialect = JdbcDialects.get(url)
val nullTypes: Array[Int] = df.schema.fields.map { field =>
- dialect.getJDBCType(field.dataType).map(_.jdbcNullType).getOrElse(
- field.dataType match {
- case IntegerType => java.sql.Types.INTEGER
- case LongType => java.sql.Types.BIGINT
- case DoubleType => java.sql.Types.DOUBLE
- case FloatType => java.sql.Types.REAL
- case ShortType => java.sql.Types.INTEGER
- case ByteType => java.sql.Types.INTEGER
- case BooleanType => java.sql.Types.BIT
- case StringType => java.sql.Types.CLOB
- case BinaryType => java.sql.Types.BLOB
- case TimestampType => java.sql.Types.TIMESTAMP
- case DateType => java.sql.Types.DATE
- case t: DecimalType => java.sql.Types.DECIMAL
- case _ => throw new IllegalArgumentException(
- s"Can't translate null value for field $field")
- })
+ getJdbcType(field.dataType, dialect).jdbcNullType
}
val rddSchema = df.schema
@@ -226,7 +229,7 @@ object JdbcUtils extends Logging {
val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties)
val batchSize = properties.getProperty("batchsize", "1000").toInt
df.foreachPartition { iterator =>
- savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize)
+ savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
index 14bfea4e3e..b3b2cb6178 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
@@ -51,7 +51,7 @@ case class JdbcType(databaseTypeDefinition : String, jdbcNullType : Int)
* for the given Catalyst type.
*/
@DeveloperApi
-abstract class JdbcDialect {
+abstract class JdbcDialect extends Serializable {
/**
* Check if this dialect instance can handle a certain jdbc url.
* @param url the jdbc url.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
index e701a7fcd9..ed3faa1268 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.jdbc
import java.sql.Types
+import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
import org.apache.spark.sql.types._
@@ -29,22 +30,40 @@ private object PostgresDialect extends JdbcDialect {
override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) {
- Option(BinaryType)
- } else if (sqlType == Types.OTHER && typeName.equals("cidr")) {
- Option(StringType)
- } else if (sqlType == Types.OTHER && typeName.equals("inet")) {
- Option(StringType)
- } else if (sqlType == Types.OTHER && typeName.equals("json")) {
- Option(StringType)
- } else if (sqlType == Types.OTHER && typeName.equals("jsonb")) {
- Option(StringType)
+ Some(BinaryType)
+ } else if (sqlType == Types.OTHER) {
+ toCatalystType(typeName).filter(_ == StringType)
+ } else if (sqlType == Types.ARRAY && typeName.length > 1 && typeName(0) == '_') {
+ toCatalystType(typeName.drop(1)).map(ArrayType(_))
} else None
}
+ // TODO: support more type names.
+ private def toCatalystType(typeName: String): Option[DataType] = typeName match {
+ case "bool" => Some(BooleanType)
+ case "bit" => Some(BinaryType)
+ case "int2" => Some(ShortType)
+ case "int4" => Some(IntegerType)
+ case "int8" | "oid" => Some(LongType)
+ case "float4" => Some(FloatType)
+ case "money" | "float8" => Some(DoubleType)
+ case "text" | "varchar" | "char" | "cidr" | "inet" | "json" | "jsonb" | "uuid" =>
+ Some(StringType)
+ case "bytea" => Some(BinaryType)
+ case "timestamp" | "timestamptz" | "time" | "timetz" => Some(TimestampType)
+ case "date" => Some(DateType)
+ case "numeric" => Some(DecimalType.SYSTEM_DEFAULT)
+ case _ => None
+ }
+
override def getJDBCType(dt: DataType): Option[JdbcType] = dt match {
- case StringType => Some(JdbcType("TEXT", java.sql.Types.CHAR))
- case BinaryType => Some(JdbcType("BYTEA", java.sql.Types.BINARY))
- case BooleanType => Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN))
+ case StringType => Some(JdbcType("TEXT", Types.CHAR))
+ case BinaryType => Some(JdbcType("BYTEA", Types.BINARY))
+ case BooleanType => Some(JdbcType("BOOLEAN", Types.BOOLEAN))
+ case ArrayType(et, _) if et.isInstanceOf[AtomicType] =>
+ getJDBCType(et).map(_.databaseTypeDefinition)
+ .orElse(JdbcUtils.getCommonJDBCType(et).map(_.databaseTypeDefinition))
+ .map(typeName => JdbcType(s"$typeName[]", java.sql.Types.ARRAY))
case _ => None
}