aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala28
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala28
3 files changed, 52 insertions, 11 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
index 52d8dc22a2..58f5071193 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
@@ -86,18 +86,18 @@ case class ExecutedCommandExec(cmd: RunnableCommand) extends SparkPlan {
* }}}
*
* @param logicalPlan plan to explain
- * @param output output schema
* @param extended whether to do extended explain or not
* @param codegen whether to output generated code from whole-stage codegen or not
*/
case class ExplainCommand(
logicalPlan: LogicalPlan,
- override val output: Seq[Attribute] =
- Seq(AttributeReference("plan", StringType, nullable = true)()),
extended: Boolean = false,
codegen: Boolean = false)
extends RunnableCommand {
+ override val output: Seq[Attribute] =
+ Seq(AttributeReference("plan", StringType, nullable = true)())
+
// Run through the optimizer to generate the physical plan.
override def run(sparkSession: SparkSession): Seq[Row] = try {
val queryExecution =
@@ -121,3 +121,25 @@ case class ExplainCommand(
("Error occurred during query planning: \n" + cause.getMessage).split("\n").map(Row(_))
}
}
+
+/** An explain command for users to see how a streaming batch is executed. */
+case class StreamingExplainCommand(
+ queryExecution: IncrementalExecution,
+ extended: Boolean) extends RunnableCommand {
+
+ override val output: Seq[Attribute] =
+ Seq(AttributeReference("plan", StringType, nullable = true)())
+
+ // Run through the optimizer to generate the physical plan.
+ override def run(sparkSession: SparkSession): Seq[Row] = try {
+ val outputString =
+ if (extended) {
+ queryExecution.toString
+ } else {
+ queryExecution.simpleString
+ }
+ Seq(Row(outputString))
+ } catch { case cause: TreeNodeException[_] =>
+ ("Error occurred during query planning: \n" + cause.getMessage).split("\n").map(Row(_))
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index 239d49b08a..e1af420a69 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.execution.streaming
-import java.io.IOException
import java.util.UUID
import java.util.concurrent.{CountDownLatch, TimeUnit}
import java.util.concurrent.locks.ReentrantLock
@@ -33,7 +32,7 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.execution.QueryExecution
-import org.apache.spark.sql.execution.command.ExplainCommand
+import org.apache.spark.sql.execution.command.StreamingExplainCommand
import org.apache.spark.sql.streaming._
import org.apache.spark.util.{Clock, UninterruptibleThread, Utils}
@@ -162,7 +161,7 @@ class StreamExecution(
private var state: State = INITIALIZING
@volatile
- var lastExecution: QueryExecution = _
+ var lastExecution: IncrementalExecution = _
/** Holds the most recent input data for each source. */
protected var newData: Map[Source, DataFrame] = _
@@ -673,7 +672,7 @@ class StreamExecution(
if (lastExecution == null) {
"No physical plan. Waiting for data."
} else {
- val explain = ExplainCommand(lastExecution.logical, extended = extended)
+ val explain = StreamingExplainCommand(lastExecution, extended = extended)
sparkSession.sessionState.executePlan(explain).executedPlan.executeCollect()
.map(_.getString(0)).mkString("\n")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
index f31dc8add4..0296a2ade3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
@@ -22,7 +22,9 @@ import scala.util.control.ControlThrowable
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
+import org.apache.spark.sql.execution.command.ExplainCommand
import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.sources.StreamSourceProvider
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
@@ -277,10 +279,24 @@ class StreamSuite extends StreamTest {
test("explain") {
val inputData = MemoryStream[String]
- val df = inputData.toDS().map(_ + "foo")
- // Test `explain` not throwing errors
- df.explain()
- val q = df.writeStream.queryName("memory_explain").format("memory").start()
+ val df = inputData.toDS().map(_ + "foo").groupBy("value").agg(count("*"))
+
+ // Test `df.explain`
+ val explain = ExplainCommand(df.queryExecution.logical, extended = false)
+ val explainString =
+ spark.sessionState
+ .executePlan(explain)
+ .executedPlan
+ .executeCollect()
+ .map(_.getString(0))
+ .mkString("\n")
+ assert(explainString.contains("StateStoreRestore"))
+ assert(explainString.contains("StreamingRelation"))
+ assert(!explainString.contains("LocalTableScan"))
+
+ // Test StreamingQuery.display
+ val q = df.writeStream.queryName("memory_explain").outputMode("complete").format("memory")
+ .start()
.asInstanceOf[StreamingQueryWrapper]
.streamingQuery
try {
@@ -294,12 +310,16 @@ class StreamSuite extends StreamTest {
// `extended = false` only displays the physical plan.
assert("LocalRelation".r.findAllMatchIn(explainWithoutExtended).size === 0)
assert("LocalTableScan".r.findAllMatchIn(explainWithoutExtended).size === 1)
+ // Use "StateStoreRestore" to verify that it does output a streaming physical plan
+ assert(explainWithoutExtended.contains("StateStoreRestore"))
val explainWithExtended = q.explainInternal(true)
// `extended = true` displays 3 logical plans (Parsed/Optimized/Optimized) and 1 physical
// plan.
assert("LocalRelation".r.findAllMatchIn(explainWithExtended).size === 3)
assert("LocalTableScan".r.findAllMatchIn(explainWithExtended).size === 1)
+ // Use "StateStoreRestore" to verify that it does output a streaming physical plan
+ assert(explainWithExtended.contains("StateStoreRestore"))
} finally {
q.stop()
}