aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorTor Myklebust <tmyklebu@gmail.com>2015-02-02 19:50:14 -0800
committerMichael Armbrust <michael@databricks.com>2015-02-02 19:50:14 -0800
commit8f471a66db0571a76a21c0d93312197fee16174a (patch)
treefcb1817c8b074956a7af39aa36b0f1629bef5483 /sql
parent1bcd46574e442e20f55709d70573f271ce44e5b9 (diff)
downloadspark-8f471a66db0571a76a21c0d93312197fee16174a.tar.gz
spark-8f471a66db0571a76a21c0d93312197fee16174a.tar.bz2
spark-8f471a66db0571a76a21c0d93312197fee16174a.zip
[SPARK-5472][SQL] A JDBC data source for Spark SQL.
This pull request contains a Spark SQL data source that can pull data from, and can put data into, a JDBC database. I have tested both read and write support with H2, MySQL, and Postgres. It would surprise me if both read and write support worked flawlessly out-of-the-box for any other database; different databases have different names for different JDBC data types and different meanings for SQL types with the same name. However, this code is designed (see `DriverQuirks.scala`) to make it *relatively* painless to add support for another database by augmenting the type mapping contained in this PR. Author: Tor Myklebust <tmyklebu@gmail.com> Closes #4261 from tmyklebu/master and squashes the following commits: cf167ce [Tor Myklebust] Work around other Java tests ruining TestSQLContext. 67893bf [Tor Myklebust] Move the jdbcRDD methods into SQLContext itself. 585f95b [Tor Myklebust] Dependencies go into the project's pom.xml. 829d5ba [Tor Myklebust] Merge branch 'master' of https://github.com/apache/spark 41647ef [Tor Myklebust] Hide a couple things that don't need to be public. 7318aea [Tor Myklebust] Fix scalastyle warnings. a09eeac [Tor Myklebust] JDBC data source for Spark SQL. 176bb98 [Tor Myklebust] Add test deps for JDBC support.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/pom.xml24
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/jdbc/JDBCUtils.java59
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala49
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala99
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala417
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala133
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/jdbc/JavaJDBCTrampoline.scala30
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala235
-rw-r--r--sql/core/src/test/java/org/apache/spark/sql/jdbc/JavaJDBCTest.java102
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/jdbc/DockerHacks.scala51
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala248
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala107
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegration.scala235
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegration.scala149
14 files changed, 1937 insertions, 1 deletions
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index 3e9ef07df9..1a0c77d282 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -76,6 +76,30 @@
<artifactId>scalacheck_${scala.binary.version}</artifactId>
<scope>test</scope>
</dependency>
+ <dependency>
+ <groupId>com.h2database</groupId>
+ <artifactId>h2</artifactId>
+ <version>1.4.183</version>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>mysql</groupId>
+ <artifactId>mysql-connector-java</artifactId>
+ <version>5.1.34</version>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.postgresql</groupId>
+ <artifactId>postgresql</artifactId>
+ <version>9.3-1102-jdbc41</version>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>com.spotify</groupId>
+ <artifactId>docker-client</artifactId>
+ <version>2.7.5</version>
+ <scope>test</scope>
+ </dependency>
</dependencies>
<build>
<outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
diff --git a/sql/core/src/main/java/org/apache/spark/sql/jdbc/JDBCUtils.java b/sql/core/src/main/java/org/apache/spark/sql/jdbc/JDBCUtils.java
new file mode 100644
index 0000000000..aa441b2096
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/jdbc/JDBCUtils.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.jdbc;
+
+import org.apache.spark.Partition;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.DataFrame;
+
+public class JDBCUtils {
+ /**
+ * Construct a DataFrame representing the JDBC table at the database
+ * specified by url with table name table.
+ */
+ public static DataFrame jdbcRDD(SQLContext sql, String url, String table) {
+ Partition[] parts = new Partition[1];
+ parts[0] = new JDBCPartition(null, 0);
+ return sql.baseRelationToDataFrame(
+ new JDBCRelation(url, table, parts, sql));
+ }
+
+ /**
+ * Construct a DataFrame representing the JDBC table at the database
+ * specified by url with table name table partitioned by parts.
+ * Here, parts is an array of expressions suitable for insertion into a WHERE
+ * clause; each one defines one partition.
+ */
+ public static DataFrame jdbcRDD(SQLContext sql, String url, String table, String[] parts) {
+ Partition[] partitions = new Partition[parts.length];
+ for (int i = 0; i < parts.length; i++)
+ partitions[i] = new JDBCPartition(parts[i], i);
+ return sql.baseRelationToDataFrame(
+ new JDBCRelation(url, table, partitions, sql));
+ }
+
+ private static JavaJDBCTrampoline trampoline = new JavaJDBCTrampoline();
+
+ public static void createJDBCTable(DataFrame rdd, String url, String table, boolean allowExisting) {
+ trampoline.createJDBCTable(rdd, url, table, allowExisting);
+ }
+
+ public static void insertIntoJDBC(DataFrame rdd, String url, String table, boolean overwrite) {
+ trampoline.insertIntoJDBC(rdd, url, table, overwrite);
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index d0bbb5f7a3..f4692b3ff5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -24,7 +24,7 @@ import scala.collection.immutable
import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag
-import org.apache.spark.SparkContext
+import org.apache.spark.{SparkContext, Partition}
import org.apache.spark.annotation.{AlphaComponent, DeveloperApi, Experimental}
import org.apache.spark.api.java.{JavaSparkContext, JavaRDD}
import org.apache.spark.rdd.RDD
@@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.execution._
import org.apache.spark.sql.json._
+import org.apache.spark.sql.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation}
import org.apache.spark.sql.sources.{LogicalRelation, BaseRelation, DDLParser, DataSourceStrategy}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -335,6 +336,52 @@ class SQLContext(@transient val sparkContext: SparkContext)
}
/**
+ * :: Experimental ::
+ * Construct an RDD representing the database table accessible via JDBC URL
+ * url named table.
+ */
+ @Experimental
+ def jdbcRDD(url: String, table: String): DataFrame = {
+ jdbcRDD(url, table, null.asInstanceOf[JDBCPartitioningInfo])
+ }
+
+ /**
+ * :: Experimental ::
+ * Construct an RDD representing the database table accessible via JDBC URL
+ * url named table. The PartitioningInfo parameter
+ * gives the name of a column of integral type, a number of partitions, and
+ * advisory minimum and maximum values for the column. The RDD is
+ * partitioned according to said column.
+ */
+ @Experimental
+ def jdbcRDD(url: String, table: String, partitioning: JDBCPartitioningInfo):
+ DataFrame = {
+ val parts = JDBCRelation.columnPartition(partitioning)
+ jdbcRDD(url, table, parts)
+ }
+
+ /**
+ * :: Experimental ::
+ * Construct an RDD representing the database table accessible via JDBC URL
+ * url named table. The theParts parameter gives a list expressions
+ * suitable for inclusion in WHERE clauses; each one defines one partition
+ * of the RDD.
+ */
+ @Experimental
+ def jdbcRDD(url: String, table: String, theParts: Array[String]):
+ DataFrame = {
+ val parts: Array[Partition] = theParts.zipWithIndex.map(
+ x => JDBCPartition(x._1, x._2).asInstanceOf[Partition])
+ jdbcRDD(url, table, parts)
+ }
+
+ private def jdbcRDD(url: String, table: String, parts: Array[Partition]):
+ DataFrame = {
+ val relation = JDBCRelation(url, table, parts)(this)
+ baseRelationToDataFrame(relation)
+ }
+
+ /**
* Registers the given RDD as a temporary table in the catalog. Temporary tables exist only
* during the lifetime of this instance of SQLContext.
*
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala
new file mode 100644
index 0000000000..1704be7fcb
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DriverQuirks.scala
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.jdbc
+
+import org.apache.spark.sql.types._
+
+import java.sql.Types
+
+
+/**
+ * Encapsulates workarounds for the extensions, quirks, and bugs in various
+ * databases. Lots of databases define types that aren't explicitly supported
+ * by the JDBC spec. Some JDBC drivers also report inaccurate
+ * information---for instance, BIT(n>1) being reported as a BIT type is quite
+ * common, even though BIT in JDBC is meant for single-bit values. Also, there
+ * does not appear to be a standard name for an unbounded string or binary
+ * type; we use BLOB and CLOB by default but override with database-specific
+ * alternatives when these are absent or do not behave correctly.
+ *
+ * Currently, the only thing DriverQuirks does is handle type mapping.
+ * `getCatalystType` is used when reading from a JDBC table and `getJDBCType`
+ * is used when writing to a JDBC table. If `getCatalystType` returns `null`,
+ * the default type handling is used for the given JDBC type. Similarly,
+ * if `getJDBCType` returns `(null, None)`, the default type handling is used
+ * for the given Catalyst type.
+ */
+private[sql] abstract class DriverQuirks {
+ def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType
+ def getJDBCType(dt: DataType): (String, Option[Int])
+}
+
+private[sql] object DriverQuirks {
+ /**
+ * Fetch the DriverQuirks class corresponding to a given database url.
+ */
+ def get(url: String): DriverQuirks = {
+ if (url.substring(0, 10).equals("jdbc:mysql")) {
+ new MySQLQuirks()
+ } else if (url.substring(0, 15).equals("jdbc:postgresql")) {
+ new PostgresQuirks()
+ } else {
+ new NoQuirks()
+ }
+ }
+}
+
+private[sql] class NoQuirks extends DriverQuirks {
+ def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType =
+ null
+ def getJDBCType(dt: DataType): (String, Option[Int]) = (null, None)
+}
+
+private[sql] class PostgresQuirks extends DriverQuirks {
+ def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = {
+ if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) {
+ BinaryType
+ } else if (sqlType == Types.OTHER && typeName.equals("cidr")) {
+ StringType
+ } else if (sqlType == Types.OTHER && typeName.equals("inet")) {
+ StringType
+ } else null
+ }
+
+ def getJDBCType(dt: DataType): (String, Option[Int]) = dt match {
+ case StringType => ("TEXT", Some(java.sql.Types.CHAR))
+ case BinaryType => ("BYTEA", Some(java.sql.Types.BINARY))
+ case BooleanType => ("BOOLEAN", Some(java.sql.Types.BOOLEAN))
+ case _ => (null, None)
+ }
+}
+
+private[sql] class MySQLQuirks extends DriverQuirks {
+ def getCatalystType(sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): DataType = {
+ if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) {
+ // This could instead be a BinaryType if we'd rather return bit-vectors of up to 64 bits as
+ // byte arrays instead of longs.
+ md.putLong("binarylong", 1)
+ LongType
+ } else if (sqlType == Types.BIT && typeName.equals("TINYINT")) {
+ BooleanType
+ } else null
+ }
+ def getJDBCType(dt: DataType): (String, Option[Int]) = (null, None)
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
new file mode 100644
index 0000000000..a2f94675fb
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
@@ -0,0 +1,417 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.jdbc
+
+import java.sql.{Connection, DatabaseMetaData, DriverManager, ResultSet, ResultSetMetaData, SQLException}
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.{Logging, Partition, SparkContext, TaskContext}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.util.NextIterator
+import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion
+import org.apache.spark.sql.catalyst.expressions.{Row, SpecificMutableRow}
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.sources._
+
+private[sql] object JDBCRDD extends Logging {
+ /**
+ * Maps a JDBC type to a Catalyst type. This function is called only when
+ * the DriverQuirks class corresponding to your database driver returns null.
+ *
+ * @param sqlType - A field of java.sql.Types
+ * @return The Catalyst type corresponding to sqlType.
+ */
+ private def getCatalystType(sqlType: Int): DataType = {
+ val answer = sqlType match {
+ case java.sql.Types.ARRAY => null
+ case java.sql.Types.BIGINT => LongType
+ case java.sql.Types.BINARY => BinaryType
+ case java.sql.Types.BIT => BooleanType // Per JDBC; Quirks handles quirky drivers.
+ case java.sql.Types.BLOB => BinaryType
+ case java.sql.Types.BOOLEAN => BooleanType
+ case java.sql.Types.CHAR => StringType
+ case java.sql.Types.CLOB => StringType
+ case java.sql.Types.DATALINK => null
+ case java.sql.Types.DATE => DateType
+ case java.sql.Types.DECIMAL => DecimalType.Unlimited
+ case java.sql.Types.DISTINCT => null
+ case java.sql.Types.DOUBLE => DoubleType
+ case java.sql.Types.FLOAT => FloatType
+ case java.sql.Types.INTEGER => IntegerType
+ case java.sql.Types.JAVA_OBJECT => null
+ case java.sql.Types.LONGNVARCHAR => StringType
+ case java.sql.Types.LONGVARBINARY => BinaryType
+ case java.sql.Types.LONGVARCHAR => StringType
+ case java.sql.Types.NCHAR => StringType
+ case java.sql.Types.NCLOB => StringType
+ case java.sql.Types.NULL => null
+ case java.sql.Types.NUMERIC => DecimalType.Unlimited
+ case java.sql.Types.OTHER => null
+ case java.sql.Types.REAL => DoubleType
+ case java.sql.Types.REF => StringType
+ case java.sql.Types.ROWID => LongType
+ case java.sql.Types.SMALLINT => IntegerType
+ case java.sql.Types.SQLXML => StringType
+ case java.sql.Types.STRUCT => StringType
+ case java.sql.Types.TIME => TimestampType
+ case java.sql.Types.TIMESTAMP => TimestampType
+ case java.sql.Types.TINYINT => IntegerType
+ case java.sql.Types.VARBINARY => BinaryType
+ case java.sql.Types.VARCHAR => StringType
+ case _ => null
+ }
+
+ if (answer == null) throw new SQLException("Unsupported type " + sqlType)
+ answer
+ }
+
+ /**
+ * Takes a (schema, table) specification and returns the table's Catalyst
+ * schema.
+ *
+ * @param url - The JDBC url to fetch information from.
+ * @param table - The table name of the desired table. This may also be a
+ * SQL query wrapped in parentheses.
+ *
+ * @return A StructType giving the table's Catalyst schema.
+ * @throws SQLException if the table specification is garbage.
+ * @throws SQLException if the table contains an unsupported type.
+ */
+ def resolveTable(url: String, table: String): StructType = {
+ val quirks = DriverQuirks.get(url)
+ val conn: Connection = DriverManager.getConnection(url)
+ try {
+ val rs = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0").executeQuery()
+ try {
+ val rsmd = rs.getMetaData
+ val ncols = rsmd.getColumnCount
+ var fields = new Array[StructField](ncols);
+ var i = 0
+ while (i < ncols) {
+ val columnName = rsmd.getColumnName(i + 1)
+ val dataType = rsmd.getColumnType(i + 1)
+ val typeName = rsmd.getColumnTypeName(i + 1)
+ val fieldSize = rsmd.getPrecision(i + 1)
+ val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls
+ val metadata = new MetadataBuilder().putString("name", columnName)
+ var columnType = quirks.getCatalystType(dataType, typeName, fieldSize, metadata)
+ if (columnType == null) columnType = getCatalystType(dataType)
+ fields(i) = StructField(columnName, columnType, nullable, metadata.build())
+ i = i + 1
+ }
+ return new StructType(fields)
+ } finally {
+ rs.close()
+ }
+ } finally {
+ conn.close()
+ }
+
+ throw new RuntimeException("This line is unreachable.")
+ }
+
+ /**
+ * Prune all but the specified columns from the specified Catalyst schema.
+ *
+ * @param schema - The Catalyst schema of the master table
+ * @param columns - The list of desired columns
+ *
+ * @return A Catalyst schema corresponding to columns in the given order.
+ */
+ private def pruneSchema(schema: StructType, columns: Array[String]): StructType = {
+ val fieldMap = Map(schema.fields map { x => x.metadata.getString("name") -> x }: _*)
+ new StructType(columns map { name => fieldMap(name) })
+ }
+
+ /**
+ * Given a driver string and an url, return a function that loads the
+ * specified driver string then returns a connection to the JDBC url.
+ * getConnector is run on the driver code, while the function it returns
+ * is run on the executor.
+ *
+ * @param driver - The class name of the JDBC driver for the given url.
+ * @param url - The JDBC url to connect to.
+ *
+ * @return A function that loads the driver and connects to the url.
+ */
+ def getConnector(driver: String, url: String): () => Connection = {
+ () => {
+ try {
+ if (driver != null) Class.forName(driver)
+ } catch {
+ case e: ClassNotFoundException => {
+ logWarning(s"Couldn't find class $driver", e);
+ }
+ }
+ DriverManager.getConnection(url)
+ }
+ }
+ /**
+ * Build and return JDBCRDD from the given information.
+ *
+ * @param sc - Your SparkContext.
+ * @param schema - The Catalyst schema of the underlying database table.
+ * @param driver - The class name of the JDBC driver for the given url.
+ * @param url - The JDBC url to connect to.
+ * @param fqTable - The fully-qualified table name (or paren'd SQL query) to use.
+ * @param requiredColumns - The names of the columns to SELECT.
+ * @param filters - The filters to include in all WHERE clauses.
+ * @param parts - An array of JDBCPartitions specifying partition ids and
+ * per-partition WHERE clauses.
+ *
+ * @return An RDD representing "SELECT requiredColumns FROM fqTable".
+ */
+ def scanTable(sc: SparkContext,
+ schema: StructType,
+ driver: String,
+ url: String,
+ fqTable: String,
+ requiredColumns: Array[String],
+ filters: Array[Filter],
+ parts: Array[Partition]): RDD[Row] = {
+ val prunedSchema = pruneSchema(schema, requiredColumns)
+
+ return new JDBCRDD(sc,
+ getConnector(driver, url),
+ prunedSchema,
+ fqTable,
+ requiredColumns,
+ filters,
+ parts)
+ }
+}
+
+/**
+ * An RDD representing a table in a database accessed via JDBC. Both the
+ * driver code and the workers must be able to access the database; the driver
+ * needs to fetch the schema while the workers need to fetch the data.
+ */
+private[sql] class JDBCRDD(
+ sc: SparkContext,
+ getConnection: () => Connection,
+ schema: StructType,
+ fqTable: String,
+ columns: Array[String],
+ filters: Array[Filter],
+ partitions: Array[Partition])
+ extends RDD[Row](sc, Nil) {
+
+ /**
+ * Retrieve the list of partitions corresponding to this RDD.
+ */
+ override def getPartitions: Array[Partition] = partitions
+
+ /**
+ * `columns`, but as a String suitable for injection into a SQL query.
+ */
+ private val columnList: String = {
+ val sb = new StringBuilder()
+ columns.foreach(x => sb.append(",").append(x))
+ if (sb.length == 0) "1" else sb.substring(1)
+ }
+
+ /**
+ * Turns a single Filter into a String representing a SQL expression.
+ * Returns null for an unhandled filter.
+ */
+ private def compileFilter(f: Filter): String = f match {
+ case EqualTo(attr, value) => s"$attr = $value"
+ case LessThan(attr, value) => s"$attr < $value"
+ case GreaterThan(attr, value) => s"$attr > $value"
+ case LessThanOrEqual(attr, value) => s"$attr <= $value"
+ case GreaterThanOrEqual(attr, value) => s"$attr >= $value"
+ case _ => null
+ }
+
+ /**
+ * `filters`, but as a WHERE clause suitable for injection into a SQL query.
+ */
+ private val filterWhereClause: String = {
+ val filterStrings = filters map compileFilter filter (_ != null)
+ if (filterStrings.size > 0) {
+ val sb = new StringBuilder("WHERE ")
+ filterStrings.foreach(x => sb.append(x).append(" AND "))
+ sb.substring(0, sb.length - 5)
+ } else ""
+ }
+
+ /**
+ * A WHERE clause representing both `filters`, if any, and the current partition.
+ */
+ private def getWhereClause(part: JDBCPartition): String = {
+ if (part.whereClause != null && filterWhereClause.length > 0) {
+ filterWhereClause + " AND " + part.whereClause
+ } else if (part.whereClause != null) {
+ "WHERE " + part.whereClause
+ } else {
+ filterWhereClause
+ }
+ }
+
+ // Each JDBC-to-Catalyst conversion corresponds to a tag defined here so that
+ // we don't have to potentially poke around in the Metadata once for every
+ // row.
+ // Is there a better way to do this? I'd rather be using a type that
+ // contains only the tags I define.
+ abstract class JDBCConversion
+ case object BooleanConversion extends JDBCConversion
+ case object DateConversion extends JDBCConversion
+ case object DecimalConversion extends JDBCConversion
+ case object DoubleConversion extends JDBCConversion
+ case object FloatConversion extends JDBCConversion
+ case object IntegerConversion extends JDBCConversion
+ case object LongConversion extends JDBCConversion
+ case object BinaryLongConversion extends JDBCConversion
+ case object StringConversion extends JDBCConversion
+ case object TimestampConversion extends JDBCConversion
+ case object BinaryConversion extends JDBCConversion
+
+ /**
+ * Maps a StructType to a type tag list.
+ */
+ def getConversions(schema: StructType): Array[JDBCConversion] = {
+ schema.fields.map(sf => sf.dataType match {
+ case BooleanType => BooleanConversion
+ case DateType => DateConversion
+ case DecimalType.Unlimited => DecimalConversion
+ case DoubleType => DoubleConversion
+ case FloatType => FloatConversion
+ case IntegerType => IntegerConversion
+ case LongType =>
+ if (sf.metadata.contains("binarylong")) BinaryLongConversion else LongConversion
+ case StringType => StringConversion
+ case TimestampType => TimestampConversion
+ case BinaryType => BinaryConversion
+ case _ => throw new IllegalArgumentException(s"Unsupported field $sf")
+ }).toArray
+ }
+
+
+ /**
+ * Runs the SQL query against the JDBC driver.
+ */
+ override def compute(thePart: Partition, context: TaskContext) = new Iterator[Row] {
+ var closed = false
+ var finished = false
+ var gotNext = false
+ var nextValue: Row = null
+
+ context.addTaskCompletionListener{ context => close() }
+ val part = thePart.asInstanceOf[JDBCPartition]
+ val conn = getConnection()
+
+ // H2's JDBC driver does not support the setSchema() method. We pass a
+ // fully-qualified table name in the SELECT statement. I don't know how to
+ // talk about a table in a completely portable way.
+
+ val myWhereClause = getWhereClause(part)
+
+ val sqlText = s"SELECT $columnList FROM $fqTable $myWhereClause"
+ val stmt = conn.prepareStatement(sqlText,
+ ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
+ val rs = stmt.executeQuery()
+
+ val conversions = getConversions(schema)
+ val mutableRow = new SpecificMutableRow(schema.fields.map(x => x.dataType))
+
+ def getNext(): Row = {
+ if (rs.next()) {
+ var i = 0
+ while (i < conversions.length) {
+ val pos = i + 1
+ conversions(i) match {
+ case BooleanConversion => mutableRow.setBoolean(i, rs.getBoolean(pos))
+ case DateConversion => mutableRow.update(i, rs.getDate(pos))
+ case DecimalConversion => mutableRow.update(i, rs.getBigDecimal(pos))
+ case DoubleConversion => mutableRow.setDouble(i, rs.getDouble(pos))
+ case FloatConversion => mutableRow.setFloat(i, rs.getFloat(pos))
+ case IntegerConversion => mutableRow.setInt(i, rs.getInt(pos))
+ case LongConversion => mutableRow.setLong(i, rs.getLong(pos))
+ case StringConversion => mutableRow.setString(i, rs.getString(pos))
+ case TimestampConversion => mutableRow.update(i, rs.getTimestamp(pos))
+ case BinaryConversion => mutableRow.update(i, rs.getBytes(pos))
+ case BinaryLongConversion => {
+ val bytes = rs.getBytes(pos)
+ var ans = 0L
+ var j = 0
+ while (j < bytes.size) {
+ ans = 256*ans + (255 & bytes(j))
+ j = j + 1;
+ }
+ mutableRow.setLong(i, ans)
+ }
+ }
+ if (rs.wasNull) mutableRow.setNullAt(i)
+ i = i + 1
+ }
+ mutableRow
+ } else {
+ finished = true
+ null.asInstanceOf[Row]
+ }
+ }
+
+ def close() {
+ if (closed) return
+ try {
+ if (null != rs && ! rs.isClosed()) {
+ rs.close()
+ }
+ } catch {
+ case e: Exception => logWarning("Exception closing resultset", e)
+ }
+ try {
+ if (null != stmt && ! stmt.isClosed()) {
+ stmt.close()
+ }
+ } catch {
+ case e: Exception => logWarning("Exception closing statement", e)
+ }
+ try {
+ if (null != conn && ! conn.isClosed()) {
+ conn.close()
+ }
+ logInfo("closed connection")
+ } catch {
+ case e: Exception => logWarning("Exception closing connection", e)
+ }
+ }
+
+ override def hasNext: Boolean = {
+ if (!finished) {
+ if (!gotNext) {
+ nextValue = getNext()
+ if (finished) {
+ close()
+ }
+ gotNext = true
+ }
+ }
+ !finished
+ }
+
+ override def next(): Row = {
+ if (!hasNext) {
+ throw new NoSuchElementException("End of stream")
+ }
+ gotNext = false
+ nextValue
+ }
+
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala
new file mode 100644
index 0000000000..e09125e406
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.jdbc
+
+import scala.collection.mutable.ArrayBuffer
+import java.sql.DriverManager
+
+import org.apache.spark.Partition
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.sources._
+
+/**
+ * Data corresponding to one partition of a JDBCRDD.
+ */
+private[sql] case class JDBCPartition(whereClause: String, idx: Int) extends Partition {
+ override def index: Int = idx
+}
+
+/**
+ * Instructions on how to partition the table among workers.
+ */
+private[sql] case class JDBCPartitioningInfo(
+ column: String,
+ lowerBound: Long,
+ upperBound: Long,
+ numPartitions: Int)
+
+private[sql] object JDBCRelation {
+ /**
+ * Given a partitioning schematic (a column of integral type, a number of
+ * partitions, and upper and lower bounds on the column's value), generate
+ * WHERE clauses for each partition so that each row in the table appears
+ * exactly once. The parameters minValue and maxValue are advisory in that
+ * incorrect values may cause the partitioning to be poor, but no data
+ * will fail to be represented.
+ *
+ * @param column - Column name. Must refer to a column of integral type.
+ * @param numPartitions - Number of partitions
+ * @param minValue - Smallest value of column. Advisory.
+ * @param maxValue - Largest value of column. Advisory.
+ */
+ def columnPartition(partitioning: JDBCPartitioningInfo): Array[Partition] = {
+ if (partitioning == null) return Array[Partition](JDBCPartition(null, 0))
+
+ val numPartitions = partitioning.numPartitions
+ val column = partitioning.column
+ if (numPartitions == 1) return Array[Partition](JDBCPartition(null, 0))
+ // Overflow and silliness can happen if you subtract then divide.
+ // Here we get a little roundoff, but that's (hopefully) OK.
+ val stride: Long = (partitioning.upperBound / numPartitions
+ - partitioning.lowerBound / numPartitions)
+ var i: Int = 0
+ var currentValue: Long = partitioning.lowerBound
+ var ans = new ArrayBuffer[Partition]()
+ while (i < numPartitions) {
+ val lowerBound = (if (i != 0) s"$column >= $currentValue" else null)
+ currentValue += stride
+ val upperBound = (if (i != numPartitions - 1) s"$column < $currentValue" else null)
+ val whereClause = (if (upperBound == null) lowerBound
+ else if (lowerBound == null) upperBound
+ else s"$lowerBound AND $upperBound")
+ ans += JDBCPartition(whereClause, i)
+ i = i + 1
+ }
+ ans.toArray
+ }
+}
+
+private[sql] class DefaultSource extends RelationProvider {
+ /** Returns a new base relation with the given parameters. */
+ override def createRelation(
+ sqlContext: SQLContext,
+ parameters: Map[String, String]): BaseRelation = {
+ val url = parameters.getOrElse("url", sys.error("Option 'url' not specified"))
+ val driver = parameters.getOrElse("driver", null)
+ val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified"))
+ val partitionColumn = parameters.getOrElse("partitionColumn", null)
+ val lowerBound = parameters.getOrElse("lowerBound", null)
+ val upperBound = parameters.getOrElse("upperBound", null)
+ val numPartitions = parameters.getOrElse("numPartitions", null)
+
+ if (driver != null) Class.forName(driver)
+
+ if ( partitionColumn != null
+ && (lowerBound == null || upperBound == null || numPartitions == null)) {
+ sys.error("Partitioning incompletely specified")
+ }
+
+ val partitionInfo = if (partitionColumn == null) {
+ null
+ } else {
+ JDBCPartitioningInfo(partitionColumn,
+ lowerBound.toLong, upperBound.toLong,
+ numPartitions.toInt)
+ }
+ val parts = JDBCRelation.columnPartition(partitionInfo)
+ JDBCRelation(url, table, parts)(sqlContext)
+ }
+}
+
+private[sql] case class JDBCRelation(url: String,
+ table: String,
+ parts: Array[Partition])(
+ @transient val sqlContext: SQLContext)
+ extends PrunedFilteredScan {
+
+ override val schema = JDBCRDD.resolveTable(url, table)
+
+ override def buildScan(requiredColumns: Array[String], filters: Array[Filter]) = {
+ val driver: String = DriverManager.getDriver(url).getClass.getCanonicalName
+ JDBCRDD.scanTable(sqlContext.sparkContext,
+ schema,
+ driver, url,
+ table,
+ requiredColumns, filters,
+ parts)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JavaJDBCTrampoline.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JavaJDBCTrampoline.scala
new file mode 100644
index 0000000000..86bb67ec74
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JavaJDBCTrampoline.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.jdbc
+
+import org.apache.spark.sql.DataFrame
+
+private[jdbc] class JavaJDBCTrampoline {
+ def createJDBCTable(rdd: DataFrame, url: String, table: String, allowExisting: Boolean) {
+ rdd.createJDBCTable(url, table, allowExisting);
+ }
+
+ def insertIntoJDBC(rdd: DataFrame, url: String, table: String, overwrite: Boolean) {
+ rdd.insertIntoJDBC(url, table, overwrite);
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala
new file mode 100644
index 0000000000..34a83f0a5d
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala
@@ -0,0 +1,235 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.sql.{Connection, DriverManager, PreparedStatement}
+import org.apache.spark.{Logging, Partition}
+import org.apache.spark.sql._
+import org.apache.spark.sql.sources.LogicalRelation
+
+import org.apache.spark.sql.jdbc.{JDBCPartitioningInfo, JDBCRelation, JDBCPartition}
+import org.apache.spark.sql.types._
+
+package object jdbc {
+ object JDBCWriteDetails extends Logging {
+ /**
+ * Returns a PreparedStatement that inserts a row into table via conn.
+ */
+ private def insertStatement(conn: Connection, table: String, rddSchema: StructType):
+ PreparedStatement = {
+ val sql = new StringBuilder(s"INSERT INTO $table VALUES (")
+ var fieldsLeft = rddSchema.fields.length
+ while (fieldsLeft > 0) {
+ sql.append("?")
+ if (fieldsLeft > 1) sql.append(", ") else sql.append(")")
+ fieldsLeft = fieldsLeft - 1
+ }
+ conn.prepareStatement(sql.toString)
+ }
+
+ /**
+ * Saves a partition of a DataFrame to the JDBC database. This is done in
+ * a single database transaction in order to avoid repeatedly inserting
+ * data as much as possible.
+ *
+ * It is still theoretically possible for rows in a DataFrame to be
+ * inserted into the database more than once if a stage somehow fails after
+ * the commit occurs but before the stage can return successfully.
+ *
+ * This is not a closure inside saveTable() because apparently cosmetic
+ * implementation changes elsewhere might easily render such a closure
+ * non-Serializable. Instead, we explicitly close over all variables that
+ * are used.
+ */
+ private[jdbc] def savePartition(url: String, table: String, iterator: Iterator[Row],
+ rddSchema: StructType, nullTypes: Array[Int]): Iterator[Byte] = {
+ val conn = DriverManager.getConnection(url)
+ var committed = false
+ try {
+ conn.setAutoCommit(false) // Everything in the same db transaction.
+ val stmt = insertStatement(conn, table, rddSchema)
+ try {
+ while (iterator.hasNext) {
+ val row = iterator.next()
+ val numFields = rddSchema.fields.length
+ var i = 0
+ while (i < numFields) {
+ if (row.isNullAt(i)) {
+ stmt.setNull(i + 1, nullTypes(i))
+ } else {
+ rddSchema.fields(i).dataType match {
+ case IntegerType => stmt.setInt(i + 1, row.getInt(i))
+ case LongType => stmt.setLong(i + 1, row.getLong(i))
+ case DoubleType => stmt.setDouble(i + 1, row.getDouble(i))
+ case FloatType => stmt.setFloat(i + 1, row.getFloat(i))
+ case ShortType => stmt.setInt(i + 1, row.getShort(i))
+ case ByteType => stmt.setInt(i + 1, row.getByte(i))
+ case BooleanType => stmt.setBoolean(i + 1, row.getBoolean(i))
+ case StringType => stmt.setString(i + 1, row.getString(i))
+ case BinaryType => stmt.setBytes(i + 1, row.getAs[Array[Byte]](i))
+ case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i))
+ case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i))
+ case DecimalType.Unlimited => stmt.setBigDecimal(i + 1,
+ row.getAs[java.math.BigDecimal](i))
+ case _ => throw new IllegalArgumentException(
+ s"Can't translate non-null value for field $i")
+ }
+ }
+ i = i + 1
+ }
+ stmt.executeUpdate()
+ }
+ } finally {
+ stmt.close()
+ }
+ conn.commit()
+ committed = true
+ } finally {
+ if (!committed) {
+ // The stage must fail. We got here through an exception path, so
+ // let the exception through unless rollback() or close() want to
+ // tell the user about another problem.
+ conn.rollback()
+ conn.close()
+ } else {
+ // The stage must succeed. We cannot propagate any exception close() might throw.
+ try {
+ conn.close()
+ } catch {
+ case e: Exception => logWarning("Transaction succeeded, but closing failed", e)
+ }
+ }
+ }
+ Array[Byte]().iterator
+ }
+ }
+
+ /**
+ * Make it so that you can call createJDBCTable and insertIntoJDBC on a DataFrame.
+ */
+ implicit class JDBCDataFrame(rdd: DataFrame) {
+ /**
+ * Compute the schema string for this RDD.
+ */
+ private def schemaString(url: String): String = {
+ val sb = new StringBuilder()
+ val quirks = DriverQuirks.get(url)
+ rdd.schema.fields foreach { field => {
+ val name = field.name
+ var typ: String = quirks.getJDBCType(field.dataType)._1
+ if (typ == null) typ = field.dataType match {
+ case IntegerType => "INTEGER"
+ case LongType => "BIGINT"
+ case DoubleType => "DOUBLE PRECISION"
+ case FloatType => "REAL"
+ case ShortType => "INTEGER"
+ case ByteType => "BYTE"
+ case BooleanType => "BIT(1)"
+ case StringType => "TEXT"
+ case BinaryType => "BLOB"
+ case TimestampType => "TIMESTAMP"
+ case DateType => "DATE"
+ case DecimalType.Unlimited => "DECIMAL(40,20)"
+ case _ => throw new IllegalArgumentException(s"Don't know how to save $field to JDBC")
+ }
+ val nullable = if (field.nullable) "" else "NOT NULL"
+ sb.append(s", $name $typ $nullable")
+ }}
+ if (sb.length < 2) "" else sb.substring(2)
+ }
+
+ /**
+ * Saves the RDD to the database in a single transaction.
+ */
+ private def saveTable(url: String, table: String) {
+ val quirks = DriverQuirks.get(url)
+ var nullTypes: Array[Int] = rdd.schema.fields.map(field => {
+ var nullType: Option[Int] = quirks.getJDBCType(field.dataType)._2
+ if (nullType.isEmpty) {
+ field.dataType match {
+ case IntegerType => java.sql.Types.INTEGER
+ case LongType => java.sql.Types.BIGINT
+ case DoubleType => java.sql.Types.DOUBLE
+ case FloatType => java.sql.Types.REAL
+ case ShortType => java.sql.Types.INTEGER
+ case ByteType => java.sql.Types.INTEGER
+ case BooleanType => java.sql.Types.BIT
+ case StringType => java.sql.Types.CLOB
+ case BinaryType => java.sql.Types.BLOB
+ case TimestampType => java.sql.Types.TIMESTAMP
+ case DateType => java.sql.Types.DATE
+ case DecimalType.Unlimited => java.sql.Types.DECIMAL
+ case _ => throw new IllegalArgumentException(
+ s"Can't translate null value for field $field")
+ }
+ } else nullType.get
+ }).toArray
+
+ val rddSchema = rdd.schema
+ rdd.mapPartitions(iterator => JDBCWriteDetails.savePartition(
+ url, table, iterator, rddSchema, nullTypes)).collect()
+ }
+
+ /**
+ * Save this RDD to a JDBC database at `url` under the table name `table`.
+ * This will run a `CREATE TABLE` and a bunch of `INSERT INTO` statements.
+ * If you pass `true` for `allowExisting`, it will drop any table with the
+ * given name; if you pass `false`, it will throw if the table already
+ * exists.
+ */
+ def createJDBCTable(url: String, table: String, allowExisting: Boolean) {
+ val conn = DriverManager.getConnection(url)
+ try {
+ if (allowExisting) {
+ val sql = s"DROP TABLE IF EXISTS $table"
+ conn.prepareStatement(sql).executeUpdate()
+ }
+ val schema = schemaString(url)
+ val sql = s"CREATE TABLE $table ($schema)"
+ conn.prepareStatement(sql).executeUpdate()
+ } finally {
+ conn.close()
+ }
+ saveTable(url, table)
+ }
+
+ /**
+ * Save this RDD to a JDBC database at `url` under the table name `table`.
+ * Assumes the table already exists and has a compatible schema. If you
+ * pass `true` for `overwrite`, it will `TRUNCATE` the table before
+ * performing the `INSERT`s.
+ *
+ * The table must already exist on the database. It must have a schema
+ * that is compatible with the schema of this RDD; inserting the rows of
+ * the RDD in order via the simple statement
+ * `INSERT INTO table VALUES (?, ?, ..., ?)` should not fail.
+ */
+ def insertIntoJDBC(url: String, table: String, overwrite: Boolean) {
+ if (overwrite) {
+ val conn = DriverManager.getConnection(url)
+ try {
+ val sql = s"TRUNCATE TABLE $table"
+ conn.prepareStatement(sql).executeUpdate()
+ } finally {
+ conn.close()
+ }
+ }
+ saveTable(url, table)
+ }
+ } // implicit class JDBCDataFrame
+} // package object jdbc
diff --git a/sql/core/src/test/java/org/apache/spark/sql/jdbc/JavaJDBCTest.java b/sql/core/src/test/java/org/apache/spark/sql/jdbc/JavaJDBCTest.java
new file mode 100644
index 0000000000..80bd74f5b5
--- /dev/null
+++ b/sql/core/src/test/java/org/apache/spark/sql/jdbc/JavaJDBCTest.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.jdbc;
+
+import org.junit.*;
+import static org.junit.Assert.*;
+import java.sql.Connection;
+import java.sql.DriverManager;
+
+import org.apache.spark.SparkEnv;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.api.java.*;
+import org.apache.spark.sql.test.TestSQLContext$;
+
+public class JavaJDBCTest {
+ static String url = "jdbc:h2:mem:testdb1";
+
+ static Connection conn = null;
+
+ // This variable will always be null if TestSQLContext is intact when running
+ // these tests. Some Java tests do not play nicely with others, however;
+ // they create a SparkContext of their own at startup and stop it at exit.
+ // This renders TestSQLContext inoperable, meaning we have to do the same
+ // thing. If this variable is nonnull, that means we allocated a
+ // SparkContext of our own and that we need to stop it at teardown.
+ static JavaSparkContext localSparkContext = null;
+
+ static SQLContext sql = TestSQLContext$.MODULE$;
+
+ @Before
+ public void beforeTest() throws Exception {
+ if (SparkEnv.get() == null) { // A previous test destroyed TestSQLContext.
+ localSparkContext = new JavaSparkContext("local", "JavaAPISuite");
+ sql = new SQLContext(localSparkContext);
+ }
+ Class.forName("org.h2.Driver");
+ conn = DriverManager.getConnection(url);
+ conn.prepareStatement("create schema test").executeUpdate();
+ conn.prepareStatement("create table test.people (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate();
+ conn.prepareStatement("insert into test.people values ('fred', 1)").executeUpdate();
+ conn.prepareStatement("insert into test.people values ('mary', 2)").executeUpdate();
+ conn.prepareStatement("insert into test.people values ('joe', 3)").executeUpdate();
+ conn.commit();
+ }
+
+ @After
+ public void afterTest() throws Exception {
+ if (localSparkContext != null) {
+ localSparkContext.stop();
+ localSparkContext = null;
+ }
+ try {
+ conn.close();
+ } finally {
+ conn = null;
+ }
+ }
+
+ @Test
+ public void basicTest() throws Exception {
+ DataFrame rdd = JDBCUtils.jdbcRDD(sql, url, "TEST.PEOPLE");
+ Row[] rows = rdd.collect();
+ assertEquals(rows.length, 3);
+ }
+
+ @Test
+ public void partitioningTest() throws Exception {
+ String[] parts = new String[2];
+ parts[0] = "THEID < 2";
+ parts[1] = "THEID = 2"; // Deliberately forget about one of them.
+ DataFrame rdd = JDBCUtils.jdbcRDD(sql, url, "TEST.PEOPLE", parts);
+ Row[] rows = rdd.collect();
+ assertEquals(rows.length, 2);
+ }
+
+ @Test
+ public void writeTest() throws Exception {
+ DataFrame rdd = JDBCUtils.jdbcRDD(sql, url, "TEST.PEOPLE");
+ JDBCUtils.createJDBCTable(rdd, url, "TEST.PEOPLECOPY", false);
+ DataFrame rdd2 = JDBCUtils.jdbcRDD(sql, url, "TEST.PEOPLECOPY");
+ Row[] rows = rdd2.collect();
+ assertEquals(rows.length, 3);
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/DockerHacks.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/DockerHacks.scala
new file mode 100644
index 0000000000..f332cb389f
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/DockerHacks.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.jdbc
+
+import scala.collection.mutable.MutableList
+
+import com.spotify.docker.client._
+
+/**
+ * A factory and morgue for DockerClient objects. In the DockerClient we use,
+ * calling close() closes the desired DockerClient but also renders all other
+ * DockerClients inoperable. This is inconvenient if we have more than one
+ * open, such as during tests.
+ */
+object DockerClientFactory {
+ var numClients: Int = 0
+ val zombies = new MutableList[DockerClient]()
+
+ def get(): DockerClient = {
+ this.synchronized {
+ numClients = numClients + 1
+ DefaultDockerClient.fromEnv.build()
+ }
+ }
+
+ def close(dc: DockerClient) {
+ this.synchronized {
+ numClients = numClients - 1
+ zombies += dc
+ if (numClients == 0) {
+ zombies.foreach(_.close())
+ zombies.clear()
+ }
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
new file mode 100644
index 0000000000..d25c1390db
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -0,0 +1,248 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.jdbc
+
+import java.math.BigDecimal
+import org.apache.spark.sql.test._
+import org.scalatest.{FunSuite, BeforeAndAfter}
+import java.sql.DriverManager
+import TestSQLContext._
+
+class JDBCSuite extends FunSuite with BeforeAndAfter {
+ val url = "jdbc:h2:mem:testdb0"
+ var conn: java.sql.Connection = null
+
+ val testBytes = Array[Byte](99.toByte, 134.toByte, 135.toByte, 200.toByte, 205.toByte)
+
+ before {
+ Class.forName("org.h2.Driver")
+ conn = DriverManager.getConnection(url)
+ conn.prepareStatement("create schema test").executeUpdate()
+ conn.prepareStatement("create table test.people (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate()
+ conn.prepareStatement("insert into test.people values ('fred', 1)").executeUpdate()
+ conn.prepareStatement("insert into test.people values ('mary', 2)").executeUpdate()
+ conn.prepareStatement("insert into test.people values ('joe', 3)").executeUpdate()
+ conn.commit()
+
+ sql(
+ s"""
+ |CREATE TEMPORARY TABLE foobar
+ |USING org.apache.spark.sql.jdbc
+ |OPTIONS (url '$url', dbtable 'TEST.PEOPLE')
+ """.stripMargin.replaceAll("\n", " "))
+
+ sql(
+ s"""
+ |CREATE TEMPORARY TABLE parts
+ |USING org.apache.spark.sql.jdbc
+ |OPTIONS (url '$url', dbtable 'TEST.PEOPLE',
+ |partitionColumn 'THEID', lowerBound '1', upperBound '4', numPartitions '3')
+ """.stripMargin.replaceAll("\n", " "))
+
+ conn.prepareStatement("create table test.inttypes (a INT, b BOOLEAN, c TINYINT, "
+ + "d SMALLINT, e BIGINT)").executeUpdate()
+ conn.prepareStatement("insert into test.inttypes values (1, false, 3, 4, 1234567890123)"
+ ).executeUpdate()
+ conn.prepareStatement("insert into test.inttypes values (null, null, null, null, null)"
+ ).executeUpdate()
+ conn.commit()
+ sql(
+ s"""
+ |CREATE TEMPORARY TABLE inttypes
+ |USING org.apache.spark.sql.jdbc
+ |OPTIONS (url '$url', dbtable 'TEST.INTTYPES')
+ """.stripMargin.replaceAll("\n", " "))
+
+ conn.prepareStatement("create table test.strtypes (a BINARY(20), b VARCHAR(20), "
+ + "c VARCHAR_IGNORECASE(20), d CHAR(20), e BLOB, f CLOB)").executeUpdate()
+ var stmt = conn.prepareStatement("insert into test.strtypes values (?, ?, ?, ?, ?, ?)")
+ stmt.setBytes(1, testBytes)
+ stmt.setString(2, "Sensitive")
+ stmt.setString(3, "Insensitive")
+ stmt.setString(4, "Twenty-byte CHAR")
+ stmt.setBytes(5, testBytes)
+ stmt.setString(6, "I am a clob!")
+ stmt.executeUpdate()
+ sql(
+ s"""
+ |CREATE TEMPORARY TABLE strtypes
+ |USING org.apache.spark.sql.jdbc
+ |OPTIONS (url '$url', dbtable 'TEST.STRTYPES')
+ """.stripMargin.replaceAll("\n", " "))
+
+ conn.prepareStatement("create table test.timetypes (a TIME, b DATE, c TIMESTAMP)"
+ ).executeUpdate()
+ conn.prepareStatement("insert into test.timetypes values ('12:34:56', "
+ + "'1996-01-01', '2002-02-20 11:22:33.543543543')").executeUpdate()
+ conn.commit()
+ sql(
+ s"""
+ |CREATE TEMPORARY TABLE timetypes
+ |USING org.apache.spark.sql.jdbc
+ |OPTIONS (url '$url', dbtable 'TEST.TIMETYPES')
+ """.stripMargin.replaceAll("\n", " "))
+
+
+ conn.prepareStatement("create table test.flttypes (a DOUBLE, b REAL, c DECIMAL(40, 20))"
+ ).executeUpdate()
+ conn.prepareStatement("insert into test.flttypes values ("
+ + "1.0000000000000002220446049250313080847263336181640625, "
+ + "1.00000011920928955078125, "
+ + "123456789012345.543215432154321)").executeUpdate()
+ conn.commit()
+ sql(
+ s"""
+ |CREATE TEMPORARY TABLE flttypes
+ |USING org.apache.spark.sql.jdbc
+ |OPTIONS (url '$url', dbtable 'TEST.FLTTYPES')
+ """.stripMargin.replaceAll("\n", " "))
+
+ // Untested: IDENTITY, OTHER, UUID, ARRAY, and GEOMETRY types.
+ }
+
+ after {
+ conn.close()
+ }
+
+ test("SELECT *") {
+ assert(sql("SELECT * FROM foobar").collect().size == 3)
+ }
+
+ test("SELECT * WHERE (simple predicates)") {
+ assert(sql("SELECT * FROM foobar WHERE THEID < 1").collect().size == 0)
+ assert(sql("SELECT * FROM foobar WHERE THEID != 2").collect().size == 2)
+ assert(sql("SELECT * FROM foobar WHERE THEID = 1").collect().size == 1)
+ }
+
+ test("SELECT first field") {
+ val names = sql("SELECT NAME FROM foobar").collect().map(x => x.getString(0)).sortWith(_ < _)
+ assert(names.size == 3)
+ assert(names(0).equals("fred"))
+ assert(names(1).equals("joe"))
+ assert(names(2).equals("mary"))
+ }
+
+ test("SELECT second field") {
+ val ids = sql("SELECT THEID FROM foobar").collect().map(x => x.getInt(0)).sortWith(_ < _)
+ assert(ids.size == 3)
+ assert(ids(0) == 1)
+ assert(ids(1) == 2)
+ assert(ids(2) == 3)
+ }
+
+ test("SELECT * partitioned") {
+ assert(sql("SELECT * FROM parts").collect().size == 3)
+ }
+
+ test("SELECT WHERE (simple predicates) partitioned") {
+ assert(sql("SELECT * FROM parts WHERE THEID < 1").collect().size == 0)
+ assert(sql("SELECT * FROM parts WHERE THEID != 2").collect().size == 2)
+ assert(sql("SELECT THEID FROM parts WHERE THEID = 1").collect().size == 1)
+ }
+
+ test("SELECT second field partitioned") {
+ val ids = sql("SELECT THEID FROM parts").collect().map(x => x.getInt(0)).sortWith(_ < _)
+ assert(ids.size == 3)
+ assert(ids(0) == 1)
+ assert(ids(1) == 2)
+ assert(ids(2) == 3)
+ }
+
+ test("Basic API") {
+ assert(TestSQLContext.jdbcRDD(url, "TEST.PEOPLE").collect.size == 3)
+ }
+
+ test("Partitioning via JDBCPartitioningInfo API") {
+ val parts = JDBCPartitioningInfo("THEID", 0, 4, 3)
+ assert(TestSQLContext.jdbcRDD(url, "TEST.PEOPLE", parts).collect.size == 3)
+ }
+
+ test("Partitioning via list-of-where-clauses API") {
+ val parts = Array[String]("THEID < 2", "THEID >= 2")
+ assert(TestSQLContext.jdbcRDD(url, "TEST.PEOPLE", parts).collect.size == 3)
+ }
+
+ test("H2 integral types") {
+ val rows = sql("SELECT * FROM inttypes WHERE A IS NOT NULL").collect()
+ assert(rows.size == 1)
+ assert(rows(0).getInt(0) == 1)
+ assert(rows(0).getBoolean(1) == false)
+ assert(rows(0).getInt(2) == 3)
+ assert(rows(0).getInt(3) == 4)
+ assert(rows(0).getLong(4) == 1234567890123L)
+ }
+
+ test("H2 null entries") {
+ val rows = sql("SELECT * FROM inttypes WHERE A IS NULL").collect()
+ assert(rows.size == 1)
+ assert(rows(0).isNullAt(0))
+ assert(rows(0).isNullAt(1))
+ assert(rows(0).isNullAt(2))
+ assert(rows(0).isNullAt(3))
+ assert(rows(0).isNullAt(4))
+ }
+
+ test("H2 string types") {
+ val rows = sql("SELECT * FROM strtypes").collect()
+ assert(rows(0).getAs[Array[Byte]](0).sameElements(testBytes))
+ assert(rows(0).getString(1).equals("Sensitive"))
+ assert(rows(0).getString(2).equals("Insensitive"))
+ assert(rows(0).getString(3).equals("Twenty-byte CHAR"))
+ assert(rows(0).getAs[Array[Byte]](4).sameElements(testBytes))
+ assert(rows(0).getString(5).equals("I am a clob!"))
+ }
+
+ test("H2 time types") {
+ val rows = sql("SELECT * FROM timetypes").collect()
+ assert(rows(0).getAs[java.sql.Timestamp](0).getHours == 12)
+ assert(rows(0).getAs[java.sql.Timestamp](0).getMinutes == 34)
+ assert(rows(0).getAs[java.sql.Timestamp](0).getSeconds == 56)
+ assert(rows(0).getAs[java.sql.Date](1).getYear == 96)
+ assert(rows(0).getAs[java.sql.Date](1).getMonth == 0)
+ assert(rows(0).getAs[java.sql.Date](1).getDate == 1)
+ assert(rows(0).getAs[java.sql.Timestamp](2).getYear == 102)
+ assert(rows(0).getAs[java.sql.Timestamp](2).getMonth == 1)
+ assert(rows(0).getAs[java.sql.Timestamp](2).getDate == 20)
+ assert(rows(0).getAs[java.sql.Timestamp](2).getHours == 11)
+ assert(rows(0).getAs[java.sql.Timestamp](2).getMinutes == 22)
+ assert(rows(0).getAs[java.sql.Timestamp](2).getSeconds == 33)
+ assert(rows(0).getAs[java.sql.Timestamp](2).getNanos == 543543543)
+ }
+
+ test("H2 floating-point types") {
+ val rows = sql("SELECT * FROM flttypes").collect()
+ assert(rows(0).getDouble(0) == 1.00000000000000022) // Yes, I meant ==.
+ assert(rows(0).getDouble(1) == 1.00000011920928955) // Yes, I meant ==.
+ assert(rows(0).getAs[BigDecimal](2)
+ .equals(new BigDecimal("123456789012345.54321543215432100000")))
+ }
+
+
+ test("SQL query as table name") {
+ sql(
+ s"""
+ |CREATE TEMPORARY TABLE hack
+ |USING org.apache.spark.sql.jdbc
+ |OPTIONS (url '$url', dbtable '(SELECT B, B*B FROM TEST.FLTTYPES)')
+ """.stripMargin.replaceAll("\n", " "))
+ val rows = sql("SELECT * FROM hack").collect()
+ assert(rows(0).getDouble(0) == 1.00000011920928955) // Yes, I meant ==.
+ // For some reason, H2 computes this square incorrectly...
+ assert(math.abs(rows(0).getDouble(1) - 1.00000023841859331) < 1e-12)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
new file mode 100644
index 0000000000..e581ac9b50
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.jdbc
+
+import java.math.BigDecimal
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.test._
+import org.scalatest.{FunSuite, BeforeAndAfter}
+import java.sql.DriverManager
+import TestSQLContext._
+
+class JDBCWriteSuite extends FunSuite with BeforeAndAfter {
+ val url = "jdbc:h2:mem:testdb2"
+ var conn: java.sql.Connection = null
+
+ before {
+ Class.forName("org.h2.Driver")
+ conn = DriverManager.getConnection(url)
+ conn.prepareStatement("create schema test").executeUpdate()
+ }
+
+ after {
+ conn.close()
+ }
+
+ val sc = TestSQLContext.sparkContext
+
+ val arr2x2 = Array[Row](Row.apply("dave", 42), Row.apply("mary", 222))
+ val arr1x2 = Array[Row](Row.apply("fred", 3))
+ val schema2 = StructType(
+ StructField("name", StringType) ::
+ StructField("id", IntegerType) :: Nil)
+
+ val arr2x3 = Array[Row](Row.apply("dave", 42, 1), Row.apply("mary", 222, 2))
+ val schema3 = StructType(
+ StructField("name", StringType) ::
+ StructField("id", IntegerType) ::
+ StructField("seq", IntegerType) :: Nil)
+
+ test("Basic CREATE") {
+ val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x2), schema2)
+
+ srdd.createJDBCTable(url, "TEST.BASICCREATETEST", false)
+ assert(2 == TestSQLContext.jdbcRDD(url, "TEST.BASICCREATETEST").count)
+ assert(2 == TestSQLContext.jdbcRDD(url, "TEST.BASICCREATETEST").collect()(0).length)
+ }
+
+ test("CREATE with overwrite") {
+ val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x3), schema3)
+ val srdd2 = TestSQLContext.applySchema(sc.parallelize(arr1x2), schema2)
+
+ srdd.createJDBCTable(url, "TEST.DROPTEST", false)
+ assert(2 == TestSQLContext.jdbcRDD(url, "TEST.DROPTEST").count)
+ assert(3 == TestSQLContext.jdbcRDD(url, "TEST.DROPTEST").collect()(0).length)
+
+ srdd2.createJDBCTable(url, "TEST.DROPTEST", true)
+ assert(1 == TestSQLContext.jdbcRDD(url, "TEST.DROPTEST").count)
+ assert(2 == TestSQLContext.jdbcRDD(url, "TEST.DROPTEST").collect()(0).length)
+ }
+
+ test("CREATE then INSERT to append") {
+ val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x2), schema2)
+ val srdd2 = TestSQLContext.applySchema(sc.parallelize(arr1x2), schema2)
+
+ srdd.createJDBCTable(url, "TEST.APPENDTEST", false)
+ srdd2.insertIntoJDBC(url, "TEST.APPENDTEST", false)
+ assert(3 == TestSQLContext.jdbcRDD(url, "TEST.APPENDTEST").count)
+ assert(2 == TestSQLContext.jdbcRDD(url, "TEST.APPENDTEST").collect()(0).length)
+ }
+
+ test("CREATE then INSERT to truncate") {
+ val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x2), schema2)
+ val srdd2 = TestSQLContext.applySchema(sc.parallelize(arr1x2), schema2)
+
+ srdd.createJDBCTable(url, "TEST.TRUNCATETEST", false)
+ srdd2.insertIntoJDBC(url, "TEST.TRUNCATETEST", true)
+ assert(1 == TestSQLContext.jdbcRDD(url, "TEST.TRUNCATETEST").count)
+ assert(2 == TestSQLContext.jdbcRDD(url, "TEST.TRUNCATETEST").collect()(0).length)
+ }
+
+ test("Incompatible INSERT to append") {
+ val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x2), schema2)
+ val srdd2 = TestSQLContext.applySchema(sc.parallelize(arr2x3), schema3)
+
+ srdd.createJDBCTable(url, "TEST.INCOMPATIBLETEST", false)
+ intercept[org.apache.spark.SparkException] {
+ srdd2.insertIntoJDBC(url, "TEST.INCOMPATIBLETEST", true)
+ }
+ }
+
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegration.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegration.scala
new file mode 100644
index 0000000000..89920f2650
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegration.scala
@@ -0,0 +1,235 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.jdbc
+
+import java.math.BigDecimal
+import java.sql.{Date, DriverManager, Timestamp}
+import com.spotify.docker.client.{DefaultDockerClient, DockerClient}
+import com.spotify.docker.client.messages.ContainerConfig
+import org.scalatest.{FunSuite, BeforeAndAfterAll, Ignore}
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.SparkContext._
+import org.apache.spark.sql._
+import org.apache.spark.sql.test._
+import TestSQLContext._
+
+import org.apache.spark.sql.jdbc._
+
+class MySQLDatabase {
+ val docker: DockerClient = DockerClientFactory.get()
+ val containerId = {
+ println("Pulling mysql")
+ docker.pull("mysql")
+ println("Configuring container")
+ val config = (ContainerConfig.builder().image("mysql")
+ .env("MYSQL_ROOT_PASSWORD=rootpass")
+ .build())
+ println("Creating container")
+ val id = docker.createContainer(config).id
+ println("Starting container " + id)
+ docker.startContainer(id)
+ id
+ }
+ val ip = docker.inspectContainer(containerId).networkSettings.ipAddress
+
+ def close() {
+ try {
+ println("Killing container " + containerId)
+ docker.killContainer(containerId)
+ println("Removing container " + containerId)
+ docker.removeContainer(containerId)
+ println("Closing docker client")
+ DockerClientFactory.close(docker)
+ } catch {
+ case e: Exception => {
+ println(e)
+ println("You may need to clean this up manually.")
+ throw e
+ }
+ }
+ }
+}
+
+@Ignore class MySQLIntegration extends FunSuite with BeforeAndAfterAll {
+ var ip: String = null
+
+ def url(ip: String): String = url(ip, "mysql")
+ def url(ip: String, db: String): String = s"jdbc:mysql://$ip:3306/$db?user=root&password=rootpass"
+
+ def waitForDatabase(ip: String, maxMillis: Long) {
+ println("Waiting for database to start up.")
+ val before = System.currentTimeMillis()
+ var lastException: java.sql.SQLException = null
+ while (true) {
+ if (System.currentTimeMillis() > before + maxMillis) {
+ throw new java.sql.SQLException(s"Database not up after $maxMillis ms.", lastException)
+ }
+ try {
+ val conn = java.sql.DriverManager.getConnection(url(ip))
+ conn.close()
+ println("Database is up.")
+ return;
+ } catch {
+ case e: java.sql.SQLException => {
+ lastException = e
+ java.lang.Thread.sleep(250)
+ }
+ }
+ }
+ }
+
+ def setupDatabase(ip: String) {
+ val conn = java.sql.DriverManager.getConnection(url(ip))
+ try {
+ conn.prepareStatement("CREATE DATABASE foo").executeUpdate()
+ conn.prepareStatement("CREATE TABLE foo.tbl (x INTEGER, y TEXT(8))").executeUpdate()
+ conn.prepareStatement("INSERT INTO foo.tbl VALUES (42,'fred')").executeUpdate()
+ conn.prepareStatement("INSERT INTO foo.tbl VALUES (17,'dave')").executeUpdate()
+
+ conn.prepareStatement("CREATE TABLE foo.numbers (onebit BIT(1), tenbits BIT(10), "
+ + "small SMALLINT, med MEDIUMINT, nor INT, big BIGINT, deci DECIMAL(40,20), flt FLOAT, "
+ + "dbl DOUBLE)").executeUpdate()
+ conn.prepareStatement("INSERT INTO foo.numbers VALUES (b'0', b'1000100101', "
+ + "17, 77777, 123456789, 123456789012345, 123456789012345.123456789012345, "
+ + "42.75, 1.0000000000000002)").executeUpdate()
+
+ conn.prepareStatement("CREATE TABLE foo.dates (d DATE, t TIME, dt DATETIME, ts TIMESTAMP, "
+ + "yr YEAR)").executeUpdate()
+ conn.prepareStatement("INSERT INTO foo.dates VALUES ('1991-11-09', '13:31:24', "
+ + "'1996-01-01 01:23:45', '2009-02-13 23:31:30', '2001')").executeUpdate()
+
+ // TODO: Test locale conversion for strings.
+ conn.prepareStatement("CREATE TABLE foo.strings (a CHAR(10), b VARCHAR(10), c TINYTEXT, "
+ + "d TEXT, e MEDIUMTEXT, f LONGTEXT, g BINARY(4), h VARBINARY(10), i BLOB)"
+ ).executeUpdate()
+ conn.prepareStatement("INSERT INTO foo.strings VALUES ('the', 'quick', 'brown', 'fox', 'jumps', 'over', 'the', 'lazy', 'dog')").executeUpdate()
+ } finally {
+ conn.close()
+ }
+ }
+
+ var db: MySQLDatabase = null
+
+ override def beforeAll() {
+ // If you load the MySQL driver here, DriverManager will deadlock. The
+ // MySQL driver gets loaded when its jar gets loaded, unlike the Postgres
+ // and H2 drivers.
+ //Class.forName("com.mysql.jdbc.Driver")
+
+ db = new MySQLDatabase()
+ waitForDatabase(db.ip, 60000)
+ setupDatabase(db.ip)
+ ip = db.ip
+ }
+
+ override def afterAll() {
+ db.close()
+ }
+
+ test("Basic test") {
+ val rdd = TestSQLContext.jdbcRDD(url(ip, "foo"), "tbl")
+ val rows = rdd.collect
+ assert(rows.length == 2)
+ val types = rows(0).toSeq.map(x => x.getClass.toString)
+ assert(types.length == 2)
+ assert(types(0).equals("class java.lang.Integer"))
+ assert(types(1).equals("class java.lang.String"))
+ }
+
+ test("Numeric types") {
+ val rdd = TestSQLContext.jdbcRDD(url(ip, "foo"), "numbers")
+ val rows = rdd.collect
+ assert(rows.length == 1)
+ val types = rows(0).toSeq.map(x => x.getClass.toString)
+ assert(types.length == 9)
+ println(types(1))
+ assert(types(0).equals("class java.lang.Boolean"))
+ assert(types(1).equals("class java.lang.Long"))
+ assert(types(2).equals("class java.lang.Integer"))
+ assert(types(3).equals("class java.lang.Integer"))
+ assert(types(4).equals("class java.lang.Integer"))
+ assert(types(5).equals("class java.lang.Long"))
+ assert(types(6).equals("class java.math.BigDecimal"))
+ assert(types(7).equals("class java.lang.Double"))
+ assert(types(8).equals("class java.lang.Double"))
+ assert(rows(0).getBoolean(0) == false)
+ assert(rows(0).getLong(1) == 0x225)
+ assert(rows(0).getInt(2) == 17)
+ assert(rows(0).getInt(3) == 77777)
+ assert(rows(0).getInt(4) == 123456789)
+ assert(rows(0).getLong(5) == 123456789012345L)
+ val bd = new BigDecimal("123456789012345.12345678901234500000")
+ assert(rows(0).getAs[BigDecimal](6).equals(bd))
+ assert(rows(0).getDouble(7) == 42.75)
+ assert(rows(0).getDouble(8) == 1.0000000000000002)
+ }
+
+ test("Date types") {
+ val rdd = TestSQLContext.jdbcRDD(url(ip, "foo"), "dates")
+ val rows = rdd.collect
+ assert(rows.length == 1)
+ val types = rows(0).toSeq.map(x => x.getClass.toString)
+ assert(types.length == 5)
+ assert(types(0).equals("class java.sql.Date"))
+ assert(types(1).equals("class java.sql.Timestamp"))
+ assert(types(2).equals("class java.sql.Timestamp"))
+ assert(types(3).equals("class java.sql.Timestamp"))
+ assert(types(4).equals("class java.sql.Date"))
+ assert(rows(0).getAs[Date](0).equals(new Date(91, 10, 9)))
+ assert(rows(0).getAs[Timestamp](1).equals(new Timestamp(70, 0, 1, 13, 31, 24, 0)))
+ assert(rows(0).getAs[Timestamp](2).equals(new Timestamp(96, 0, 1, 1, 23, 45, 0)))
+ assert(rows(0).getAs[Timestamp](3).equals(new Timestamp(109, 1, 13, 23, 31, 30, 0)))
+ assert(rows(0).getAs[Date](4).equals(new Date(101, 0, 1)))
+ }
+
+ test("String types") {
+ val rdd = TestSQLContext.jdbcRDD(url(ip, "foo"), "strings")
+ val rows = rdd.collect
+ assert(rows.length == 1)
+ val types = rows(0).toSeq.map(x => x.getClass.toString)
+ assert(types.length == 9)
+ assert(types(0).equals("class java.lang.String"))
+ assert(types(1).equals("class java.lang.String"))
+ assert(types(2).equals("class java.lang.String"))
+ assert(types(3).equals("class java.lang.String"))
+ assert(types(4).equals("class java.lang.String"))
+ assert(types(5).equals("class java.lang.String"))
+ assert(types(6).equals("class [B"))
+ assert(types(7).equals("class [B"))
+ assert(types(8).equals("class [B"))
+ assert(rows(0).getString(0).equals("the"))
+ assert(rows(0).getString(1).equals("quick"))
+ assert(rows(0).getString(2).equals("brown"))
+ assert(rows(0).getString(3).equals("fox"))
+ assert(rows(0).getString(4).equals("jumps"))
+ assert(rows(0).getString(5).equals("over"))
+ assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](6), Array[Byte](116, 104, 101, 0)))
+ assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](7), Array[Byte](108, 97, 122, 121)))
+ assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](8), Array[Byte](100, 111, 103)))
+ }
+
+ test("Basic write test") {
+ val rdd1 = TestSQLContext.jdbcRDD(url(ip, "foo"), "numbers")
+ val rdd2 = TestSQLContext.jdbcRDD(url(ip, "foo"), "dates")
+ val rdd3 = TestSQLContext.jdbcRDD(url(ip, "foo"), "strings")
+ rdd1.createJDBCTable(url(ip, "foo"), "numberscopy", false)
+ rdd2.createJDBCTable(url(ip, "foo"), "datescopy", false)
+ rdd3.createJDBCTable(url(ip, "foo"), "stringscopy", false)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegration.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegration.scala
new file mode 100644
index 0000000000..c174d7adb7
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegration.scala
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.jdbc
+
+import java.math.BigDecimal
+import org.apache.spark.sql.test._
+import org.scalatest.{FunSuite, BeforeAndAfterAll, Ignore}
+import java.sql.DriverManager
+import TestSQLContext._
+import com.spotify.docker.client.{DefaultDockerClient, DockerClient}
+import com.spotify.docker.client.messages.ContainerConfig
+
+class PostgresDatabase {
+ val docker: DockerClient = DockerClientFactory.get()
+ val containerId = {
+ println("Pulling postgres")
+ docker.pull("postgres")
+ println("Configuring container")
+ val config = (ContainerConfig.builder().image("postgres")
+ .env("POSTGRES_PASSWORD=rootpass")
+ .build())
+ println("Creating container")
+ val id = docker.createContainer(config).id
+ println("Starting container " + id)
+ docker.startContainer(id)
+ id
+ }
+ val ip = docker.inspectContainer(containerId).networkSettings.ipAddress
+
+ def close() {
+ try {
+ println("Killing container " + containerId)
+ docker.killContainer(containerId)
+ println("Removing container " + containerId)
+ docker.removeContainer(containerId)
+ println("Closing docker client")
+ DockerClientFactory.close(docker)
+ } catch {
+ case e: Exception => {
+ println(e)
+ println("You may need to clean this up manually.")
+ throw e
+ }
+ }
+ }
+}
+
+@Ignore class PostgresIntegration extends FunSuite with BeforeAndAfterAll {
+ lazy val db = new PostgresDatabase()
+
+ def url(ip: String) = s"jdbc:postgresql://$ip:5432/postgres?user=postgres&password=rootpass"
+
+ def waitForDatabase(ip: String, maxMillis: Long) {
+ val before = System.currentTimeMillis()
+ var lastException: java.sql.SQLException = null
+ while (true) {
+ if (System.currentTimeMillis() > before + maxMillis) {
+ throw new java.sql.SQLException(s"Database not up after $maxMillis ms.",
+ lastException)
+ }
+ try {
+ val conn = java.sql.DriverManager.getConnection(url(ip))
+ conn.close()
+ println("Database is up.")
+ return;
+ } catch {
+ case e: java.sql.SQLException => {
+ lastException = e
+ java.lang.Thread.sleep(250)
+ }
+ }
+ }
+ }
+
+ def setupDatabase(ip: String) {
+ val conn = DriverManager.getConnection(url(ip))
+ try {
+ conn.prepareStatement("CREATE DATABASE foo").executeUpdate()
+ conn.setCatalog("foo")
+ conn.prepareStatement("CREATE TABLE bar (a text, b integer, c double precision, d bigint, "
+ + "e bit(1), f bit(10), g bytea, h boolean, i inet, j cidr)").executeUpdate()
+ conn.prepareStatement("INSERT INTO bar VALUES ('hello', 42, 1.25, 123456789012345, B'0', "
+ + "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16')").executeUpdate()
+ } finally {
+ conn.close()
+ }
+ }
+
+ override def beforeAll() {
+ println("Waiting for database to start up.")
+ waitForDatabase(db.ip, 60000)
+ println("Setting up database.")
+ setupDatabase(db.ip)
+ }
+
+ override def afterAll() {
+ db.close()
+ }
+
+ test("Type mapping for various types") {
+ val rdd = TestSQLContext.jdbcRDD(url(db.ip), "public.bar")
+ val rows = rdd.collect
+ assert(rows.length == 1)
+ val types = rows(0).toSeq.map(x => x.getClass.toString)
+ assert(types.length == 10)
+ assert(types(0).equals("class java.lang.String"))
+ assert(types(1).equals("class java.lang.Integer"))
+ assert(types(2).equals("class java.lang.Double"))
+ assert(types(3).equals("class java.lang.Long"))
+ assert(types(4).equals("class java.lang.Boolean"))
+ assert(types(5).equals("class [B"))
+ assert(types(6).equals("class [B"))
+ assert(types(7).equals("class java.lang.Boolean"))
+ assert(types(8).equals("class java.lang.String"))
+ assert(types(9).equals("class java.lang.String"))
+ assert(rows(0).getString(0).equals("hello"))
+ assert(rows(0).getInt(1) == 42)
+ assert(rows(0).getDouble(2) == 1.25)
+ assert(rows(0).getLong(3) == 123456789012345L)
+ assert(rows(0).getBoolean(4) == false)
+ // BIT(10)'s come back as ASCII strings of ten ASCII 0's and 1's...
+ assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](5), Array[Byte](49,48,48,48,49,48,48,49,48,49)))
+ assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](6), Array[Byte](0xDE.toByte, 0xAD.toByte, 0xBE.toByte, 0xEF.toByte)))
+ assert(rows(0).getBoolean(7) == true)
+ assert(rows(0).getString(8) == "172.16.0.42")
+ assert(rows(0).getString(9) == "192.168.0.0/16")
+ }
+
+ test("Basic write test") {
+ val rdd = TestSQLContext.jdbcRDD(url(db.ip), "public.bar")
+ rdd.createJDBCTable(url(db.ip), "public.barcopy", false)
+ // Test only that it doesn't bomb out.
+ }
+}