aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/test
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2017-02-16 21:09:14 -0800
committerWenchen Fan <wenchen@databricks.com>2017-02-16 21:09:14 -0800
commit54d23599df7c28a7685416ced6ad8fcde047e534 (patch)
tree5c8177482f13ef0c70e08ef14249a71cdbc363c3 /sql/core/src/test
parent21fde57f15db974b710e7b00e72c744da7c1ac3c (diff)
downloadspark-54d23599df7c28a7685416ced6ad8fcde047e534.tar.gz
spark-54d23599df7c28a7685416ced6ad8fcde047e534.tar.bz2
spark-54d23599df7c28a7685416ced6ad8fcde047e534.zip
[SPARK-18120][SPARK-19557][SQL] Call QueryExecutionListener callback methods for DataFrameWriter methods
## What changes were proposed in this pull request? We only notify `QueryExecutionListener` for several `Dataset` operations, e.g. collect, take, etc. We should also do the notification for `DataFrameWriter` operations. ## How was this patch tested? new regression test close https://github.com/apache/spark/pull/16664 Author: Wenchen Fan <wenchen@databricks.com> Closes #16962 from cloud-fan/insert.
Diffstat (limited to 'sql/core/src/test')
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala57
1 files changed, 55 insertions, 2 deletions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
index 3ae5ce610d..9f27d06dcb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
@@ -20,9 +20,11 @@ package org.apache.spark.sql.util
import scala.collection.mutable.ArrayBuffer
import org.apache.spark._
-import org.apache.spark.sql.{functions, QueryTest}
-import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Project}
+import org.apache.spark.sql.{functions, AnalysisException, QueryTest}
+import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, InsertIntoTable, LogicalPlan, Project}
import org.apache.spark.sql.execution.{QueryExecution, WholeStageCodegenExec}
+import org.apache.spark.sql.execution.datasources.{CreateTable, SaveIntoDataSourceCommand}
import org.apache.spark.sql.test.SharedSQLContext
class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
@@ -159,4 +161,55 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
spark.listenerManager.unregister(listener)
}
+
+ test("execute callback functions for DataFrameWriter") {
+ val commands = ArrayBuffer.empty[(String, LogicalPlan)]
+ val exceptions = ArrayBuffer.empty[(String, Exception)]
+ val listener = new QueryExecutionListener {
+ override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {
+ exceptions += funcName -> exception
+ }
+
+ override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
+ commands += funcName -> qe.logical
+ }
+ }
+ spark.listenerManager.register(listener)
+
+ withTempPath { path =>
+ spark.range(10).write.format("json").save(path.getCanonicalPath)
+ assert(commands.length == 1)
+ assert(commands.head._1 == "save")
+ assert(commands.head._2.isInstanceOf[SaveIntoDataSourceCommand])
+ assert(commands.head._2.asInstanceOf[SaveIntoDataSourceCommand].provider == "json")
+ }
+
+ withTable("tab") {
+ sql("CREATE TABLE tab(i long) using parquet")
+ spark.range(10).write.insertInto("tab")
+ assert(commands.length == 2)
+ assert(commands(1)._1 == "insertInto")
+ assert(commands(1)._2.isInstanceOf[InsertIntoTable])
+ assert(commands(1)._2.asInstanceOf[InsertIntoTable].table
+ .asInstanceOf[UnresolvedRelation].tableIdentifier.table == "tab")
+ }
+
+ withTable("tab") {
+ spark.range(10).select($"id", $"id" % 5 as "p").write.partitionBy("p").saveAsTable("tab")
+ assert(commands.length == 3)
+ assert(commands(2)._1 == "saveAsTable")
+ assert(commands(2)._2.isInstanceOf[CreateTable])
+ assert(commands(2)._2.asInstanceOf[CreateTable].tableDesc.partitionColumnNames == Seq("p"))
+ }
+
+ withTable("tab") {
+ sql("CREATE TABLE tab(i long) using parquet")
+ val e = intercept[AnalysisException] {
+ spark.range(10).select($"id", $"id").write.insertInto("tab")
+ }
+ assert(exceptions.length == 1)
+ assert(exceptions.head._1 == "insertInto")
+ assert(exceptions.head._2 == e)
+ }
+ }
}