diff options
3 files changed, 28 insertions, 5 deletions
diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index 79dd70116e..c9325dea0b 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -22,7 +22,7 @@ import java.util.Properties import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.types.{ArrayType, DecimalType} +import org.apache.spark.sql.types.{ArrayType, DecimalType, FloatType, ShortType} import org.apache.spark.tags.DockerTest @DockerTest @@ -45,10 +45,12 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { conn.prepareStatement("CREATE TYPE enum_type AS ENUM ('d1', 'd2')").executeUpdate() conn.prepareStatement("CREATE TABLE bar (c0 text, c1 integer, c2 double precision, c3 bigint, " + "c4 bit(1), c5 bit(10), c6 bytea, c7 boolean, c8 inet, c9 cidr, " - + "c10 integer[], c11 text[], c12 real[], c13 numeric(2,2)[], c14 enum_type)").executeUpdate() + + "c10 integer[], c11 text[], c12 real[], c13 numeric(2,2)[], c14 enum_type, " + + "c15 float4, c16 smallint)").executeUpdate() conn.prepareStatement("INSERT INTO bar VALUES ('hello', 42, 1.25, 123456789012345, B'0', " + "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16', " - + """'{1, 2}', '{"a", null, "b"}', '{0.11, 0.22}', '{0.11, 0.22}', 'd1')""").executeUpdate() + + """'{1, 2}', '{"a", null, "b"}', '{0.11, 0.22}', '{0.11, 0.22}', 'd1', 1.01, 1)""" + ).executeUpdate() } test("Type mapping for various types") { @@ -56,7 +58,7 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { val rows = df.collect() assert(rows.length == 1) val types = rows(0).toSeq.map(x => x.getClass) - assert(types.length == 15) + assert(types.length == 17) assert(classOf[String].isAssignableFrom(types(0))) assert(classOf[java.lang.Integer].isAssignableFrom(types(1))) assert(classOf[java.lang.Double].isAssignableFrom(types(2))) @@ -72,6 +74,8 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { assert(classOf[Seq[Double]].isAssignableFrom(types(12))) assert(classOf[Seq[BigDecimal]].isAssignableFrom(types(13))) assert(classOf[String].isAssignableFrom(types(14))) + assert(classOf[java.lang.Float].isAssignableFrom(types(15))) + assert(classOf[java.lang.Short].isAssignableFrom(types(16))) assert(rows(0).getString(0).equals("hello")) assert(rows(0).getInt(1) == 42) assert(rows(0).getDouble(2) == 1.25) @@ -90,6 +94,8 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { assert(rows(0).getSeq(12).toSeq == Seq(0.11f, 0.22f)) assert(rows(0).getSeq(13) == Seq("0.11", "0.22").map(BigDecimal(_).bigDecimal)) assert(rows(0).getString(14) == "d1") + assert(rows(0).getFloat(15) == 1.01f) + assert(rows(0).getShort(16) == 1) } test("Basic write test") { @@ -104,4 +110,12 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { Column(Literal.create(null, a.dataType)).as(a.name) }: _*).write.jdbc(jdbcUrl, "public.barcopy2", new Properties) } + + test("Creating a table with shorts and floats") { + sqlContext.createDataFrame(Seq((1.0f, 1.toShort))) + .write.jdbc(jdbcUrl, "shortfloat", new Properties) + val schema = sqlContext.read.jdbc(jdbcUrl, "shortfloat", new Properties).schema + assert(schema(0).dataType == FloatType) + assert(schema(1).dataType == ShortType) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 6dad8cbef7..8d9048ab82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -390,6 +390,10 @@ private[jdbc] class JDBCRDD( (rs: ResultSet, row: MutableRow, pos: Int) => row.setLong(pos, rs.getLong(pos + 1)) + case ShortType => + (rs: ResultSet, row: MutableRow, pos: Int) => + row.setShort(pos, rs.getShort(pos + 1)) + case StringType => (rs: ResultSet, row: MutableRow, pos: Int) => // TODO(davies): use getBytes for better performance, if the encoding is UTF-8 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index fb959d881e..3f540d6258 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -29,7 +29,11 @@ private object PostgresDialect extends JdbcDialect { override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { - if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) { + if (sqlType == Types.REAL) { + Some(FloatType) + } else if (sqlType == Types.SMALLINT) { + Some(ShortType) + } else if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) { Some(BinaryType) } else if (sqlType == Types.OTHER) { Some(StringType) @@ -66,6 +70,7 @@ private object PostgresDialect extends JdbcDialect { case BooleanType => Some(JdbcType("BOOLEAN", Types.BOOLEAN)) case FloatType => Some(JdbcType("FLOAT4", Types.FLOAT)) case DoubleType => Some(JdbcType("FLOAT8", Types.DOUBLE)) + case ShortType => Some(JdbcType("SMALLINT", Types.SMALLINT)) case t: DecimalType => Some( JdbcType(s"NUMERIC(${t.precision},${t.scale})", java.sql.Types.NUMERIC)) case ArrayType(et, _) if et.isInstanceOf[AtomicType] => |