From 50b89d05b7bffc212cc9b9ae6e0bca7cb90b9c77 Mon Sep 17 00:00:00 2001 From: Justin Pihony Date: Mon, 26 Sep 2016 09:54:22 +0100 Subject: [SPARK-14525][SQL] Make DataFrameWrite.save work for jdbc ## What changes were proposed in this pull request? This change modifies the implementation of DataFrameWriter.save such that it works with jdbc, and the call to jdbc merely delegates to save. ## How was this patch tested? This was tested via unit tests in the JDBCWriteSuite, of which I added one new test to cover this scenario. ## Additional details rxin This seems to have been most recently touched by you and was also commented on in the JIRA. This contribution is my original work and I license the work to the project under the project's open source license. Author: Justin Pihony Author: Justin Pihony Closes #12601 from JustinPihony/jdbc_reconciliation. --- .../org/apache/spark/sql/jdbc/JDBCWriteSuite.scala | 82 ++++++++++++++++++++++ 1 file changed, 82 insertions(+) (limited to 'sql/core/src/test/scala') diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index ff3309874f..506971362f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.jdbc import java.sql.DriverManager import java.util.Properties +import scala.collection.JavaConverters.propertiesAsScalaMapConverter + import org.scalatest.BeforeAndAfter import org.apache.spark.SparkException @@ -208,4 +210,84 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).count()) assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } + + test("save works for format(\"jdbc\") if url and dbtable are set") { + val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + + df.write.format("jdbc") + .options(Map("url" -> url, "dbtable" -> "TEST.SAVETEST")) + .save() + + assert(2 === sqlContext.read.jdbc(url, "TEST.SAVETEST", new Properties).count) + assert( + 2 === sqlContext.read.jdbc(url, "TEST.SAVETEST", new Properties).collect()(0).length) + } + + test("save API with SaveMode.Overwrite") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2) + + df.write.format("jdbc") + .option("url", url1) + .option("dbtable", "TEST.SAVETEST") + .options(properties.asScala) + .save() + df2.write.mode(SaveMode.Overwrite).format("jdbc") + .option("url", url1) + .option("dbtable", "TEST.SAVETEST") + .options(properties.asScala) + .save() + assert(1 === spark.read.jdbc(url1, "TEST.SAVETEST", properties).count()) + assert(2 === spark.read.jdbc(url1, "TEST.SAVETEST", properties).collect()(0).length) + } + + test("save errors if url is not specified") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + + val e = intercept[RuntimeException] { + df.write.format("jdbc") + .option("dbtable", "TEST.SAVETEST") + .options(properties.asScala) + .save() + }.getMessage + assert(e.contains("Option 'url' is required")) + } + + test("save errors if dbtable is not specified") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + + val e = intercept[RuntimeException] { + df.write.format("jdbc") + .option("url", url1) + .options(properties.asScala) + .save() + }.getMessage + assert(e.contains("Option 'dbtable' is required")) + } + + test("save errors if wrong user/password combination") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + + val e = intercept[org.h2.jdbc.JdbcSQLException] { + df.write.format("jdbc") + .option("dbtable", "TEST.SAVETEST") + .option("url", url1) + .save() + }.getMessage + assert(e.contains("Wrong user name or password")) + } + + test("save errors if partitionColumn and numPartitions and bounds not set") { + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + + val e = intercept[java.lang.IllegalArgumentException] { + df.write.format("jdbc") + .option("dbtable", "TEST.SAVETEST") + .option("url", url1) + .option("partitionColumn", "foo") + .save() + }.getMessage + assert(e.contains("If 'partitionColumn' is specified then 'lowerBound', 'upperBound'," + + " and 'numPartitions' are required.")) + } } -- cgit v1.2.3