diff options
Diffstat (limited to 'sql/core/src')
4 files changed, 148 insertions, 31 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 1f85dac682..01fd432cc8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import java.io.CharArrayWriter import java.sql.DriverManager +import java.util.Properties import scala.collection.JavaConversions._ import scala.language.implicitConversions @@ -1582,7 +1583,24 @@ class DataFrame private[sql]( * @group output */ def createJDBCTable(url: String, table: String, allowExisting: Boolean): Unit = { - val conn = DriverManager.getConnection(url) + createJDBCTable(url, table, allowExisting, new Properties()) + } + + /** + * Save this [[DataFrame]] to a JDBC database at `url` under the table name `table` + * using connection properties defined in `properties`. + * This will run a `CREATE TABLE` and a bunch of `INSERT INTO` statements. + * If you pass `true` for `allowExisting`, it will drop any table with the + * given name; if you pass `false`, it will throw if the table already + * exists. + * @group output + */ + def createJDBCTable( + url: String, + table: String, + allowExisting: Boolean, + properties: Properties): Unit = { + val conn = DriverManager.getConnection(url, properties) try { if (allowExisting) { val sql = s"DROP TABLE IF EXISTS $table" @@ -1594,7 +1612,7 @@ class DataFrame private[sql]( } finally { conn.close() } - JDBCWriteDetails.saveTable(this, url, table) + JDBCWriteDetails.saveTable(this, url, table, properties) } /** @@ -1610,8 +1628,29 @@ class DataFrame private[sql]( * @group output */ def insertIntoJDBC(url: String, table: String, overwrite: Boolean): Unit = { + insertIntoJDBC(url, table, overwrite, new Properties()) + } + + /** + * Save this [[DataFrame]] to a JDBC database at `url` under the table name `table` + * using connection properties defined in `properties`. + * Assumes the table already exists and has a compatible schema. If you + * pass `true` for `overwrite`, it will `TRUNCATE` the table before + * performing the `INSERT`s. + * + * The table must already exist on the database. It must have a schema + * that is compatible with the schema of this RDD; inserting the rows of + * the RDD in order via the simple statement + * `INSERT INTO table VALUES (?, ?, ..., ?)` should not fail. + * @group output + */ + def insertIntoJDBC( + url: String, + table: String, + overwrite: Boolean, + properties: Properties): Unit = { if (overwrite) { - val conn = DriverManager.getConnection(url) + val conn = DriverManager.getConnection(url, properties) try { val sql = s"TRUNCATE TABLE $table" conn.prepareStatement(sql).executeUpdate() @@ -1619,9 +1658,8 @@ class DataFrame private[sql]( conn.close() } } - JDBCWriteDetails.saveTable(this, url, table) + JDBCWriteDetails.saveTable(this, url, table, properties) } - //////////////////////////////////////////////////////////////////////////// // for Python API //////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index afee09adaa..70ba8985d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -949,9 +949,21 @@ class SQLContext(@transient val sparkContext: SparkContext) */ @Experimental def jdbc(url: String, table: String): DataFrame = { - jdbc(url, table, JDBCRelation.columnPartition(null)) + jdbc(url, table, JDBCRelation.columnPartition(null), new Properties()) } - + + /** + * :: Experimental :: + * Construct a [[DataFrame]] representing the database table accessible via JDBC URL + * url named table and connection properties. + * + * @group specificdata + */ + @Experimental + def jdbc(url: String, table: String, properties: Properties): DataFrame = { + jdbc(url, table, JDBCRelation.columnPartition(null), properties) + } + /** * :: Experimental :: * Construct a [[DataFrame]] representing the database table accessible via JDBC URL @@ -963,7 +975,31 @@ class SQLContext(@transient val sparkContext: SparkContext) * @param upperBound the maximum value of `columnName` used to decide partition stride * @param numPartitions the number of partitions. the range `minValue`-`maxValue` will be split * evenly into this many partitions + * @group specificdata + */ + @Experimental + def jdbc( + url: String, + table: String, + columnName: String, + lowerBound: Long, + upperBound: Long, + numPartitions: Int): DataFrame = { + jdbc(url, table, columnName, lowerBound, upperBound, numPartitions, new Properties()) + } + + /** + * :: Experimental :: + * Construct a [[DataFrame]] representing the database table accessible via JDBC URL + * url named table. Partitions of the table will be retrieved in parallel based on the parameters + * passed to this function. * + * @param columnName the name of a column of integral type that will be used for partitioning. + * @param lowerBound the minimum value of `columnName` used to decide partition stride + * @param upperBound the maximum value of `columnName` used to decide partition stride + * @param numPartitions the number of partitions. the range `minValue`-`maxValue` will be split + * evenly into this many partitions + * @param properties connection properties * @group specificdata */ @Experimental @@ -973,16 +1009,17 @@ class SQLContext(@transient val sparkContext: SparkContext) columnName: String, lowerBound: Long, upperBound: Long, - numPartitions: Int): DataFrame = { + numPartitions: Int, + properties: Properties): DataFrame = { val partitioning = JDBCPartitioningInfo(columnName, lowerBound, upperBound, numPartitions) val parts = JDBCRelation.columnPartition(partitioning) - jdbc(url, table, parts) + jdbc(url, table, parts, properties) } - + /** * :: Experimental :: * Construct a [[DataFrame]] representing the database table accessible via JDBC URL - * url named table. The theParts parameter gives a list expressions + * url named table. The theParts parameter gives a list expressions * suitable for inclusion in WHERE clauses; each one defines one partition * of the [[DataFrame]]. * @@ -990,14 +1027,36 @@ class SQLContext(@transient val sparkContext: SparkContext) */ @Experimental def jdbc(url: String, table: String, theParts: Array[String]): DataFrame = { + jdbc(url, table, theParts, new Properties()) + } + + /** + * :: Experimental :: + * Construct a [[DataFrame]] representing the database table accessible via JDBC URL + * url named table using connection properties. The theParts parameter gives a list expressions + * suitable for inclusion in WHERE clauses; each one defines one partition + * of the [[DataFrame]]. + * + * @group specificdata + */ + @Experimental + def jdbc( + url: String, + table: String, + theParts: Array[String], + properties: Properties): DataFrame = { val parts: Array[Partition] = theParts.zipWithIndex.map { case (part, i) => JDBCPartition(part, i) : Partition } - jdbc(url, table, parts) + jdbc(url, table, parts, properties) } - - private def jdbc(url: String, table: String, parts: Array[Partition]): DataFrame = { - val relation = JDBCRelation(url, table, parts)(this) + + private def jdbc( + url: String, + table: String, + parts: Array[Partition], + properties: Properties): DataFrame = { + val relation = JDBCRelation(url, table, parts, properties)(this) baseRelationToDataFrame(relation) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala index 3a6c2c1e91..c099881a01 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala @@ -57,9 +57,14 @@ package object jdbc { * non-Serializable. Instead, we explicitly close over all variables that * are used. */ - def savePartition(url: String, table: String, iterator: Iterator[Row], - rddSchema: StructType, nullTypes: Array[Int]): Iterator[Byte] = { - val conn = DriverManager.getConnection(url) + def savePartition( + url: String, + table: String, + iterator: Iterator[Row], + rddSchema: StructType, + nullTypes: Array[Int], + properties: Properties): Iterator[Byte] = { + val conn = DriverManager.getConnection(url, properties) var committed = false try { conn.setAutoCommit(false) // Everything in the same db transaction. @@ -152,7 +157,11 @@ package object jdbc { /** * Saves the RDD to the database in a single transaction. */ - def saveTable(df: DataFrame, url: String, table: String) { + def saveTable( + df: DataFrame, + url: String, + table: String, + properties: Properties = new Properties()) { val quirks = DriverQuirks.get(url) var nullTypes: Array[Int] = df.schema.fields.map(field => { var nullType: Option[Int] = quirks.getJDBCType(field.dataType)._2 @@ -178,7 +187,7 @@ package object jdbc { val rddSchema = df.schema df.foreachPartition { iterator => - JDBCWriteDetails.savePartition(url, table, iterator, rddSchema, nullTypes) + JDBCWriteDetails.savePartition(url, table, iterator, rddSchema, nullTypes, properties) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index ee5c7620d1..f3ce8e6646 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.jdbc import java.sql.DriverManager +import java.util.Properties import org.scalatest.{BeforeAndAfter, FunSuite} @@ -28,15 +29,25 @@ import org.apache.spark.sql.types._ class JDBCWriteSuite extends FunSuite with BeforeAndAfter { val url = "jdbc:h2:mem:testdb2" var conn: java.sql.Connection = null - + val url1 = "jdbc:h2:mem:testdb3" + var conn1: java.sql.Connection = null + val properties = new Properties() + properties.setProperty("user", "testUser") + properties.setProperty("password", "testPass") + properties.setProperty("rowId", "false") + before { Class.forName("org.h2.Driver") conn = DriverManager.getConnection(url) conn.prepareStatement("create schema test").executeUpdate() + + conn1 = DriverManager.getConnection(url1, properties) + conn1.prepareStatement("create schema test").executeUpdate() } after { conn.close() + conn1.close() } val sc = TestSQLContext.sparkContext @@ -65,13 +76,13 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter { val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x3), schema3) val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2) - df.createJDBCTable(url, "TEST.DROPTEST", false) - assert(2 == TestSQLContext.jdbc(url, "TEST.DROPTEST").count) - assert(3 == TestSQLContext.jdbc(url, "TEST.DROPTEST").collect()(0).length) + df.createJDBCTable(url1, "TEST.DROPTEST", false, properties) + assert(2 == TestSQLContext.jdbc(url1, "TEST.DROPTEST", properties).count) + assert(3 == TestSQLContext.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) - df2.createJDBCTable(url, "TEST.DROPTEST", true) - assert(1 == TestSQLContext.jdbc(url, "TEST.DROPTEST").count) - assert(2 == TestSQLContext.jdbc(url, "TEST.DROPTEST").collect()(0).length) + df2.createJDBCTable(url1, "TEST.DROPTEST", true, properties) + assert(1 == TestSQLContext.jdbc(url1, "TEST.DROPTEST", properties).count) + assert(2 == TestSQLContext.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) } test("CREATE then INSERT to append") { @@ -88,10 +99,10 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter { val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2) val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2) - df.createJDBCTable(url, "TEST.TRUNCATETEST", false) - df2.insertIntoJDBC(url, "TEST.TRUNCATETEST", true) - assert(1 == TestSQLContext.jdbc(url, "TEST.TRUNCATETEST").count) - assert(2 == TestSQLContext.jdbc(url, "TEST.TRUNCATETEST").collect()(0).length) + df.createJDBCTable(url1, "TEST.TRUNCATETEST", false, properties) + df2.insertIntoJDBC(url1, "TEST.TRUNCATETEST", true, properties) + assert(1 == TestSQLContext.jdbc(url1, "TEST.TRUNCATETEST", properties).count) + assert(2 == TestSQLContext.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length) } test("Incompatible INSERT to append") { |