aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-03-03 17:36:48 -0800
committerYin Huai <yhuai@databricks.com>2016-03-03 17:36:48 -0800
commitb373a888621ba6f0dd499f47093d4e2e42086dfc (patch)
treee68780effb3df46a612a076fe425920def81015b /sql
parentad0de99f3d3167990d501297f1df069fe15e0678 (diff)
downloadspark-b373a888621ba6f0dd499f47093d4e2e42086dfc.tar.gz
spark-b373a888621ba6f0dd499f47093d4e2e42086dfc.tar.bz2
spark-b373a888621ba6f0dd499f47093d4e2e42086dfc.zip
[SPARK-13415][SQL] Visualize subquery in SQL web UI
## What changes were proposed in this pull request? This PR support visualization for subquery in SQL web UI, also improve the explain of subquery, especially when it's used together with whole stage codegen. For example: ```python >>> sqlContext.range(100).registerTempTable("range") >>> sqlContext.sql("select id / (select sum(id) from range) from range where id > (select id from range limit 1)").explain(True) == Parsed Logical Plan == 'Project [unresolvedalias(('id / subquery#9), None)] : +- 'SubqueryAlias subquery#9 : +- 'Project [unresolvedalias('sum('id), None)] : +- 'UnresolvedRelation `range`, None +- 'Filter ('id > subquery#8) : +- 'SubqueryAlias subquery#8 : +- 'GlobalLimit 1 : +- 'LocalLimit 1 : +- 'Project [unresolvedalias('id, None)] : +- 'UnresolvedRelation `range`, None +- 'UnresolvedRelation `range`, None == Analyzed Logical Plan == (id / scalarsubquery()): double Project [(cast(id#0L as double) / cast(subquery#9 as double)) AS (id / scalarsubquery())#11] : +- SubqueryAlias subquery#9 : +- Aggregate [(sum(id#0L),mode=Complete,isDistinct=false) AS sum(id)#10L] : +- SubqueryAlias range : +- Range 0, 100, 1, 4, [id#0L] +- Filter (id#0L > subquery#8) : +- SubqueryAlias subquery#8 : +- GlobalLimit 1 : +- LocalLimit 1 : +- Project [id#0L] : +- SubqueryAlias range : +- Range 0, 100, 1, 4, [id#0L] +- SubqueryAlias range +- Range 0, 100, 1, 4, [id#0L] == Optimized Logical Plan == Project [(cast(id#0L as double) / cast(subquery#9 as double)) AS (id / scalarsubquery())#11] : +- SubqueryAlias subquery#9 : +- Aggregate [(sum(id#0L),mode=Complete,isDistinct=false) AS sum(id)#10L] : +- Range 0, 100, 1, 4, [id#0L] +- Filter (id#0L > subquery#8) : +- SubqueryAlias subquery#8 : +- GlobalLimit 1 : +- LocalLimit 1 : +- Project [id#0L] : +- Range 0, 100, 1, 4, [id#0L] +- Range 0, 100, 1, 4, [id#0L] == Physical Plan == WholeStageCodegen : +- Project [(cast(id#0L as double) / cast(subquery#9 as double)) AS (id / scalarsubquery())#11] : : +- Subquery subquery#9 : : +- WholeStageCodegen : : : +- TungstenAggregate(key=[], functions=[(sum(id#0L),mode=Final,isDistinct=false)], output=[sum(id)#10L]) : : : +- INPUT : : +- Exchange SinglePartition, None : : +- WholeStageCodegen : : : +- TungstenAggregate(key=[], functions=[(sum(id#0L),mode=Partial,isDistinct=false)], output=[sum#14L]) : : : +- Range 0, 1, 4, 100, [id#0L] : +- Filter (id#0L > subquery#8) : : +- Subquery subquery#8 : : +- CollectLimit 1 : : +- WholeStageCodegen : : : +- Project [id#0L] : : : +- Range 0, 1, 4, 100, [id#0L] : +- Range 0, 1, 4, 100, [id#0L] ``` The web UI looks like: ![subquery](https://cloud.githubusercontent.com/assets/40902/13377963/932bcbae-dda7-11e5-82f7-03c9be85d77c.png) This PR also change the tree structure of WholeStageCodegen to make it consistent than others. Before this change, Both WholeStageCodegen and InputAdapter hold a references to the same plans, those could be updated without notify another, causing problems, this is discovered by #11403 . ## How was this patch tested? Existing tests, also manual tests with the example query, check the explain and web UI. Author: Davies Liu <davies@databricks.com> Closes #11417 from davies/viz_subquery.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala49
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala113
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala23
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala66
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala17
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala2
9 files changed, 166 insertions, 127 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 3ff37fffbd..0e0453b517 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -229,8 +229,12 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
override def simpleString: String = statePrefix + super.simpleString
- override def treeChildren: Seq[PlanType] = {
- val subqueries = expressions.flatMap(_.collect {case e: SubqueryExpression => e})
- children ++ subqueries.map(e => e.plan.asInstanceOf[PlanType])
+ /**
+ * All the subqueries of current plan.
+ */
+ def subqueries: Seq[PlanType] = {
+ expressions.flatMap(_.collect {case e: SubqueryExpression => e.plan.asInstanceOf[PlanType]})
}
+
+ override def innerChildren: Seq[PlanType] = subqueries
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index 2d0bf6b375..6b7997e903 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -447,10 +447,53 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
/**
* All the nodes that will be used to generate tree string.
+ *
+ * For example:
+ *
+ * WholeStageCodegen
+ * +-- SortMergeJoin
+ * |-- InputAdapter
+ * | +-- Sort
+ * +-- InputAdapter
+ * +-- Sort
+ *
+ * the treeChildren of WholeStageCodegen will be Seq(Sort, Sort), it will generate a tree string
+ * like this:
+ *
+ * WholeStageCodegen
+ * : +- SortMergeJoin
+ * : :- INPUT
+ * : :- INPUT
+ * :- Sort
+ * :- Sort
*/
protected def treeChildren: Seq[BaseType] = children
/**
+ * All the nodes that are parts of this node.
+ *
+ * For example:
+ *
+ * WholeStageCodegen
+ * +- SortMergeJoin
+ * |-- InputAdapter
+ * | +-- Sort
+ * +-- InputAdapter
+ * +-- Sort
+ *
+ * the innerChildren of WholeStageCodegen will be Seq(SortMergeJoin), it will generate a tree
+ * string like this:
+ *
+ * WholeStageCodegen
+ * : +- SortMergeJoin
+ * : :- INPUT
+ * : :- INPUT
+ * :- Sort
+ * :- Sort
+ */
+ protected def innerChildren: Seq[BaseType] = Nil
+
+ /**
* Appends the string represent of this node and its children to the given StringBuilder.
*
* The `i`-th element in `lastChildren` indicates whether the ancestor of the current node at
@@ -472,6 +515,12 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
builder.append(simpleString)
builder.append("\n")
+ if (innerChildren.nonEmpty) {
+ innerChildren.init.foreach(_.generateTreeString(
+ depth + 2, lastChildren :+ false :+ false, builder))
+ innerChildren.last.generateTreeString(depth + 2, lastChildren :+ false :+ true, builder)
+ }
+
if (treeChildren.nonEmpty) {
treeChildren.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder))
treeChildren.last.generateTreeString(depth + 1, lastChildren :+ true, builder)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
index 4dd9928244..9019e5dfd6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
@@ -36,11 +36,8 @@ class SparkPlanInfo(
private[sql] object SparkPlanInfo {
def fromSparkPlan(plan: SparkPlan): SparkPlanInfo = {
- val children = plan match {
- case WholeStageCodegen(child, _) => child :: Nil
- case InputAdapter(child) => child :: Nil
- case plan => plan.children
- }
+
+ val children = plan.children ++ plan.subqueries
val metrics = plan.metrics.toSeq.map { case (key, metric) =>
new SQLMetricInfo(metric.name.getOrElse(key), metric.id,
Utils.getFormattedClassName(metric.param))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index cb68ca6ada..6d231bf74a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.execution
-import scala.collection.mutable.ArrayBuffer
-
import org.apache.spark.broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
@@ -29,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.toCommentSafeString
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
-import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight, SortMergeJoin}
+import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin}
import org.apache.spark.sql.execution.metric.LongSQLMetricValue
/**
@@ -163,16 +161,12 @@ trait CodegenSupport extends SparkPlan {
* This is the leaf node of a tree with WholeStageCodegen, is used to generate code that consumes
* an RDD iterator of InternalRow.
*/
-case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
+case class InputAdapter(child: SparkPlan) extends UnaryNode with CodegenSupport {
override def output: Seq[Attribute] = child.output
override def outputPartitioning: Partitioning = child.outputPartitioning
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
- override def doPrepare(): Unit = {
- child.prepare()
- }
-
override def doExecute(): RDD[InternalRow] = {
child.execute()
}
@@ -181,8 +175,6 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
child.doExecuteBroadcast()
}
- override def supportCodegen: Boolean = false
-
override def upstreams(): Seq[RDD[InternalRow]] = {
child.execute() :: Nil
}
@@ -210,6 +202,8 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
}
override def simpleString: String = "INPUT"
+
+ override def treeChildren: Seq[SparkPlan] = Nil
}
/**
@@ -243,22 +237,15 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
* doCodeGen() will create a CodeGenContext, which will hold a list of variables for input,
* used to generated code for BoundReference.
*/
-case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
- extends SparkPlan with CodegenSupport {
-
- override def supportCodegen: Boolean = false
-
- override def output: Seq[Attribute] = plan.output
- override def outputPartitioning: Partitioning = plan.outputPartitioning
- override def outputOrdering: Seq[SortOrder] = plan.outputOrdering
+case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSupport {
- override def doPrepare(): Unit = {
- plan.prepare()
- }
+ override def output: Seq[Attribute] = child.output
+ override def outputPartitioning: Partitioning = child.outputPartitioning
+ override def outputOrdering: Seq[SortOrder] = child.outputOrdering
override def doExecute(): RDD[InternalRow] = {
val ctx = new CodegenContext
- val code = plan.produce(ctx, this)
+ val code = child.asInstanceOf[CodegenSupport].produce(ctx, this)
val references = ctx.references.toArray
val source = s"""
public Object generate(Object[] references) {
@@ -266,7 +253,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
}
/** Codegened pipeline for:
- * ${toCommentSafeString(plan.treeString.trim)}
+ * ${toCommentSafeString(child.treeString.trim)}
*/
class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
@@ -294,7 +281,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
// println(s"${CodeFormatter.format(cleanedSource)}")
CodeGenerator.compile(cleanedSource)
- val rdds = plan.upstreams()
+ val rdds = child.asInstanceOf[CodegenSupport].upstreams()
assert(rdds.size <= 2, "Up to two upstream RDDs can be supported")
if (rdds.length == 1) {
rdds.head.mapPartitions { iter =>
@@ -361,34 +348,17 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
}
}
- private[sql] override def resetMetrics(): Unit = {
- plan.foreach(_.resetMetrics())
+ override def innerChildren: Seq[SparkPlan] = {
+ child :: Nil
}
- override def generateTreeString(
- depth: Int,
- lastChildren: Seq[Boolean],
- builder: StringBuilder): StringBuilder = {
- if (depth > 0) {
- lastChildren.init.foreach { isLast =>
- val prefixFragment = if (isLast) " " else ": "
- builder.append(prefixFragment)
- }
-
- val branch = if (lastChildren.last) "+- " else ":- "
- builder.append(branch)
- }
-
- builder.append(simpleString)
- builder.append("\n")
-
- plan.generateTreeString(depth + 2, lastChildren :+ false :+ true, builder)
- if (children.nonEmpty) {
- children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder))
- children.last.generateTreeString(depth + 1, lastChildren :+ true, builder)
- }
+ private def collectInputs(plan: SparkPlan): Seq[SparkPlan] = plan match {
+ case InputAdapter(c) => c :: Nil
+ case other => other.children.flatMap(collectInputs)
+ }
- builder
+ override def treeChildren: Seq[SparkPlan] = {
+ collectInputs(child)
}
override def simpleString: String = "WholeStageCodegen"
@@ -416,27 +386,34 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru
case _ => false
}
+ /**
+ * Inserts a InputAdapter on top of those that do not support codegen.
+ */
+ private def insertInputAdapter(plan: SparkPlan): SparkPlan = plan match {
+ case j @ SortMergeJoin(_, _, _, left, right) =>
+ // The children of SortMergeJoin should do codegen separately.
+ j.copy(left = InputAdapter(insertWholeStageCodegen(left)),
+ right = InputAdapter(insertWholeStageCodegen(right)))
+ case p if !supportCodegen(p) =>
+ // collapse them recursively
+ InputAdapter(insertWholeStageCodegen(p))
+ case p =>
+ p.withNewChildren(p.children.map(insertInputAdapter))
+ }
+
+ /**
+ * Inserts a WholeStageCodegen on top of those that support codegen.
+ */
+ private def insertWholeStageCodegen(plan: SparkPlan): SparkPlan = plan match {
+ case plan: CodegenSupport if supportCodegen(plan) =>
+ WholeStageCodegen(insertInputAdapter(plan))
+ case other =>
+ other.withNewChildren(other.children.map(insertWholeStageCodegen))
+ }
+
def apply(plan: SparkPlan): SparkPlan = {
if (sqlContext.conf.wholeStageEnabled) {
- plan.transform {
- case plan: CodegenSupport if supportCodegen(plan) =>
- var inputs = ArrayBuffer[SparkPlan]()
- val combined = plan.transform {
- // The build side can't be compiled together
- case b @ BroadcastHashJoin(_, _, _, BuildLeft, _, left, right) =>
- b.copy(left = apply(left))
- case b @ BroadcastHashJoin(_, _, _, BuildRight, _, left, right) =>
- b.copy(right = apply(right))
- case j @ SortMergeJoin(_, _, _, left, right) =>
- // The children of SortMergeJoin should do codegen separately.
- j.copy(left = apply(left), right = apply(right))
- case p if !supportCodegen(p) =>
- val input = apply(p) // collapse them recursively
- inputs += input
- InputAdapter(input)
- }.asInstanceOf[CodegenSupport]
- WholeStageCodegen(combined, inputs)
- }
+ insertWholeStageCodegen(plan)
} else {
plan
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index 95d033bc57..fed88b8c0a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -24,6 +24,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
import org.apache.spark.sql.internal.SQLConf
@@ -68,7 +69,7 @@ package object debug {
}
}
- private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode {
+ private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode with CodegenSupport {
def output: Seq[Attribute] = child.output
implicit object SetAccumulatorParam extends AccumulatorParam[HashSet[String]] {
@@ -86,10 +87,11 @@ package object debug {
/**
* A collection of metrics for each column of output.
* @param elementTypes the actual runtime types for the output. Useful when there are bugs
- * causing the wrong data to be projected.
+ * causing the wrong data to be projected.
*/
case class ColumnMetrics(
- elementTypes: Accumulator[HashSet[String]] = sparkContext.accumulator(HashSet.empty))
+ elementTypes: Accumulator[HashSet[String]] = sparkContext.accumulator(HashSet.empty))
+
val tupleCount: Accumulator[Int] = sparkContext.accumulator[Int](0)
val numColumns: Int = child.output.size
@@ -98,7 +100,7 @@ package object debug {
def dumpStats(): Unit = {
logDebug(s"== ${child.simpleString} ==")
logDebug(s"Tuples output: ${tupleCount.value}")
- child.output.zip(columnStats).foreach { case(attr, metric) =>
+ child.output.zip(columnStats).foreach { case (attr, metric) =>
val actualDataTypes = metric.elementTypes.value.mkString("{", ",", "}")
logDebug(s" ${attr.name} ${attr.dataType}: $actualDataTypes")
}
@@ -108,6 +110,7 @@ package object debug {
child.execute().mapPartitions { iter =>
new Iterator[InternalRow] {
def hasNext: Boolean = iter.hasNext
+
def next(): InternalRow = {
val currentRow = iter.next()
tupleCount += 1
@@ -124,5 +127,17 @@ package object debug {
}
}
}
+
+ override def upstreams(): Seq[RDD[InternalRow]] = {
+ child.asInstanceOf[CodegenSupport].upstreams()
+ }
+
+ override def doProduce(ctx: CodegenContext): String = {
+ child.asInstanceOf[CodegenSupport].produce(ctx, this)
+ }
+
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ consume(ctx, input)
+ }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
index 4eb248569b..12e586ada5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
@@ -21,7 +21,7 @@ import java.util.concurrent.atomic.AtomicLong
import scala.collection.mutable
-import org.apache.spark.sql.execution.{InputAdapter, SparkPlanInfo, WholeStageCodegen}
+import org.apache.spark.sql.execution.SparkPlanInfo
import org.apache.spark.sql.execution.metric.SQLMetrics
/**
@@ -73,36 +73,40 @@ private[sql] object SparkPlanGraph {
edges: mutable.ArrayBuffer[SparkPlanGraphEdge],
parent: SparkPlanGraphNode,
subgraph: SparkPlanGraphCluster): Unit = {
- if (planInfo.nodeName == classOf[WholeStageCodegen].getSimpleName) {
- val cluster = new SparkPlanGraphCluster(
- nodeIdGenerator.getAndIncrement(),
- planInfo.nodeName,
- planInfo.simpleString,
- mutable.ArrayBuffer[SparkPlanGraphNode]())
- nodes += cluster
- buildSparkPlanGraphNode(
- planInfo.children.head, nodeIdGenerator, nodes, edges, parent, cluster)
- } else if (planInfo.nodeName == classOf[InputAdapter].getSimpleName) {
- buildSparkPlanGraphNode(planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null)
- } else {
- val metrics = planInfo.metrics.map { metric =>
- SQLPlanMetric(metric.name, metric.accumulatorId,
- SQLMetrics.getMetricParam(metric.metricParam))
- }
- val node = new SparkPlanGraphNode(
- nodeIdGenerator.getAndIncrement(), planInfo.nodeName,
- planInfo.simpleString, planInfo.metadata, metrics)
- if (subgraph == null) {
- nodes += node
- } else {
- subgraph.nodes += node
- }
-
- if (parent != null) {
- edges += SparkPlanGraphEdge(node.id, parent.id)
- }
- planInfo.children.foreach(
- buildSparkPlanGraphNode(_, nodeIdGenerator, nodes, edges, node, subgraph))
+ planInfo.nodeName match {
+ case "WholeStageCodegen" =>
+ val cluster = new SparkPlanGraphCluster(
+ nodeIdGenerator.getAndIncrement(),
+ planInfo.nodeName,
+ planInfo.simpleString,
+ mutable.ArrayBuffer[SparkPlanGraphNode]())
+ nodes += cluster
+ buildSparkPlanGraphNode(
+ planInfo.children.head, nodeIdGenerator, nodes, edges, parent, cluster)
+ case "InputAdapter" =>
+ buildSparkPlanGraphNode(planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null)
+ case "Subquery" if subgraph != null =>
+ // Subquery should not be included in WholeStageCodegen
+ buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, parent, null)
+ case _ =>
+ val metrics = planInfo.metrics.map { metric =>
+ SQLPlanMetric(metric.name, metric.accumulatorId,
+ SQLMetrics.getMetricParam(metric.metricParam))
+ }
+ val node = new SparkPlanGraphNode(
+ nodeIdGenerator.getAndIncrement(), planInfo.nodeName,
+ planInfo.simpleString, planInfo.metadata, metrics)
+ if (subgraph == null) {
+ nodes += node
+ } else {
+ subgraph.nodes += node
+ }
+
+ if (parent != null) {
+ edges += SparkPlanGraphEdge(node.id, parent.id)
+ }
+ planInfo.children.foreach(
+ buildSparkPlanGraphNode(_, nodeIdGenerator, nodes, edges, node, subgraph))
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index de371d85d9..e00c762c67 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -31,14 +31,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
val df = sqlContext.range(10).filter("id = 1").selectExpr("id + 1")
val plan = df.queryExecution.executedPlan
assert(plan.find(_.isInstanceOf[WholeStageCodegen]).isDefined)
-
- checkThatPlansAgree(
- sqlContext.range(100),
- (p: SparkPlan) =>
- WholeStageCodegen(Filter('a == 1, InputAdapter(p)), Seq()),
- (p: SparkPlan) => Filter('a == 1, p),
- sortAnswers = false
- )
+ assert(df.collect() === Array(Row(2)))
}
test("Aggregate should be included in WholeStageCodegen") {
@@ -46,7 +39,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
val plan = df.queryExecution.executedPlan
assert(plan.find(p =>
p.isInstanceOf[WholeStageCodegen] &&
- p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined)
+ p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[TungstenAggregate]).isDefined)
assert(df.collect() === Array(Row(9, 4.5)))
}
@@ -55,7 +48,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
val plan = df.queryExecution.executedPlan
assert(plan.find(p =>
p.isInstanceOf[WholeStageCodegen] &&
- p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined)
+ p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[TungstenAggregate]).isDefined)
assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1)))
}
@@ -66,7 +59,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
val df = sqlContext.range(10).join(broadcast(smallDF), col("k") === col("id"))
assert(df.queryExecution.executedPlan.find(p =>
p.isInstanceOf[WholeStageCodegen] &&
- p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[BroadcastHashJoin]).isDefined)
+ p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[BroadcastHashJoin]).isDefined)
assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2")))
}
@@ -75,7 +68,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
val plan = df.queryExecution.executedPlan
assert(plan.find(p =>
p.isInstanceOf[WholeStageCodegen] &&
- p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[Sort]).isDefined)
+ p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[Sort]).isDefined)
assert(df.collect() === Array(Row(1), Row(2), Row(3)))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index 4358c7c76d..b0d64aa7bf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -210,8 +210,8 @@ class JDBCSuite extends SparkFunSuite
// the plan only has PhysicalRDD to scan JDBCRelation.
assert(parentPlan.isInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen])
val node = parentPlan.asInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen]
- assert(node.plan.isInstanceOf[org.apache.spark.sql.execution.PhysicalRDD])
- assert(node.plan.asInstanceOf[PhysicalRDD].nodeName.contains("JDBCRelation"))
+ assert(node.child.isInstanceOf[org.apache.spark.sql.execution.PhysicalRDD])
+ assert(node.child.asInstanceOf[PhysicalRDD].nodeName.contains("JDBCRelation"))
df
}
assert(checkPushdown(sql("SELECT * FROM foobar WHERE THEID < 1")).collect().size == 0)
@@ -248,7 +248,7 @@ class JDBCSuite extends SparkFunSuite
// cannot compile given predicates.
assert(parentPlan.isInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen])
val node = parentPlan.asInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen]
- assert(node.plan.isInstanceOf[org.apache.spark.sql.execution.Filter])
+ assert(node.child.isInstanceOf[org.apache.spark.sql.execution.Filter])
df
}
assert(checkNotPushdown(sql("SELECT * FROM foobar WHERE (THEID + 1) < 2")).collect().size == 0)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
index 15a95623d1..e7d2b5ad96 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
@@ -93,7 +93,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
val metric = qe.executedPlan match {
- case w: WholeStageCodegen => w.plan.longMetric("numOutputRows")
+ case w: WholeStageCodegen => w.child.longMetric("numOutputRows")
case other => other.longMetric("numOutputRows")
}
metrics += metric.value.value