aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@163.com>2015-10-13 17:59:32 -0700
committerDavies Liu <davies.liu@gmail.com>2015-10-13 17:59:32 -0700
commit15ff85b3163acbe8052d4489a00bcf1d2332fcf0 (patch)
tree7ccd434650750c478a4d373951f7138f9ce830b1
parente170c22160bb452f98c340489ebf8390116a8cbb (diff)
downloadspark-15ff85b3163acbe8052d4489a00bcf1d2332fcf0.tar.gz
spark-15ff85b3163acbe8052d4489a00bcf1d2332fcf0.tar.bz2
spark-15ff85b3163acbe8052d4489a00bcf1d2332fcf0.zip
[SPARK-11068] [SQL] add callback to query execution
With this feature, we can track the query plan, time cost, exception during query execution for spark users. Author: Wenchen Fan <cloud0fan@163.com> Closes #9078 from cloud-fan/callback.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala46
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/QueryExecutionListener.scala136
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameCallbackSuite.scala82
4 files changed, 261 insertions, 6 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 01f60aba87..bfe8d3c8ef 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -1344,7 +1344,9 @@ class DataFrame private[sql](
* @group action
* @since 1.3.0
*/
- def head(n: Int): Array[Row] = limit(n).collect()
+ def head(n: Int): Array[Row] = withCallback("head", limit(n)) { df =>
+ df.collect(needCallback = false)
+ }
/**
* Returns the first row.
@@ -1414,8 +1416,18 @@ class DataFrame private[sql](
* @group action
* @since 1.3.0
*/
- def collect(): Array[Row] = withNewExecutionId {
- queryExecution.executedPlan.executeCollectPublic()
+ def collect(): Array[Row] = collect(needCallback = true)
+
+ private def collect(needCallback: Boolean): Array[Row] = {
+ def execute(): Array[Row] = withNewExecutionId {
+ queryExecution.executedPlan.executeCollectPublic()
+ }
+
+ if (needCallback) {
+ withCallback("collect", this)(_ => execute())
+ } else {
+ execute()
+ }
}
/**
@@ -1423,8 +1435,10 @@ class DataFrame private[sql](
* @group action
* @since 1.3.0
*/
- def collectAsList(): java.util.List[Row] = withNewExecutionId {
- java.util.Arrays.asList(rdd.collect() : _*)
+ def collectAsList(): java.util.List[Row] = withCallback("collectAsList", this) { _ =>
+ withNewExecutionId {
+ java.util.Arrays.asList(rdd.collect() : _*)
+ }
}
/**
@@ -1432,7 +1446,9 @@ class DataFrame private[sql](
* @group action
* @since 1.3.0
*/
- def count(): Long = groupBy().count().collect().head.getLong(0)
+ def count(): Long = withCallback("count", groupBy().count()) { df =>
+ df.collect(needCallback = false).head.getLong(0)
+ }
/**
* Returns a new [[DataFrame]] that has exactly `numPartitions` partitions.
@@ -1936,6 +1952,24 @@ class DataFrame private[sql](
SQLExecution.withNewExecutionId(sqlContext, queryExecution)(body)
}
+ /**
+ * Wrap a DataFrame action to track the QueryExecution and time cost, then report to the
+ * user-registered callback functions.
+ */
+ private def withCallback[T](name: String, df: DataFrame)(action: DataFrame => T) = {
+ try {
+ val start = System.nanoTime()
+ val result = action(df)
+ val end = System.nanoTime()
+ sqlContext.listenerManager.onSuccess(name, df.queryExecution, end - start)
+ result
+ } catch {
+ case e: Exception =>
+ sqlContext.listenerManager.onFailure(name, df.queryExecution, e)
+ throw e
+ }
+ }
+
////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////
// End of deprecated methods
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/QueryExecutionListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/QueryExecutionListener.scala
new file mode 100644
index 0000000000..14fbebb45f
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/QueryExecutionListener.scala
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.util.concurrent.locks.ReentrantReadWriteLock
+import scala.collection.mutable.ListBuffer
+
+import org.apache.spark.annotation.{DeveloperApi, Experimental}
+import org.apache.spark.Logging
+import org.apache.spark.sql.execution.QueryExecution
+
+
+/**
+ * The interface of query execution listener that can be used to analyze execution metrics.
+ *
+ * Note that implementations should guarantee thread-safety as they will be used in a non
+ * thread-safe way.
+ */
+@Experimental
+trait QueryExecutionListener {
+
+ /**
+ * A callback function that will be called when a query executed successfully.
+ * Implementations should guarantee thread-safe.
+ *
+ * @param funcName the name of the action that triggered this query.
+ * @param qe the QueryExecution object that carries detail information like logical plan,
+ * physical plan, etc.
+ * @param duration the execution time for this query in nanoseconds.
+ */
+ @DeveloperApi
+ def onSuccess(funcName: String, qe: QueryExecution, duration: Long)
+
+ /**
+ * A callback function that will be called when a query execution failed.
+ * Implementations should guarantee thread-safe.
+ *
+ * @param funcName the name of the action that triggered this query.
+ * @param qe the QueryExecution object that carries detail information like logical plan,
+ * physical plan, etc.
+ * @param exception the exception that failed this query.
+ */
+ @DeveloperApi
+ def onFailure(funcName: String, qe: QueryExecution, exception: Exception)
+}
+
+@Experimental
+class ExecutionListenerManager extends Logging {
+ private[this] val listeners = ListBuffer.empty[QueryExecutionListener]
+ private[this] val lock = new ReentrantReadWriteLock()
+
+ /** Acquires a read lock on the cache for the duration of `f`. */
+ private def readLock[A](f: => A): A = {
+ val rl = lock.readLock()
+ rl.lock()
+ try f finally {
+ rl.unlock()
+ }
+ }
+
+ /** Acquires a write lock on the cache for the duration of `f`. */
+ private def writeLock[A](f: => A): A = {
+ val wl = lock.writeLock()
+ wl.lock()
+ try f finally {
+ wl.unlock()
+ }
+ }
+
+ /**
+ * Registers the specified QueryExecutionListener.
+ */
+ @DeveloperApi
+ def register(listener: QueryExecutionListener): Unit = writeLock {
+ listeners += listener
+ }
+
+ /**
+ * Unregisters the specified QueryExecutionListener.
+ */
+ @DeveloperApi
+ def unregister(listener: QueryExecutionListener): Unit = writeLock {
+ listeners -= listener
+ }
+
+ /**
+ * clears out all registered QueryExecutionListeners.
+ */
+ @DeveloperApi
+ def clear(): Unit = writeLock {
+ listeners.clear()
+ }
+
+ private[sql] def onSuccess(
+ funcName: String,
+ qe: QueryExecution,
+ duration: Long): Unit = readLock {
+ withErrorHandling { listener =>
+ listener.onSuccess(funcName, qe, duration)
+ }
+ }
+
+ private[sql] def onFailure(
+ funcName: String,
+ qe: QueryExecution,
+ exception: Exception): Unit = readLock {
+ withErrorHandling { listener =>
+ listener.onFailure(funcName, qe, exception)
+ }
+ }
+
+ private def withErrorHandling(f: QueryExecutionListener => Unit): Unit = {
+ for (listener <- listeners) {
+ try {
+ f(listener)
+ } catch {
+ case e: Exception => logWarning("error executing query execution listener", e)
+ }
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index cd937257d3..a835408f8a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -178,6 +178,9 @@ class SQLContext private[sql](
def getAllConfs: immutable.Map[String, String] = conf.getAllConfs
@transient
+ lazy val listenerManager: ExecutionListenerManager = new ExecutionListenerManager
+
+ @transient
protected[sql] lazy val catalog: Catalog = new SimpleCatalog(conf)
@transient
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameCallbackSuite.scala
new file mode 100644
index 0000000000..4e286a0076
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameCallbackSuite.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Project}
+import org.apache.spark.sql.execution.QueryExecution
+import org.apache.spark.sql.test.SharedSQLContext
+
+import scala.collection.mutable.ArrayBuffer
+
+class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
+ import testImplicits._
+ import functions._
+
+ test("execute callback functions when a DataFrame action finished successfully") {
+ val metrics = ArrayBuffer.empty[(String, QueryExecution, Long)]
+ val listener = new QueryExecutionListener {
+ // Only test successful case here, so no need to implement `onFailure`
+ override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
+
+ override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
+ metrics += ((funcName, qe, duration))
+ }
+ }
+ sqlContext.listenerManager.register(listener)
+
+ val df = Seq(1 -> "a").toDF("i", "j")
+ df.select("i").collect()
+ df.filter($"i" > 0).count()
+
+ assert(metrics.length == 2)
+
+ assert(metrics(0)._1 == "collect")
+ assert(metrics(0)._2.analyzed.isInstanceOf[Project])
+ assert(metrics(0)._3 > 0)
+
+ assert(metrics(1)._1 == "count")
+ assert(metrics(1)._2.analyzed.isInstanceOf[Aggregate])
+ assert(metrics(1)._3 > 0)
+ }
+
+ test("execute callback functions when a DataFrame action failed") {
+ val metrics = ArrayBuffer.empty[(String, QueryExecution, Exception)]
+ val listener = new QueryExecutionListener {
+ override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {
+ metrics += ((funcName, qe, exception))
+ }
+
+ // Only test failed case here, so no need to implement `onSuccess`
+ override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {}
+ }
+ sqlContext.listenerManager.register(listener)
+
+ val errorUdf = udf[Int, Int] { _ => throw new RuntimeException("udf error") }
+ val df = sparkContext.makeRDD(Seq(1 -> "a")).toDF("i", "j")
+
+ // Ignore the log when we are expecting an exception.
+ sparkContext.setLogLevel("FATAL")
+ val e = intercept[SparkException](df.select(errorUdf($"i")).collect())
+
+ assert(metrics.length == 1)
+ assert(metrics(0)._1 == "collect")
+ assert(metrics(0)._2.analyzed.isInstanceOf[Project])
+ assert(metrics(0)._3.getMessage == e.getMessage)
+ }
+}