aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core/src')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala48
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala79
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala19
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala33
4 files changed, 148 insertions, 31 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 1f85dac682..01fd432cc8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
import java.io.CharArrayWriter
import java.sql.DriverManager
+import java.util.Properties
import scala.collection.JavaConversions._
import scala.language.implicitConversions
@@ -1582,7 +1583,24 @@ class DataFrame private[sql](
* @group output
*/
def createJDBCTable(url: String, table: String, allowExisting: Boolean): Unit = {
- val conn = DriverManager.getConnection(url)
+ createJDBCTable(url, table, allowExisting, new Properties())
+ }
+
+ /**
+ * Save this [[DataFrame]] to a JDBC database at `url` under the table name `table`
+ * using connection properties defined in `properties`.
+ * This will run a `CREATE TABLE` and a bunch of `INSERT INTO` statements.
+ * If you pass `true` for `allowExisting`, it will drop any table with the
+ * given name; if you pass `false`, it will throw if the table already
+ * exists.
+ * @group output
+ */
+ def createJDBCTable(
+ url: String,
+ table: String,
+ allowExisting: Boolean,
+ properties: Properties): Unit = {
+ val conn = DriverManager.getConnection(url, properties)
try {
if (allowExisting) {
val sql = s"DROP TABLE IF EXISTS $table"
@@ -1594,7 +1612,7 @@ class DataFrame private[sql](
} finally {
conn.close()
}
- JDBCWriteDetails.saveTable(this, url, table)
+ JDBCWriteDetails.saveTable(this, url, table, properties)
}
/**
@@ -1610,8 +1628,29 @@ class DataFrame private[sql](
* @group output
*/
def insertIntoJDBC(url: String, table: String, overwrite: Boolean): Unit = {
+ insertIntoJDBC(url, table, overwrite, new Properties())
+ }
+
+ /**
+ * Save this [[DataFrame]] to a JDBC database at `url` under the table name `table`
+ * using connection properties defined in `properties`.
+ * Assumes the table already exists and has a compatible schema. If you
+ * pass `true` for `overwrite`, it will `TRUNCATE` the table before
+ * performing the `INSERT`s.
+ *
+ * The table must already exist on the database. It must have a schema
+ * that is compatible with the schema of this RDD; inserting the rows of
+ * the RDD in order via the simple statement
+ * `INSERT INTO table VALUES (?, ?, ..., ?)` should not fail.
+ * @group output
+ */
+ def insertIntoJDBC(
+ url: String,
+ table: String,
+ overwrite: Boolean,
+ properties: Properties): Unit = {
if (overwrite) {
- val conn = DriverManager.getConnection(url)
+ val conn = DriverManager.getConnection(url, properties)
try {
val sql = s"TRUNCATE TABLE $table"
conn.prepareStatement(sql).executeUpdate()
@@ -1619,9 +1658,8 @@ class DataFrame private[sql](
conn.close()
}
}
- JDBCWriteDetails.saveTable(this, url, table)
+ JDBCWriteDetails.saveTable(this, url, table, properties)
}
-
////////////////////////////////////////////////////////////////////////////
// for Python API
////////////////////////////////////////////////////////////////////////////
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index afee09adaa..70ba8985d6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -949,9 +949,21 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
@Experimental
def jdbc(url: String, table: String): DataFrame = {
- jdbc(url, table, JDBCRelation.columnPartition(null))
+ jdbc(url, table, JDBCRelation.columnPartition(null), new Properties())
}
-
+
+ /**
+ * :: Experimental ::
+ * Construct a [[DataFrame]] representing the database table accessible via JDBC URL
+ * url named table and connection properties.
+ *
+ * @group specificdata
+ */
+ @Experimental
+ def jdbc(url: String, table: String, properties: Properties): DataFrame = {
+ jdbc(url, table, JDBCRelation.columnPartition(null), properties)
+ }
+
/**
* :: Experimental ::
* Construct a [[DataFrame]] representing the database table accessible via JDBC URL
@@ -963,7 +975,31 @@ class SQLContext(@transient val sparkContext: SparkContext)
* @param upperBound the maximum value of `columnName` used to decide partition stride
* @param numPartitions the number of partitions. the range `minValue`-`maxValue` will be split
* evenly into this many partitions
+ * @group specificdata
+ */
+ @Experimental
+ def jdbc(
+ url: String,
+ table: String,
+ columnName: String,
+ lowerBound: Long,
+ upperBound: Long,
+ numPartitions: Int): DataFrame = {
+ jdbc(url, table, columnName, lowerBound, upperBound, numPartitions, new Properties())
+ }
+
+ /**
+ * :: Experimental ::
+ * Construct a [[DataFrame]] representing the database table accessible via JDBC URL
+ * url named table. Partitions of the table will be retrieved in parallel based on the parameters
+ * passed to this function.
*
+ * @param columnName the name of a column of integral type that will be used for partitioning.
+ * @param lowerBound the minimum value of `columnName` used to decide partition stride
+ * @param upperBound the maximum value of `columnName` used to decide partition stride
+ * @param numPartitions the number of partitions. the range `minValue`-`maxValue` will be split
+ * evenly into this many partitions
+ * @param properties connection properties
* @group specificdata
*/
@Experimental
@@ -973,16 +1009,17 @@ class SQLContext(@transient val sparkContext: SparkContext)
columnName: String,
lowerBound: Long,
upperBound: Long,
- numPartitions: Int): DataFrame = {
+ numPartitions: Int,
+ properties: Properties): DataFrame = {
val partitioning = JDBCPartitioningInfo(columnName, lowerBound, upperBound, numPartitions)
val parts = JDBCRelation.columnPartition(partitioning)
- jdbc(url, table, parts)
+ jdbc(url, table, parts, properties)
}
-
+
/**
* :: Experimental ::
* Construct a [[DataFrame]] representing the database table accessible via JDBC URL
- * url named table. The theParts parameter gives a list expressions
+ * url named table. The theParts parameter gives a list expressions
* suitable for inclusion in WHERE clauses; each one defines one partition
* of the [[DataFrame]].
*
@@ -990,14 +1027,36 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
@Experimental
def jdbc(url: String, table: String, theParts: Array[String]): DataFrame = {
+ jdbc(url, table, theParts, new Properties())
+ }
+
+ /**
+ * :: Experimental ::
+ * Construct a [[DataFrame]] representing the database table accessible via JDBC URL
+ * url named table using connection properties. The theParts parameter gives a list expressions
+ * suitable for inclusion in WHERE clauses; each one defines one partition
+ * of the [[DataFrame]].
+ *
+ * @group specificdata
+ */
+ @Experimental
+ def jdbc(
+ url: String,
+ table: String,
+ theParts: Array[String],
+ properties: Properties): DataFrame = {
val parts: Array[Partition] = theParts.zipWithIndex.map { case (part, i) =>
JDBCPartition(part, i) : Partition
}
- jdbc(url, table, parts)
+ jdbc(url, table, parts, properties)
}
-
- private def jdbc(url: String, table: String, parts: Array[Partition]): DataFrame = {
- val relation = JDBCRelation(url, table, parts)(this)
+
+ private def jdbc(
+ url: String,
+ table: String,
+ parts: Array[Partition],
+ properties: Properties): DataFrame = {
+ val relation = JDBCRelation(url, table, parts, properties)(this)
baseRelationToDataFrame(relation)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala
index 3a6c2c1e91..c099881a01 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala
@@ -57,9 +57,14 @@ package object jdbc {
* non-Serializable. Instead, we explicitly close over all variables that
* are used.
*/
- def savePartition(url: String, table: String, iterator: Iterator[Row],
- rddSchema: StructType, nullTypes: Array[Int]): Iterator[Byte] = {
- val conn = DriverManager.getConnection(url)
+ def savePartition(
+ url: String,
+ table: String,
+ iterator: Iterator[Row],
+ rddSchema: StructType,
+ nullTypes: Array[Int],
+ properties: Properties): Iterator[Byte] = {
+ val conn = DriverManager.getConnection(url, properties)
var committed = false
try {
conn.setAutoCommit(false) // Everything in the same db transaction.
@@ -152,7 +157,11 @@ package object jdbc {
/**
* Saves the RDD to the database in a single transaction.
*/
- def saveTable(df: DataFrame, url: String, table: String) {
+ def saveTable(
+ df: DataFrame,
+ url: String,
+ table: String,
+ properties: Properties = new Properties()) {
val quirks = DriverQuirks.get(url)
var nullTypes: Array[Int] = df.schema.fields.map(field => {
var nullType: Option[Int] = quirks.getJDBCType(field.dataType)._2
@@ -178,7 +187,7 @@ package object jdbc {
val rddSchema = df.schema
df.foreachPartition { iterator =>
- JDBCWriteDetails.savePartition(url, table, iterator, rddSchema, nullTypes)
+ JDBCWriteDetails.savePartition(url, table, iterator, rddSchema, nullTypes, properties)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
index ee5c7620d1..f3ce8e6646 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.jdbc
import java.sql.DriverManager
+import java.util.Properties
import org.scalatest.{BeforeAndAfter, FunSuite}
@@ -28,15 +29,25 @@ import org.apache.spark.sql.types._
class JDBCWriteSuite extends FunSuite with BeforeAndAfter {
val url = "jdbc:h2:mem:testdb2"
var conn: java.sql.Connection = null
-
+ val url1 = "jdbc:h2:mem:testdb3"
+ var conn1: java.sql.Connection = null
+ val properties = new Properties()
+ properties.setProperty("user", "testUser")
+ properties.setProperty("password", "testPass")
+ properties.setProperty("rowId", "false")
+
before {
Class.forName("org.h2.Driver")
conn = DriverManager.getConnection(url)
conn.prepareStatement("create schema test").executeUpdate()
+
+ conn1 = DriverManager.getConnection(url1, properties)
+ conn1.prepareStatement("create schema test").executeUpdate()
}
after {
conn.close()
+ conn1.close()
}
val sc = TestSQLContext.sparkContext
@@ -65,13 +76,13 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter {
val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x3), schema3)
val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2)
- df.createJDBCTable(url, "TEST.DROPTEST", false)
- assert(2 == TestSQLContext.jdbc(url, "TEST.DROPTEST").count)
- assert(3 == TestSQLContext.jdbc(url, "TEST.DROPTEST").collect()(0).length)
+ df.createJDBCTable(url1, "TEST.DROPTEST", false, properties)
+ assert(2 == TestSQLContext.jdbc(url1, "TEST.DROPTEST", properties).count)
+ assert(3 == TestSQLContext.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length)
- df2.createJDBCTable(url, "TEST.DROPTEST", true)
- assert(1 == TestSQLContext.jdbc(url, "TEST.DROPTEST").count)
- assert(2 == TestSQLContext.jdbc(url, "TEST.DROPTEST").collect()(0).length)
+ df2.createJDBCTable(url1, "TEST.DROPTEST", true, properties)
+ assert(1 == TestSQLContext.jdbc(url1, "TEST.DROPTEST", properties).count)
+ assert(2 == TestSQLContext.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length)
}
test("CREATE then INSERT to append") {
@@ -88,10 +99,10 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter {
val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2)
val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2)
- df.createJDBCTable(url, "TEST.TRUNCATETEST", false)
- df2.insertIntoJDBC(url, "TEST.TRUNCATETEST", true)
- assert(1 == TestSQLContext.jdbc(url, "TEST.TRUNCATETEST").count)
- assert(2 == TestSQLContext.jdbc(url, "TEST.TRUNCATETEST").collect()(0).length)
+ df.createJDBCTable(url1, "TEST.TRUNCATETEST", false, properties)
+ df2.insertIntoJDBC(url1, "TEST.TRUNCATETEST", true, properties)
+ assert(1 == TestSQLContext.jdbc(url1, "TEST.TRUNCATETEST", properties).count)
+ assert(2 == TestSQLContext.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length)
}
test("Incompatible INSERT to append") {