diff options
author | Liwei Lin <lwlin7@gmail.com> | 2016-05-02 16:48:20 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2016-05-02 16:48:20 -0700 |
commit | 35d9c8aa69c650f33037813607dc939922c5fc27 (patch) | |
tree | 269f03fd949a43d8998bff9eb8c88c44fc61795e /sql | |
parent | f362363d148e2df4549fed5c3fd1cf20d0848fd0 (diff) | |
download | spark-35d9c8aa69c650f33037813607dc939922c5fc27.tar.gz spark-35d9c8aa69c650f33037813607dc939922c5fc27.tar.bz2 spark-35d9c8aa69c650f33037813607dc939922c5fc27.zip |
[SPARK-14747][SQL] Add assertStreaming/assertNoneStreaming checks in DataFrameWriter
## Problem
If an end user happens to write code mixed with continuous-query-oriented methods and non-continuous-query-oriented methods:
```scala
ctx.read
.format("text")
.stream("...") // continuous query
.write
.text("...") // non-continuous query; should be startStream() here
```
He/she would get this somehow confusing exception:
>
Exception in thread "main" java.lang.AssertionError: assertion failed: No plan for FileSource[./continuous_query_test_input]
at scala.Predef$.assert(Predef.scala:170)
at org.apache.spark.sql.catalyst.planning.QueryPlanner.plan(QueryPlanner.scala:59)
at org.apache.spark.sql.catalyst.planning.QueryPlanner.planLater(QueryPlanner.scala:54)
at ...
## What changes were proposed in this pull request?
This PR adds checks for continuous-query-oriented methods and non-continuous-query-oriented methods in `DataFrameWriter`:
<table>
<tr>
<td align="center"></td>
<td align="center"><strong>can be called on continuous query?</strong></td>
<td align="center"><strong>can be called on non-continuous query?</strong></td>
</tr>
<tr>
<td align="center">mode</td>
<td align="center"></td>
<td align="center">yes</td>
</tr>
<tr>
<td align="center">trigger</td>
<td align="center">yes</td>
<td align="center"></td>
</tr>
<tr>
<td align="center">format</td>
<td align="center">yes</td>
<td align="center">yes</td>
</tr>
<tr>
<td align="center">option/options</td>
<td align="center">yes</td>
<td align="center">yes</td>
</tr>
<tr>
<td align="center">partitionBy</td>
<td align="center">yes</td>
<td align="center">yes</td>
</tr>
<tr>
<td align="center">bucketBy</td>
<td align="center"></td>
<td align="center">yes</td>
</tr>
<tr>
<td align="center">sortBy</td>
<td align="center"></td>
<td align="center">yes</td>
</tr>
<tr>
<td align="center">save</td>
<td align="center"></td>
<td align="center">yes</td>
</tr>
<tr>
<td align="center">queryName</td>
<td align="center">yes</td>
<td align="center"></td>
</tr>
<tr>
<td align="center">startStream</td>
<td align="center">yes</td>
<td align="center"></td>
</tr>
<tr>
<td align="center">insertInto</td>
<td align="center"></td>
<td align="center">yes</td>
</tr>
<tr>
<td align="center">saveAsTable</td>
<td align="center"></td>
<td align="center">yes</td>
</tr>
<tr>
<td align="center">jdbc</td>
<td align="center"></td>
<td align="center">yes</td>
</tr>
<tr>
<td align="center">json</td>
<td align="center"></td>
<td align="center">yes</td>
</tr>
<tr>
<td align="center">parquet</td>
<td align="center"></td>
<td align="center">yes</td>
</tr>
<tr>
<td align="center">orc</td>
<td align="center"></td>
<td align="center">yes</td>
</tr>
<tr>
<td align="center">text</td>
<td align="center"></td>
<td align="center">yes</td>
</tr>
<tr>
<td align="center">csv</td>
<td align="center"></td>
<td align="center">yes</td>
</tr>
</table>
After this PR's change, the friendly exception would be:
>
Exception in thread "main" org.apache.spark.sql.AnalysisException: text() can only be called on non-continuous queries;
at org.apache.spark.sql.DataFrameWriter.assertNotStreaming(DataFrameWriter.scala:678)
at org.apache.spark.sql.DataFrameWriter.text(DataFrameWriter.scala:629)
at ss.SSDemo$.main(SSDemo.scala:47)
## How was this patch tested?
dedicated unit tests were added
Author: Liwei Lin <lwlin7@gmail.com>
Closes #12521 from lw-lin/dataframe-writer-check.
Diffstat (limited to 'sql')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala | 59 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala | 156 |
2 files changed, 210 insertions, 5 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index a57d47d28c..a8f96a9b45 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -53,6 +53,9 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def mode(saveMode: SaveMode): DataFrameWriter = { + // mode() is used for non-continuous queries + // outputMode() is used for continuous queries + assertNotStreaming("mode() can only be called on non-continuous queries") this.mode = saveMode this } @@ -67,6 +70,9 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def mode(saveMode: String): DataFrameWriter = { + // mode() is used for non-continuous queries + // outputMode() is used for continuous queries + assertNotStreaming("mode() can only be called on non-continuous queries") this.mode = saveMode.toLowerCase match { case "overwrite" => SaveMode.Overwrite case "append" => SaveMode.Append @@ -103,6 +109,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { */ @Experimental def trigger(trigger: Trigger): DataFrameWriter = { + assertStreaming("trigger() can only be called on continuous queries") this.trigger = trigger this } @@ -236,6 +243,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { */ def save(): Unit = { assertNotBucketed() + assertNotStreaming("save() can only be called on non-continuous queries") val dataSource = DataSource( df.sparkSession, className = source, @@ -253,6 +261,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 2.0.0 */ def queryName(queryName: String): DataFrameWriter = { + assertStreaming("queryName() can only be called on continuous queries") this.extraOptions += ("queryName" -> queryName) this } @@ -276,6 +285,9 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 2.0.0 */ def startStream(): ContinuousQuery = { + assertNotBucketed + assertStreaming("startStream() can only be called on continuous queries") + if (source == "memory") { val queryName = extraOptions.getOrElse( @@ -348,6 +360,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { private def insertInto(tableIdent: TableIdentifier): Unit = { assertNotBucketed() + assertNotStreaming("insertInto() can only be called on non-continuous queries") val partitions = normalizedParCols.map(_.map(col => col -> (None: Option[String])).toMap) val overwrite = mode == SaveMode.Overwrite @@ -446,6 +459,8 @@ final class DataFrameWriter private[sql](df: DataFrame) { } private def saveAsTable(tableIdent: TableIdentifier): Unit = { + assertNotStreaming("saveAsTable() can only be called on non-continuous queries") + val tableExists = df.sparkSession.sessionState.catalog.tableExists(tableIdent) (tableExists, mode) match { @@ -486,6 +501,8 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def jdbc(url: String, table: String, connectionProperties: Properties): Unit = { + assertNotStreaming("jdbc() can only be called on non-continuous queries") + val props = new Properties() extraOptions.foreach { case (key, value) => props.put(key, value) @@ -542,7 +559,10 @@ final class DataFrameWriter private[sql](df: DataFrame) { * * @since 1.4.0 */ - def json(path: String): Unit = format("json").save(path) + def json(path: String): Unit = { + assertNotStreaming("json() can only be called on non-continuous queries") + format("json").save(path) + } /** * Saves the content of the [[DataFrame]] in Parquet format at the specified path. @@ -558,7 +578,10 @@ final class DataFrameWriter private[sql](df: DataFrame) { * * @since 1.4.0 */ - def parquet(path: String): Unit = format("parquet").save(path) + def parquet(path: String): Unit = { + assertNotStreaming("parquet() can only be called on non-continuous queries") + format("parquet").save(path) + } /** * Saves the content of the [[DataFrame]] in ORC format at the specified path. @@ -575,7 +598,10 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.5.0 * @note Currently, this method can only be used together with `HiveContext`. */ - def orc(path: String): Unit = format("orc").save(path) + def orc(path: String): Unit = { + assertNotStreaming("orc() can only be called on non-continuous queries") + format("orc").save(path) + } /** * Saves the content of the [[DataFrame]] in a text file at the specified path. @@ -596,7 +622,10 @@ final class DataFrameWriter private[sql](df: DataFrame) { * * @since 1.6.0 */ - def text(path: String): Unit = format("text").save(path) + def text(path: String): Unit = { + assertNotStreaming("text() can only be called on non-continuous queries") + format("text").save(path) + } /** * Saves the content of the [[DataFrame]] in CSV format at the specified path. @@ -620,7 +649,10 @@ final class DataFrameWriter private[sql](df: DataFrame) { * * @since 2.0.0 */ - def csv(path: String): Unit = format("csv").save(path) + def csv(path: String): Unit = { + assertNotStreaming("csv() can only be called on non-continuous queries") + format("csv").save(path) + } /////////////////////////////////////////////////////////////////////////////////////// // Builder pattern config options @@ -641,4 +673,21 @@ final class DataFrameWriter private[sql](df: DataFrame) { private var numBuckets: Option[Int] = None private var sortColumnNames: Option[Seq[String]] = None + + /////////////////////////////////////////////////////////////////////////////////////// + // Helper functions + /////////////////////////////////////////////////////////////////////////////////////// + + private def assertNotStreaming(errMsg: String): Unit = { + if (df.isStreaming) { + throw new AnalysisException(errMsg) + } + } + + private def assertStreaming(errMsg: String): Unit = { + if (!df.isStreaming) { + throw new AnalysisException(errMsg) + } + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala index 00efe21d39..c7b2b99822 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala @@ -368,4 +368,160 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B "org.apache.spark.sql.streaming.test", Map.empty) } + + private def newTextInput = Utils.createTempDir(namePrefix = "text").getCanonicalPath + + test("check trigger() can only be called on continuous queries") { + val df = sqlContext.read.text(newTextInput) + val w = df.write.option("checkpointLocation", newMetadataDir) + val e = intercept[AnalysisException](w.trigger(ProcessingTime("10 seconds"))) + assert(e.getMessage == "trigger() can only be called on continuous queries;") + } + + test("check queryName() can only be called on continuous queries") { + val df = sqlContext.read.text(newTextInput) + val w = df.write.option("checkpointLocation", newMetadataDir) + val e = intercept[AnalysisException](w.queryName("queryName")) + assert(e.getMessage == "queryName() can only be called on continuous queries;") + } + + test("check startStream() can only be called on continuous queries") { + val df = sqlContext.read.text(newTextInput) + val w = df.write.option("checkpointLocation", newMetadataDir) + val e = intercept[AnalysisException](w.startStream()) + assert(e.getMessage == "startStream() can only be called on continuous queries;") + } + + test("check startStream(path) can only be called on continuous queries") { + val df = sqlContext.read.text(newTextInput) + val w = df.write.option("checkpointLocation", newMetadataDir) + val e = intercept[AnalysisException](w.startStream("non_exist_path")) + assert(e.getMessage == "startStream() can only be called on continuous queries;") + } + + test("check mode(SaveMode) can only be called on non-continuous queries") { + val df = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[AnalysisException](w.mode(SaveMode.Append)) + assert(e.getMessage == "mode() can only be called on non-continuous queries;") + } + + test("check mode(string) can only be called on non-continuous queries") { + val df = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[AnalysisException](w.mode("append")) + assert(e.getMessage == "mode() can only be called on non-continuous queries;") + } + + test("check bucketBy() can only be called on non-continuous queries") { + val df = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[IllegalArgumentException](w.bucketBy(1, "text").startStream()) + assert(e.getMessage == "Currently we don't support writing bucketed data to this data source.") + } + + test("check sortBy() can only be called on non-continuous queries;") { + val df = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[IllegalArgumentException](w.sortBy("text").startStream()) + assert(e.getMessage == "Currently we don't support writing bucketed data to this data source.") + } + + test("check save(path) can only be called on non-continuous queries") { + val df = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[AnalysisException](w.save("non_exist_path")) + assert(e.getMessage == "save() can only be called on non-continuous queries;") + } + + test("check save() can only be called on non-continuous queries") { + val df = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[AnalysisException](w.save()) + assert(e.getMessage == "save() can only be called on non-continuous queries;") + } + + test("check insertInto() can only be called on non-continuous queries") { + val df = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[AnalysisException](w.insertInto("non_exsit_table")) + assert(e.getMessage == "insertInto() can only be called on non-continuous queries;") + } + + test("check saveAsTable() can only be called on non-continuous queries") { + val df = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[AnalysisException](w.saveAsTable("non_exsit_table")) + assert(e.getMessage == "saveAsTable() can only be called on non-continuous queries;") + } + + test("check jdbc() can only be called on non-continuous queries") { + val df = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[AnalysisException](w.jdbc(null, null, null)) + assert(e.getMessage == "jdbc() can only be called on non-continuous queries;") + } + + test("check json() can only be called on non-continuous queries") { + val df = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[AnalysisException](w.json("non_exist_path")) + assert(e.getMessage == "json() can only be called on non-continuous queries;") + } + + test("check parquet() can only be called on non-continuous queries") { + val df = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[AnalysisException](w.parquet("non_exist_path")) + assert(e.getMessage == "parquet() can only be called on non-continuous queries;") + } + + test("check orc() can only be called on non-continuous queries") { + val df = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[AnalysisException](w.orc("non_exist_path")) + assert(e.getMessage == "orc() can only be called on non-continuous queries;") + } + + test("check text() can only be called on non-continuous queries") { + val df = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[AnalysisException](w.text("non_exist_path")) + assert(e.getMessage == "text() can only be called on non-continuous queries;") + } + + test("check csv() can only be called on non-continuous queries") { + val df = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream() + val w = df.write + val e = intercept[AnalysisException](w.csv("non_exist_path")) + assert(e.getMessage == "csv() can only be called on non-continuous queries;") + } } |