aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala43
1 files changed, 31 insertions, 12 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
index e701a7fcd9..ed3faa1268 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.jdbc
import java.sql.Types
+import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
import org.apache.spark.sql.types._
@@ -29,22 +30,40 @@ private object PostgresDialect extends JdbcDialect {
override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) {
- Option(BinaryType)
- } else if (sqlType == Types.OTHER && typeName.equals("cidr")) {
- Option(StringType)
- } else if (sqlType == Types.OTHER && typeName.equals("inet")) {
- Option(StringType)
- } else if (sqlType == Types.OTHER && typeName.equals("json")) {
- Option(StringType)
- } else if (sqlType == Types.OTHER && typeName.equals("jsonb")) {
- Option(StringType)
+ Some(BinaryType)
+ } else if (sqlType == Types.OTHER) {
+ toCatalystType(typeName).filter(_ == StringType)
+ } else if (sqlType == Types.ARRAY && typeName.length > 1 && typeName(0) == '_') {
+ toCatalystType(typeName.drop(1)).map(ArrayType(_))
} else None
}
+ // TODO: support more type names.
+ private def toCatalystType(typeName: String): Option[DataType] = typeName match {
+ case "bool" => Some(BooleanType)
+ case "bit" => Some(BinaryType)
+ case "int2" => Some(ShortType)
+ case "int4" => Some(IntegerType)
+ case "int8" | "oid" => Some(LongType)
+ case "float4" => Some(FloatType)
+ case "money" | "float8" => Some(DoubleType)
+ case "text" | "varchar" | "char" | "cidr" | "inet" | "json" | "jsonb" | "uuid" =>
+ Some(StringType)
+ case "bytea" => Some(BinaryType)
+ case "timestamp" | "timestamptz" | "time" | "timetz" => Some(TimestampType)
+ case "date" => Some(DateType)
+ case "numeric" => Some(DecimalType.SYSTEM_DEFAULT)
+ case _ => None
+ }
+
override def getJDBCType(dt: DataType): Option[JdbcType] = dt match {
- case StringType => Some(JdbcType("TEXT", java.sql.Types.CHAR))
- case BinaryType => Some(JdbcType("BYTEA", java.sql.Types.BINARY))
- case BooleanType => Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN))
+ case StringType => Some(JdbcType("TEXT", Types.CHAR))
+ case BinaryType => Some(JdbcType("BYTEA", Types.BINARY))
+ case BooleanType => Some(JdbcType("BOOLEAN", Types.BOOLEAN))
+ case ArrayType(et, _) if et.isInstanceOf[AtomicType] =>
+ getJDBCType(et).map(_.databaseTypeDefinition)
+ .orElse(JdbcUtils.getCommonJDBCType(et).map(_.databaseTypeDefinition))
+ .map(typeName => JdbcType(s"$typeName[]", java.sql.Types.ARRAY))
case _ => None
}