aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorLiwei Lin <lwlin7@gmail.com>2016-05-02 16:48:20 -0700
committerMichael Armbrust <michael@databricks.com>2016-05-02 16:48:20 -0700
commit35d9c8aa69c650f33037813607dc939922c5fc27 (patch)
tree269f03fd949a43d8998bff9eb8c88c44fc61795e /sql
parentf362363d148e2df4549fed5c3fd1cf20d0848fd0 (diff)
downloadspark-35d9c8aa69c650f33037813607dc939922c5fc27.tar.gz
spark-35d9c8aa69c650f33037813607dc939922c5fc27.tar.bz2
spark-35d9c8aa69c650f33037813607dc939922c5fc27.zip
[SPARK-14747][SQL] Add assertStreaming/assertNoneStreaming checks in DataFrameWriter
## Problem If an end user happens to write code mixed with continuous-query-oriented methods and non-continuous-query-oriented methods: ```scala ctx.read .format("text") .stream("...") // continuous query .write .text("...") // non-continuous query; should be startStream() here ``` He/she would get this somehow confusing exception: > Exception in thread "main" java.lang.AssertionError: assertion failed: No plan for FileSource[./continuous_query_test_input] at scala.Predef$.assert(Predef.scala:170) at org.apache.spark.sql.catalyst.planning.QueryPlanner.plan(QueryPlanner.scala:59) at org.apache.spark.sql.catalyst.planning.QueryPlanner.planLater(QueryPlanner.scala:54) at ... ## What changes were proposed in this pull request? This PR adds checks for continuous-query-oriented methods and non-continuous-query-oriented methods in `DataFrameWriter`: <table> <tr> <td align="center"></td> <td align="center"><strong>can be called on continuous query?</strong></td> <td align="center"><strong>can be called on non-continuous query?</strong></td> </tr> <tr> <td align="center">mode</td> <td align="center"></td> <td align="center">yes</td> </tr> <tr> <td align="center">trigger</td> <td align="center">yes</td> <td align="center"></td> </tr> <tr> <td align="center">format</td> <td align="center">yes</td> <td align="center">yes</td> </tr> <tr> <td align="center">option/options</td> <td align="center">yes</td> <td align="center">yes</td> </tr> <tr> <td align="center">partitionBy</td> <td align="center">yes</td> <td align="center">yes</td> </tr> <tr> <td align="center">bucketBy</td> <td align="center"></td> <td align="center">yes</td> </tr> <tr> <td align="center">sortBy</td> <td align="center"></td> <td align="center">yes</td> </tr> <tr> <td align="center">save</td> <td align="center"></td> <td align="center">yes</td> </tr> <tr> <td align="center">queryName</td> <td align="center">yes</td> <td align="center"></td> </tr> <tr> <td align="center">startStream</td> <td align="center">yes</td> <td align="center"></td> </tr> <tr> <td align="center">insertInto</td> <td align="center"></td> <td align="center">yes</td> </tr> <tr> <td align="center">saveAsTable</td> <td align="center"></td> <td align="center">yes</td> </tr> <tr> <td align="center">jdbc</td> <td align="center"></td> <td align="center">yes</td> </tr> <tr> <td align="center">json</td> <td align="center"></td> <td align="center">yes</td> </tr> <tr> <td align="center">parquet</td> <td align="center"></td> <td align="center">yes</td> </tr> <tr> <td align="center">orc</td> <td align="center"></td> <td align="center">yes</td> </tr> <tr> <td align="center">text</td> <td align="center"></td> <td align="center">yes</td> </tr> <tr> <td align="center">csv</td> <td align="center"></td> <td align="center">yes</td> </tr> </table> After this PR's change, the friendly exception would be: > Exception in thread "main" org.apache.spark.sql.AnalysisException: text() can only be called on non-continuous queries; at org.apache.spark.sql.DataFrameWriter.assertNotStreaming(DataFrameWriter.scala:678) at org.apache.spark.sql.DataFrameWriter.text(DataFrameWriter.scala:629) at ss.SSDemo$.main(SSDemo.scala:47) ## How was this patch tested? dedicated unit tests were added Author: Liwei Lin <lwlin7@gmail.com> Closes #12521 from lw-lin/dataframe-writer-check.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala59
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala156
2 files changed, 210 insertions, 5 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index a57d47d28c..a8f96a9b45 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -53,6 +53,9 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 1.4.0
*/
def mode(saveMode: SaveMode): DataFrameWriter = {
+ // mode() is used for non-continuous queries
+ // outputMode() is used for continuous queries
+ assertNotStreaming("mode() can only be called on non-continuous queries")
this.mode = saveMode
this
}
@@ -67,6 +70,9 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 1.4.0
*/
def mode(saveMode: String): DataFrameWriter = {
+ // mode() is used for non-continuous queries
+ // outputMode() is used for continuous queries
+ assertNotStreaming("mode() can only be called on non-continuous queries")
this.mode = saveMode.toLowerCase match {
case "overwrite" => SaveMode.Overwrite
case "append" => SaveMode.Append
@@ -103,6 +109,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
*/
@Experimental
def trigger(trigger: Trigger): DataFrameWriter = {
+ assertStreaming("trigger() can only be called on continuous queries")
this.trigger = trigger
this
}
@@ -236,6 +243,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
*/
def save(): Unit = {
assertNotBucketed()
+ assertNotStreaming("save() can only be called on non-continuous queries")
val dataSource = DataSource(
df.sparkSession,
className = source,
@@ -253,6 +261,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 2.0.0
*/
def queryName(queryName: String): DataFrameWriter = {
+ assertStreaming("queryName() can only be called on continuous queries")
this.extraOptions += ("queryName" -> queryName)
this
}
@@ -276,6 +285,9 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 2.0.0
*/
def startStream(): ContinuousQuery = {
+ assertNotBucketed
+ assertStreaming("startStream() can only be called on continuous queries")
+
if (source == "memory") {
val queryName =
extraOptions.getOrElse(
@@ -348,6 +360,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
private def insertInto(tableIdent: TableIdentifier): Unit = {
assertNotBucketed()
+ assertNotStreaming("insertInto() can only be called on non-continuous queries")
val partitions = normalizedParCols.map(_.map(col => col -> (None: Option[String])).toMap)
val overwrite = mode == SaveMode.Overwrite
@@ -446,6 +459,8 @@ final class DataFrameWriter private[sql](df: DataFrame) {
}
private def saveAsTable(tableIdent: TableIdentifier): Unit = {
+ assertNotStreaming("saveAsTable() can only be called on non-continuous queries")
+
val tableExists = df.sparkSession.sessionState.catalog.tableExists(tableIdent)
(tableExists, mode) match {
@@ -486,6 +501,8 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 1.4.0
*/
def jdbc(url: String, table: String, connectionProperties: Properties): Unit = {
+ assertNotStreaming("jdbc() can only be called on non-continuous queries")
+
val props = new Properties()
extraOptions.foreach { case (key, value) =>
props.put(key, value)
@@ -542,7 +559,10 @@ final class DataFrameWriter private[sql](df: DataFrame) {
*
* @since 1.4.0
*/
- def json(path: String): Unit = format("json").save(path)
+ def json(path: String): Unit = {
+ assertNotStreaming("json() can only be called on non-continuous queries")
+ format("json").save(path)
+ }
/**
* Saves the content of the [[DataFrame]] in Parquet format at the specified path.
@@ -558,7 +578,10 @@ final class DataFrameWriter private[sql](df: DataFrame) {
*
* @since 1.4.0
*/
- def parquet(path: String): Unit = format("parquet").save(path)
+ def parquet(path: String): Unit = {
+ assertNotStreaming("parquet() can only be called on non-continuous queries")
+ format("parquet").save(path)
+ }
/**
* Saves the content of the [[DataFrame]] in ORC format at the specified path.
@@ -575,7 +598,10 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 1.5.0
* @note Currently, this method can only be used together with `HiveContext`.
*/
- def orc(path: String): Unit = format("orc").save(path)
+ def orc(path: String): Unit = {
+ assertNotStreaming("orc() can only be called on non-continuous queries")
+ format("orc").save(path)
+ }
/**
* Saves the content of the [[DataFrame]] in a text file at the specified path.
@@ -596,7 +622,10 @@ final class DataFrameWriter private[sql](df: DataFrame) {
*
* @since 1.6.0
*/
- def text(path: String): Unit = format("text").save(path)
+ def text(path: String): Unit = {
+ assertNotStreaming("text() can only be called on non-continuous queries")
+ format("text").save(path)
+ }
/**
* Saves the content of the [[DataFrame]] in CSV format at the specified path.
@@ -620,7 +649,10 @@ final class DataFrameWriter private[sql](df: DataFrame) {
*
* @since 2.0.0
*/
- def csv(path: String): Unit = format("csv").save(path)
+ def csv(path: String): Unit = {
+ assertNotStreaming("csv() can only be called on non-continuous queries")
+ format("csv").save(path)
+ }
///////////////////////////////////////////////////////////////////////////////////////
// Builder pattern config options
@@ -641,4 +673,21 @@ final class DataFrameWriter private[sql](df: DataFrame) {
private var numBuckets: Option[Int] = None
private var sortColumnNames: Option[Seq[String]] = None
+
+ ///////////////////////////////////////////////////////////////////////////////////////
+ // Helper functions
+ ///////////////////////////////////////////////////////////////////////////////////////
+
+ private def assertNotStreaming(errMsg: String): Unit = {
+ if (df.isStreaming) {
+ throw new AnalysisException(errMsg)
+ }
+ }
+
+ private def assertStreaming(errMsg: String): Unit = {
+ if (!df.isStreaming) {
+ throw new AnalysisException(errMsg)
+ }
+ }
+
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala
index 00efe21d39..c7b2b99822 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala
@@ -368,4 +368,160 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B
"org.apache.spark.sql.streaming.test",
Map.empty)
}
+
+ private def newTextInput = Utils.createTempDir(namePrefix = "text").getCanonicalPath
+
+ test("check trigger() can only be called on continuous queries") {
+ val df = sqlContext.read.text(newTextInput)
+ val w = df.write.option("checkpointLocation", newMetadataDir)
+ val e = intercept[AnalysisException](w.trigger(ProcessingTime("10 seconds")))
+ assert(e.getMessage == "trigger() can only be called on continuous queries;")
+ }
+
+ test("check queryName() can only be called on continuous queries") {
+ val df = sqlContext.read.text(newTextInput)
+ val w = df.write.option("checkpointLocation", newMetadataDir)
+ val e = intercept[AnalysisException](w.queryName("queryName"))
+ assert(e.getMessage == "queryName() can only be called on continuous queries;")
+ }
+
+ test("check startStream() can only be called on continuous queries") {
+ val df = sqlContext.read.text(newTextInput)
+ val w = df.write.option("checkpointLocation", newMetadataDir)
+ val e = intercept[AnalysisException](w.startStream())
+ assert(e.getMessage == "startStream() can only be called on continuous queries;")
+ }
+
+ test("check startStream(path) can only be called on continuous queries") {
+ val df = sqlContext.read.text(newTextInput)
+ val w = df.write.option("checkpointLocation", newMetadataDir)
+ val e = intercept[AnalysisException](w.startStream("non_exist_path"))
+ assert(e.getMessage == "startStream() can only be called on continuous queries;")
+ }
+
+ test("check mode(SaveMode) can only be called on non-continuous queries") {
+ val df = sqlContext.read
+ .format("org.apache.spark.sql.streaming.test")
+ .stream()
+ val w = df.write
+ val e = intercept[AnalysisException](w.mode(SaveMode.Append))
+ assert(e.getMessage == "mode() can only be called on non-continuous queries;")
+ }
+
+ test("check mode(string) can only be called on non-continuous queries") {
+ val df = sqlContext.read
+ .format("org.apache.spark.sql.streaming.test")
+ .stream()
+ val w = df.write
+ val e = intercept[AnalysisException](w.mode("append"))
+ assert(e.getMessage == "mode() can only be called on non-continuous queries;")
+ }
+
+ test("check bucketBy() can only be called on non-continuous queries") {
+ val df = sqlContext.read
+ .format("org.apache.spark.sql.streaming.test")
+ .stream()
+ val w = df.write
+ val e = intercept[IllegalArgumentException](w.bucketBy(1, "text").startStream())
+ assert(e.getMessage == "Currently we don't support writing bucketed data to this data source.")
+ }
+
+ test("check sortBy() can only be called on non-continuous queries;") {
+ val df = sqlContext.read
+ .format("org.apache.spark.sql.streaming.test")
+ .stream()
+ val w = df.write
+ val e = intercept[IllegalArgumentException](w.sortBy("text").startStream())
+ assert(e.getMessage == "Currently we don't support writing bucketed data to this data source.")
+ }
+
+ test("check save(path) can only be called on non-continuous queries") {
+ val df = sqlContext.read
+ .format("org.apache.spark.sql.streaming.test")
+ .stream()
+ val w = df.write
+ val e = intercept[AnalysisException](w.save("non_exist_path"))
+ assert(e.getMessage == "save() can only be called on non-continuous queries;")
+ }
+
+ test("check save() can only be called on non-continuous queries") {
+ val df = sqlContext.read
+ .format("org.apache.spark.sql.streaming.test")
+ .stream()
+ val w = df.write
+ val e = intercept[AnalysisException](w.save())
+ assert(e.getMessage == "save() can only be called on non-continuous queries;")
+ }
+
+ test("check insertInto() can only be called on non-continuous queries") {
+ val df = sqlContext.read
+ .format("org.apache.spark.sql.streaming.test")
+ .stream()
+ val w = df.write
+ val e = intercept[AnalysisException](w.insertInto("non_exsit_table"))
+ assert(e.getMessage == "insertInto() can only be called on non-continuous queries;")
+ }
+
+ test("check saveAsTable() can only be called on non-continuous queries") {
+ val df = sqlContext.read
+ .format("org.apache.spark.sql.streaming.test")
+ .stream()
+ val w = df.write
+ val e = intercept[AnalysisException](w.saveAsTable("non_exsit_table"))
+ assert(e.getMessage == "saveAsTable() can only be called on non-continuous queries;")
+ }
+
+ test("check jdbc() can only be called on non-continuous queries") {
+ val df = sqlContext.read
+ .format("org.apache.spark.sql.streaming.test")
+ .stream()
+ val w = df.write
+ val e = intercept[AnalysisException](w.jdbc(null, null, null))
+ assert(e.getMessage == "jdbc() can only be called on non-continuous queries;")
+ }
+
+ test("check json() can only be called on non-continuous queries") {
+ val df = sqlContext.read
+ .format("org.apache.spark.sql.streaming.test")
+ .stream()
+ val w = df.write
+ val e = intercept[AnalysisException](w.json("non_exist_path"))
+ assert(e.getMessage == "json() can only be called on non-continuous queries;")
+ }
+
+ test("check parquet() can only be called on non-continuous queries") {
+ val df = sqlContext.read
+ .format("org.apache.spark.sql.streaming.test")
+ .stream()
+ val w = df.write
+ val e = intercept[AnalysisException](w.parquet("non_exist_path"))
+ assert(e.getMessage == "parquet() can only be called on non-continuous queries;")
+ }
+
+ test("check orc() can only be called on non-continuous queries") {
+ val df = sqlContext.read
+ .format("org.apache.spark.sql.streaming.test")
+ .stream()
+ val w = df.write
+ val e = intercept[AnalysisException](w.orc("non_exist_path"))
+ assert(e.getMessage == "orc() can only be called on non-continuous queries;")
+ }
+
+ test("check text() can only be called on non-continuous queries") {
+ val df = sqlContext.read
+ .format("org.apache.spark.sql.streaming.test")
+ .stream()
+ val w = df.write
+ val e = intercept[AnalysisException](w.text("non_exist_path"))
+ assert(e.getMessage == "text() can only be called on non-continuous queries;")
+ }
+
+ test("check csv() can only be called on non-continuous queries") {
+ val df = sqlContext.read
+ .format("org.apache.spark.sql.streaming.test")
+ .stream()
+ val w = df.write
+ val e = intercept[AnalysisException](w.csv("non_exist_path"))
+ assert(e.getMessage == "csv() can only be called on non-continuous queries;")
+ }
}