From b373a888621ba6f0dd499f47093d4e2e42086dfc Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 3 Mar 2016 17:36:48 -0800 Subject: [SPARK-13415][SQL] Visualize subquery in SQL web UI ## What changes were proposed in this pull request? This PR support visualization for subquery in SQL web UI, also improve the explain of subquery, especially when it's used together with whole stage codegen. For example: ```python >>> sqlContext.range(100).registerTempTable("range") >>> sqlContext.sql("select id / (select sum(id) from range) from range where id > (select id from range limit 1)").explain(True) == Parsed Logical Plan == 'Project [unresolvedalias(('id / subquery#9), None)] : +- 'SubqueryAlias subquery#9 : +- 'Project [unresolvedalias('sum('id), None)] : +- 'UnresolvedRelation `range`, None +- 'Filter ('id > subquery#8) : +- 'SubqueryAlias subquery#8 : +- 'GlobalLimit 1 : +- 'LocalLimit 1 : +- 'Project [unresolvedalias('id, None)] : +- 'UnresolvedRelation `range`, None +- 'UnresolvedRelation `range`, None == Analyzed Logical Plan == (id / scalarsubquery()): double Project [(cast(id#0L as double) / cast(subquery#9 as double)) AS (id / scalarsubquery())#11] : +- SubqueryAlias subquery#9 : +- Aggregate [(sum(id#0L),mode=Complete,isDistinct=false) AS sum(id)#10L] : +- SubqueryAlias range : +- Range 0, 100, 1, 4, [id#0L] +- Filter (id#0L > subquery#8) : +- SubqueryAlias subquery#8 : +- GlobalLimit 1 : +- LocalLimit 1 : +- Project [id#0L] : +- SubqueryAlias range : +- Range 0, 100, 1, 4, [id#0L] +- SubqueryAlias range +- Range 0, 100, 1, 4, [id#0L] == Optimized Logical Plan == Project [(cast(id#0L as double) / cast(subquery#9 as double)) AS (id / scalarsubquery())#11] : +- SubqueryAlias subquery#9 : +- Aggregate [(sum(id#0L),mode=Complete,isDistinct=false) AS sum(id)#10L] : +- Range 0, 100, 1, 4, [id#0L] +- Filter (id#0L > subquery#8) : +- SubqueryAlias subquery#8 : +- GlobalLimit 1 : +- LocalLimit 1 : +- Project [id#0L] : +- Range 0, 100, 1, 4, [id#0L] +- Range 0, 100, 1, 4, [id#0L] == Physical Plan == WholeStageCodegen : +- Project [(cast(id#0L as double) / cast(subquery#9 as double)) AS (id / scalarsubquery())#11] : : +- Subquery subquery#9 : : +- WholeStageCodegen : : : +- TungstenAggregate(key=[], functions=[(sum(id#0L),mode=Final,isDistinct=false)], output=[sum(id)#10L]) : : : +- INPUT : : +- Exchange SinglePartition, None : : +- WholeStageCodegen : : : +- TungstenAggregate(key=[], functions=[(sum(id#0L),mode=Partial,isDistinct=false)], output=[sum#14L]) : : : +- Range 0, 1, 4, 100, [id#0L] : +- Filter (id#0L > subquery#8) : : +- Subquery subquery#8 : : +- CollectLimit 1 : : +- WholeStageCodegen : : : +- Project [id#0L] : : : +- Range 0, 1, 4, 100, [id#0L] : +- Range 0, 1, 4, 100, [id#0L] ``` The web UI looks like: ![subquery](https://cloud.githubusercontent.com/assets/40902/13377963/932bcbae-dda7-11e5-82f7-03c9be85d77c.png) This PR also change the tree structure of WholeStageCodegen to make it consistent than others. Before this change, Both WholeStageCodegen and InputAdapter hold a references to the same plans, those could be updated without notify another, causing problems, this is discovered by #11403 . ## How was this patch tested? Existing tests, also manual tests with the example query, check the explain and web UI. Author: Davies Liu Closes #11417 from davies/viz_subquery. --- .../spark/sql/catalyst/plans/QueryPlan.scala | 10 +- .../apache/spark/sql/catalyst/trees/TreeNode.scala | 49 +++++++++ .../apache/spark/sql/execution/SparkPlanInfo.scala | 7 +- .../spark/sql/execution/WholeStageCodegen.scala | 113 ++++++++------------- .../apache/spark/sql/execution/debug/package.scala | 23 ++++- .../spark/sql/execution/ui/SparkPlanGraph.scala | 66 ++++++------ .../sql/execution/WholeStageCodegenSuite.scala | 17 +--- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 6 +- .../spark/sql/util/DataFrameCallbackSuite.scala | 2 +- 9 files changed, 166 insertions(+), 127 deletions(-) (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 3ff37fffbd..0e0453b517 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -229,8 +229,12 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy override def simpleString: String = statePrefix + super.simpleString - override def treeChildren: Seq[PlanType] = { - val subqueries = expressions.flatMap(_.collect {case e: SubqueryExpression => e}) - children ++ subqueries.map(e => e.plan.asInstanceOf[PlanType]) + /** + * All the subqueries of current plan. + */ + def subqueries: Seq[PlanType] = { + expressions.flatMap(_.collect {case e: SubqueryExpression => e.plan.asInstanceOf[PlanType]}) } + + override def innerChildren: Seq[PlanType] = subqueries } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 2d0bf6b375..6b7997e903 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -447,9 +447,52 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { /** * All the nodes that will be used to generate tree string. + * + * For example: + * + * WholeStageCodegen + * +-- SortMergeJoin + * |-- InputAdapter + * | +-- Sort + * +-- InputAdapter + * +-- Sort + * + * the treeChildren of WholeStageCodegen will be Seq(Sort, Sort), it will generate a tree string + * like this: + * + * WholeStageCodegen + * : +- SortMergeJoin + * : :- INPUT + * : :- INPUT + * :- Sort + * :- Sort */ protected def treeChildren: Seq[BaseType] = children + /** + * All the nodes that are parts of this node. + * + * For example: + * + * WholeStageCodegen + * +- SortMergeJoin + * |-- InputAdapter + * | +-- Sort + * +-- InputAdapter + * +-- Sort + * + * the innerChildren of WholeStageCodegen will be Seq(SortMergeJoin), it will generate a tree + * string like this: + * + * WholeStageCodegen + * : +- SortMergeJoin + * : :- INPUT + * : :- INPUT + * :- Sort + * :- Sort + */ + protected def innerChildren: Seq[BaseType] = Nil + /** * Appends the string represent of this node and its children to the given StringBuilder. * @@ -472,6 +515,12 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { builder.append(simpleString) builder.append("\n") + if (innerChildren.nonEmpty) { + innerChildren.init.foreach(_.generateTreeString( + depth + 2, lastChildren :+ false :+ false, builder)) + innerChildren.last.generateTreeString(depth + 2, lastChildren :+ false :+ true, builder) + } + if (treeChildren.nonEmpty) { treeChildren.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder)) treeChildren.last.generateTreeString(depth + 1, lastChildren :+ true, builder) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala index 4dd9928244..9019e5dfd6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala @@ -36,11 +36,8 @@ class SparkPlanInfo( private[sql] object SparkPlanInfo { def fromSparkPlan(plan: SparkPlan): SparkPlanInfo = { - val children = plan match { - case WholeStageCodegen(child, _) => child :: Nil - case InputAdapter(child) => child :: Nil - case plan => plan.children - } + + val children = plan.children ++ plan.subqueries val metrics = plan.metrics.toSeq.map { case (key, metric) => new SQLMetricInfo(metric.name.getOrElse(key), metric.id, Utils.getFormattedClassName(metric.param)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index cb68ca6ada..6d231bf74a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution -import scala.collection.mutable.ArrayBuffer - import org.apache.spark.broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext @@ -29,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.toCommentSafeString import org.apache.spark.sql.execution.aggregate.TungstenAggregate -import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight, SortMergeJoin} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin} import org.apache.spark.sql.execution.metric.LongSQLMetricValue /** @@ -163,16 +161,12 @@ trait CodegenSupport extends SparkPlan { * This is the leaf node of a tree with WholeStageCodegen, is used to generate code that consumes * an RDD iterator of InternalRow. */ -case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { +case class InputAdapter(child: SparkPlan) extends UnaryNode with CodegenSupport { override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = child.outputPartitioning override def outputOrdering: Seq[SortOrder] = child.outputOrdering - override def doPrepare(): Unit = { - child.prepare() - } - override def doExecute(): RDD[InternalRow] = { child.execute() } @@ -181,8 +175,6 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { child.doExecuteBroadcast() } - override def supportCodegen: Boolean = false - override def upstreams(): Seq[RDD[InternalRow]] = { child.execute() :: Nil } @@ -210,6 +202,8 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { } override def simpleString: String = "INPUT" + + override def treeChildren: Seq[SparkPlan] = Nil } /** @@ -243,22 +237,15 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport { * doCodeGen() will create a CodeGenContext, which will hold a list of variables for input, * used to generated code for BoundReference. */ -case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) - extends SparkPlan with CodegenSupport { - - override def supportCodegen: Boolean = false - - override def output: Seq[Attribute] = plan.output - override def outputPartitioning: Partitioning = plan.outputPartitioning - override def outputOrdering: Seq[SortOrder] = plan.outputOrdering +case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSupport { - override def doPrepare(): Unit = { - plan.prepare() - } + override def output: Seq[Attribute] = child.output + override def outputPartitioning: Partitioning = child.outputPartitioning + override def outputOrdering: Seq[SortOrder] = child.outputOrdering override def doExecute(): RDD[InternalRow] = { val ctx = new CodegenContext - val code = plan.produce(ctx, this) + val code = child.asInstanceOf[CodegenSupport].produce(ctx, this) val references = ctx.references.toArray val source = s""" public Object generate(Object[] references) { @@ -266,7 +253,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) } /** Codegened pipeline for: - * ${toCommentSafeString(plan.treeString.trim)} + * ${toCommentSafeString(child.treeString.trim)} */ class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { @@ -294,7 +281,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) // println(s"${CodeFormatter.format(cleanedSource)}") CodeGenerator.compile(cleanedSource) - val rdds = plan.upstreams() + val rdds = child.asInstanceOf[CodegenSupport].upstreams() assert(rdds.size <= 2, "Up to two upstream RDDs can be supported") if (rdds.length == 1) { rdds.head.mapPartitions { iter => @@ -361,34 +348,17 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) } } - private[sql] override def resetMetrics(): Unit = { - plan.foreach(_.resetMetrics()) + override def innerChildren: Seq[SparkPlan] = { + child :: Nil } - override def generateTreeString( - depth: Int, - lastChildren: Seq[Boolean], - builder: StringBuilder): StringBuilder = { - if (depth > 0) { - lastChildren.init.foreach { isLast => - val prefixFragment = if (isLast) " " else ": " - builder.append(prefixFragment) - } - - val branch = if (lastChildren.last) "+- " else ":- " - builder.append(branch) - } - - builder.append(simpleString) - builder.append("\n") - - plan.generateTreeString(depth + 2, lastChildren :+ false :+ true, builder) - if (children.nonEmpty) { - children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder)) - children.last.generateTreeString(depth + 1, lastChildren :+ true, builder) - } + private def collectInputs(plan: SparkPlan): Seq[SparkPlan] = plan match { + case InputAdapter(c) => c :: Nil + case other => other.children.flatMap(collectInputs) + } - builder + override def treeChildren: Seq[SparkPlan] = { + collectInputs(child) } override def simpleString: String = "WholeStageCodegen" @@ -416,27 +386,34 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru case _ => false } + /** + * Inserts a InputAdapter on top of those that do not support codegen. + */ + private def insertInputAdapter(plan: SparkPlan): SparkPlan = plan match { + case j @ SortMergeJoin(_, _, _, left, right) => + // The children of SortMergeJoin should do codegen separately. + j.copy(left = InputAdapter(insertWholeStageCodegen(left)), + right = InputAdapter(insertWholeStageCodegen(right))) + case p if !supportCodegen(p) => + // collapse them recursively + InputAdapter(insertWholeStageCodegen(p)) + case p => + p.withNewChildren(p.children.map(insertInputAdapter)) + } + + /** + * Inserts a WholeStageCodegen on top of those that support codegen. + */ + private def insertWholeStageCodegen(plan: SparkPlan): SparkPlan = plan match { + case plan: CodegenSupport if supportCodegen(plan) => + WholeStageCodegen(insertInputAdapter(plan)) + case other => + other.withNewChildren(other.children.map(insertWholeStageCodegen)) + } + def apply(plan: SparkPlan): SparkPlan = { if (sqlContext.conf.wholeStageEnabled) { - plan.transform { - case plan: CodegenSupport if supportCodegen(plan) => - var inputs = ArrayBuffer[SparkPlan]() - val combined = plan.transform { - // The build side can't be compiled together - case b @ BroadcastHashJoin(_, _, _, BuildLeft, _, left, right) => - b.copy(left = apply(left)) - case b @ BroadcastHashJoin(_, _, _, BuildRight, _, left, right) => - b.copy(right = apply(right)) - case j @ SortMergeJoin(_, _, _, left, right) => - // The children of SortMergeJoin should do codegen separately. - j.copy(left = apply(left), right = apply(right)) - case p if !supportCodegen(p) => - val input = apply(p) // collapse them recursively - inputs += input - InputAdapter(input) - }.asInstanceOf[CodegenSupport] - WholeStageCodegen(combined, inputs) - } + insertWholeStageCodegen(plan) } else { plan } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 95d033bc57..fed88b8c0a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -24,6 +24,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.internal.SQLConf @@ -68,7 +69,7 @@ package object debug { } } - private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode { + private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode with CodegenSupport { def output: Seq[Attribute] = child.output implicit object SetAccumulatorParam extends AccumulatorParam[HashSet[String]] { @@ -86,10 +87,11 @@ package object debug { /** * A collection of metrics for each column of output. * @param elementTypes the actual runtime types for the output. Useful when there are bugs - * causing the wrong data to be projected. + * causing the wrong data to be projected. */ case class ColumnMetrics( - elementTypes: Accumulator[HashSet[String]] = sparkContext.accumulator(HashSet.empty)) + elementTypes: Accumulator[HashSet[String]] = sparkContext.accumulator(HashSet.empty)) + val tupleCount: Accumulator[Int] = sparkContext.accumulator[Int](0) val numColumns: Int = child.output.size @@ -98,7 +100,7 @@ package object debug { def dumpStats(): Unit = { logDebug(s"== ${child.simpleString} ==") logDebug(s"Tuples output: ${tupleCount.value}") - child.output.zip(columnStats).foreach { case(attr, metric) => + child.output.zip(columnStats).foreach { case (attr, metric) => val actualDataTypes = metric.elementTypes.value.mkString("{", ",", "}") logDebug(s" ${attr.name} ${attr.dataType}: $actualDataTypes") } @@ -108,6 +110,7 @@ package object debug { child.execute().mapPartitions { iter => new Iterator[InternalRow] { def hasNext: Boolean = iter.hasNext + def next(): InternalRow = { val currentRow = iter.next() tupleCount += 1 @@ -124,5 +127,17 @@ package object debug { } } } + + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() + } + + override def doProduce(ctx: CodegenContext): String = { + child.asInstanceOf[CodegenSupport].produce(ctx, this) + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + consume(ctx, input) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index 4eb248569b..12e586ada5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -21,7 +21,7 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable -import org.apache.spark.sql.execution.{InputAdapter, SparkPlanInfo, WholeStageCodegen} +import org.apache.spark.sql.execution.SparkPlanInfo import org.apache.spark.sql.execution.metric.SQLMetrics /** @@ -73,36 +73,40 @@ private[sql] object SparkPlanGraph { edges: mutable.ArrayBuffer[SparkPlanGraphEdge], parent: SparkPlanGraphNode, subgraph: SparkPlanGraphCluster): Unit = { - if (planInfo.nodeName == classOf[WholeStageCodegen].getSimpleName) { - val cluster = new SparkPlanGraphCluster( - nodeIdGenerator.getAndIncrement(), - planInfo.nodeName, - planInfo.simpleString, - mutable.ArrayBuffer[SparkPlanGraphNode]()) - nodes += cluster - buildSparkPlanGraphNode( - planInfo.children.head, nodeIdGenerator, nodes, edges, parent, cluster) - } else if (planInfo.nodeName == classOf[InputAdapter].getSimpleName) { - buildSparkPlanGraphNode(planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null) - } else { - val metrics = planInfo.metrics.map { metric => - SQLPlanMetric(metric.name, metric.accumulatorId, - SQLMetrics.getMetricParam(metric.metricParam)) - } - val node = new SparkPlanGraphNode( - nodeIdGenerator.getAndIncrement(), planInfo.nodeName, - planInfo.simpleString, planInfo.metadata, metrics) - if (subgraph == null) { - nodes += node - } else { - subgraph.nodes += node - } - - if (parent != null) { - edges += SparkPlanGraphEdge(node.id, parent.id) - } - planInfo.children.foreach( - buildSparkPlanGraphNode(_, nodeIdGenerator, nodes, edges, node, subgraph)) + planInfo.nodeName match { + case "WholeStageCodegen" => + val cluster = new SparkPlanGraphCluster( + nodeIdGenerator.getAndIncrement(), + planInfo.nodeName, + planInfo.simpleString, + mutable.ArrayBuffer[SparkPlanGraphNode]()) + nodes += cluster + buildSparkPlanGraphNode( + planInfo.children.head, nodeIdGenerator, nodes, edges, parent, cluster) + case "InputAdapter" => + buildSparkPlanGraphNode(planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null) + case "Subquery" if subgraph != null => + // Subquery should not be included in WholeStageCodegen + buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, parent, null) + case _ => + val metrics = planInfo.metrics.map { metric => + SQLPlanMetric(metric.name, metric.accumulatorId, + SQLMetrics.getMetricParam(metric.metricParam)) + } + val node = new SparkPlanGraphNode( + nodeIdGenerator.getAndIncrement(), planInfo.nodeName, + planInfo.simpleString, planInfo.metadata, metrics) + if (subgraph == null) { + nodes += node + } else { + subgraph.nodes += node + } + + if (parent != null) { + edges += SparkPlanGraphEdge(node.id, parent.id) + } + planInfo.children.foreach( + buildSparkPlanGraphNode(_, nodeIdGenerator, nodes, edges, node, subgraph)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index de371d85d9..e00c762c67 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -31,14 +31,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { val df = sqlContext.range(10).filter("id = 1").selectExpr("id + 1") val plan = df.queryExecution.executedPlan assert(plan.find(_.isInstanceOf[WholeStageCodegen]).isDefined) - - checkThatPlansAgree( - sqlContext.range(100), - (p: SparkPlan) => - WholeStageCodegen(Filter('a == 1, InputAdapter(p)), Seq()), - (p: SparkPlan) => Filter('a == 1, p), - sortAnswers = false - ) + assert(df.collect() === Array(Row(2))) } test("Aggregate should be included in WholeStageCodegen") { @@ -46,7 +39,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { val plan = df.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegen] && - p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined) + p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[TungstenAggregate]).isDefined) assert(df.collect() === Array(Row(9, 4.5))) } @@ -55,7 +48,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { val plan = df.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegen] && - p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined) + p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[TungstenAggregate]).isDefined) assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1))) } @@ -66,7 +59,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { val df = sqlContext.range(10).join(broadcast(smallDF), col("k") === col("id")) assert(df.queryExecution.executedPlan.find(p => p.isInstanceOf[WholeStageCodegen] && - p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[BroadcastHashJoin]).isDefined) + p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[BroadcastHashJoin]).isDefined) assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2"))) } @@ -75,7 +68,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { val plan = df.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegen] && - p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[Sort]).isDefined) + p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[Sort]).isDefined) assert(df.collect() === Array(Row(1), Row(2), Row(3))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 4358c7c76d..b0d64aa7bf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -210,8 +210,8 @@ class JDBCSuite extends SparkFunSuite // the plan only has PhysicalRDD to scan JDBCRelation. assert(parentPlan.isInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen]) val node = parentPlan.asInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen] - assert(node.plan.isInstanceOf[org.apache.spark.sql.execution.PhysicalRDD]) - assert(node.plan.asInstanceOf[PhysicalRDD].nodeName.contains("JDBCRelation")) + assert(node.child.isInstanceOf[org.apache.spark.sql.execution.PhysicalRDD]) + assert(node.child.asInstanceOf[PhysicalRDD].nodeName.contains("JDBCRelation")) df } assert(checkPushdown(sql("SELECT * FROM foobar WHERE THEID < 1")).collect().size == 0) @@ -248,7 +248,7 @@ class JDBCSuite extends SparkFunSuite // cannot compile given predicates. assert(parentPlan.isInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen]) val node = parentPlan.asInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen] - assert(node.plan.isInstanceOf[org.apache.spark.sql.execution.Filter]) + assert(node.child.isInstanceOf[org.apache.spark.sql.execution.Filter]) df } assert(checkNotPushdown(sql("SELECT * FROM foobar WHERE (THEID + 1) < 2")).collect().size == 0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index 15a95623d1..e7d2b5ad96 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -93,7 +93,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { val metric = qe.executedPlan match { - case w: WholeStageCodegen => w.plan.longMetric("numOutputRows") + case w: WholeStageCodegen => w.child.longMetric("numOutputRows") case other => other.longMetric("numOutputRows") } metrics += metric.value.value -- cgit v1.2.3