aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala49
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala113
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala23
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala66
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala17
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala2
9 files changed, 166 insertions, 127 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 3ff37fffbd..0e0453b517 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -229,8 +229,12 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
override def simpleString: String = statePrefix + super.simpleString
- override def treeChildren: Seq[PlanType] = {
- val subqueries = expressions.flatMap(_.collect {case e: SubqueryExpression => e})
- children ++ subqueries.map(e => e.plan.asInstanceOf[PlanType])
+ /**
+ * All the subqueries of current plan.
+ */
+ def subqueries: Seq[PlanType] = {
+ expressions.flatMap(_.collect {case e: SubqueryExpression => e.plan.asInstanceOf[PlanType]})
}
+
+ override def innerChildren: Seq[PlanType] = subqueries
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index 2d0bf6b375..6b7997e903 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -447,10 +447,53 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
/**
* All the nodes that will be used to generate tree string.
+ *
+ * For example:
+ *
+ * WholeStageCodegen
+ * +-- SortMergeJoin
+ * |-- InputAdapter
+ * | +-- Sort
+ * +-- InputAdapter
+ * +-- Sort
+ *
+ * the treeChildren of WholeStageCodegen will be Seq(Sort, Sort), it will generate a tree string
+ * like this:
+ *
+ * WholeStageCodegen
+ * : +- SortMergeJoin
+ * : :- INPUT
+ * : :- INPUT
+ * :- Sort
+ * :- Sort
*/
protected def treeChildren: Seq[BaseType] = children
/**
+ * All the nodes that are parts of this node.
+ *
+ * For example:
+ *
+ * WholeStageCodegen
+ * +- SortMergeJoin
+ * |-- InputAdapter
+ * | +-- Sort
+ * +-- InputAdapter
+ * +-- Sort
+ *
+ * the innerChildren of WholeStageCodegen will be Seq(SortMergeJoin), it will generate a tree
+ * string like this:
+ *
+ * WholeStageCodegen
+ * : +- SortMergeJoin
+ * : :- INPUT
+ * : :- INPUT
+ * :- Sort
+ * :- Sort
+ */
+ protected def innerChildren: Seq[BaseType] = Nil
+
+ /**
* Appends the string represent of this node and its children to the given StringBuilder.
*
* The `i`-th element in `lastChildren` indicates whether the ancestor of the current node at
@@ -472,6 +515,12 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
builder.append(simpleString)
builder.append("\n")
+ if (innerChildren.nonEmpty) {
+ innerChildren.init.foreach(_.generateTreeString(
+ depth + 2, lastChildren :+ false :+ false, builder))
+ innerChildren.last.generateTreeString(depth + 2, lastChildren :+ false :+ true, builder)
+ }
+
if (treeChildren.nonEmpty) {
treeChildren.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder))
treeChildren.last.generateTreeString(depth + 1, lastChildren :+ true, builder)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
index 4dd9928244..9019e5dfd6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
@@ -36,11 +36,8 @@ class SparkPlanInfo(
private[sql] object SparkPlanInfo {
def fromSparkPlan(plan: SparkPlan): SparkPlanInfo = {
- val children = plan match {
- case WholeStageCodegen(child, _) => child :: Nil
- case InputAdapter(child) => child :: Nil
- case plan => plan.children
- }
+
+ val children = plan.children ++ plan.subqueries
val metrics = plan.metrics.toSeq.map { case (key, metric) =>
new SQLMetricInfo(metric.name.getOrElse(key), metric.id,
Utils.getFormattedClassName(metric.param))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index cb68ca6ada..6d231bf74a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.execution
-import scala.collection.mutable.ArrayBuffer
-
import org.apache.spark.broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
@@ -29,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.toCommentSafeString
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
-import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight, SortMergeJoin}
+import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin}
import org.apache.spark.sql.execution.metric.LongSQLMetricValue
/**
@@ -163,16 +161,12 @@ trait CodegenSupport extends SparkPlan {
* This is the leaf node of a tree with WholeStageCodegen, is used to generate code that consumes
* an RDD iterator of InternalRow.
*/
-case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
+case class InputAdapter(child: SparkPlan) extends UnaryNode with CodegenSupport {
override def output: Seq[Attribute] = child.output
override def outputPartitioning: Partitioning = child.outputPartitioning
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
- override def doPrepare(): Unit = {
- child.prepare()
- }
-
override def doExecute(): RDD[InternalRow] = {
child.execute()
}
@@ -181,8 +175,6 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
child.doExecuteBroadcast()
}
- override def supportCodegen: Boolean = false
-
override def upstreams(): Seq[RDD[InternalRow]] = {
child.execute() :: Nil
}
@@ -210,6 +202,8 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
}
override def simpleString: String = "INPUT"
+
+ override def treeChildren: Seq[SparkPlan] = Nil
}
/**
@@ -243,22 +237,15 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
* doCodeGen() will create a CodeGenContext, which will hold a list of variables for input,
* used to generated code for BoundReference.
*/
-case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
- extends SparkPlan with CodegenSupport {
-
- override def supportCodegen: Boolean = false
-
- override def output: Seq[Attribute] = plan.output
- override def outputPartitioning: Partitioning = plan.outputPartitioning
- override def outputOrdering: Seq[SortOrder] = plan.outputOrdering
+case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSupport {
- override def doPrepare(): Unit = {
- plan.prepare()
- }
+ override def output: Seq[Attribute] = child.output
+ override def outputPartitioning: Partitioning = child.outputPartitioning
+ override def outputOrdering: Seq[SortOrder] = child.outputOrdering
override def doExecute(): RDD[InternalRow] = {
val ctx = new CodegenContext
- val code = plan.produce(ctx, this)
+ val code = child.asInstanceOf[CodegenSupport].produce(ctx, this)
val references = ctx.references.toArray
val source = s"""
public Object generate(Object[] references) {
@@ -266,7 +253,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
}
/** Codegened pipeline for:
- * ${toCommentSafeString(plan.treeString.trim)}
+ * ${toCommentSafeString(child.treeString.trim)}
*/
class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
@@ -294,7 +281,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
// println(s"${CodeFormatter.format(cleanedSource)}")
CodeGenerator.compile(cleanedSource)
- val rdds = plan.upstreams()
+ val rdds = child.asInstanceOf[CodegenSupport].upstreams()
assert(rdds.size <= 2, "Up to two upstream RDDs can be supported")
if (rdds.length == 1) {
rdds.head.mapPartitions { iter =>
@@ -361,34 +348,17 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
}
}
- private[sql] override def resetMetrics(): Unit = {
- plan.foreach(_.resetMetrics())
+ override def innerChildren: Seq[SparkPlan] = {
+ child :: Nil
}
- override def generateTreeString(
- depth: Int,
- lastChildren: Seq[Boolean],
- builder: StringBuilder): StringBuilder = {
- if (depth > 0) {
- lastChildren.init.foreach { isLast =>
- val prefixFragment = if (isLast) " " else ": "
- builder.append(prefixFragment)
- }
-
- val branch = if (lastChildren.last) "+- " else ":- "
- builder.append(branch)
- }
-
- builder.append(simpleString)
- builder.append("\n")
-
- plan.generateTreeString(depth + 2, lastChildren :+ false :+ true, builder)
- if (children.nonEmpty) {
- children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ false, builder))
- children.last.generateTreeString(depth + 1, lastChildren :+ true, builder)
- }
+ private def collectInputs(plan: SparkPlan): Seq[SparkPlan] = plan match {
+ case InputAdapter(c) => c :: Nil
+ case other => other.children.flatMap(collectInputs)
+ }
- builder
+ override def treeChildren: Seq[SparkPlan] = {
+ collectInputs(child)
}
override def simpleString: String = "WholeStageCodegen"
@@ -416,27 +386,34 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru
case _ => false
}
+ /**
+ * Inserts a InputAdapter on top of those that do not support codegen.
+ */
+ private def insertInputAdapter(plan: SparkPlan): SparkPlan = plan match {
+ case j @ SortMergeJoin(_, _, _, left, right) =>
+ // The children of SortMergeJoin should do codegen separately.
+ j.copy(left = InputAdapter(insertWholeStageCodegen(left)),
+ right = InputAdapter(insertWholeStageCodegen(right)))
+ case p if !supportCodegen(p) =>
+ // collapse them recursively
+ InputAdapter(insertWholeStageCodegen(p))
+ case p =>
+ p.withNewChildren(p.children.map(insertInputAdapter))
+ }
+
+ /**
+ * Inserts a WholeStageCodegen on top of those that support codegen.
+ */
+ private def insertWholeStageCodegen(plan: SparkPlan): SparkPlan = plan match {
+ case plan: CodegenSupport if supportCodegen(plan) =>
+ WholeStageCodegen(insertInputAdapter(plan))
+ case other =>
+ other.withNewChildren(other.children.map(insertWholeStageCodegen))
+ }
+
def apply(plan: SparkPlan): SparkPlan = {
if (sqlContext.conf.wholeStageEnabled) {
- plan.transform {
- case plan: CodegenSupport if supportCodegen(plan) =>
- var inputs = ArrayBuffer[SparkPlan]()
- val combined = plan.transform {
- // The build side can't be compiled together
- case b @ BroadcastHashJoin(_, _, _, BuildLeft, _, left, right) =>
- b.copy(left = apply(left))
- case b @ BroadcastHashJoin(_, _, _, BuildRight, _, left, right) =>
- b.copy(right = apply(right))
- case j @ SortMergeJoin(_, _, _, left, right) =>
- // The children of SortMergeJoin should do codegen separately.
- j.copy(left = apply(left), right = apply(right))
- case p if !supportCodegen(p) =>
- val input = apply(p) // collapse them recursively
- inputs += input
- InputAdapter(input)
- }.asInstanceOf[CodegenSupport]
- WholeStageCodegen(combined, inputs)
- }
+ insertWholeStageCodegen(plan)
} else {
plan
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index 95d033bc57..fed88b8c0a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -24,6 +24,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
import org.apache.spark.sql.internal.SQLConf
@@ -68,7 +69,7 @@ package object debug {
}
}
- private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode {
+ private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode with CodegenSupport {
def output: Seq[Attribute] = child.output
implicit object SetAccumulatorParam extends AccumulatorParam[HashSet[String]] {
@@ -86,10 +87,11 @@ package object debug {
/**
* A collection of metrics for each column of output.
* @param elementTypes the actual runtime types for the output. Useful when there are bugs
- * causing the wrong data to be projected.
+ * causing the wrong data to be projected.
*/
case class ColumnMetrics(
- elementTypes: Accumulator[HashSet[String]] = sparkContext.accumulator(HashSet.empty))
+ elementTypes: Accumulator[HashSet[String]] = sparkContext.accumulator(HashSet.empty))
+
val tupleCount: Accumulator[Int] = sparkContext.accumulator[Int](0)
val numColumns: Int = child.output.size
@@ -98,7 +100,7 @@ package object debug {
def dumpStats(): Unit = {
logDebug(s"== ${child.simpleString} ==")
logDebug(s"Tuples output: ${tupleCount.value}")
- child.output.zip(columnStats).foreach { case(attr, metric) =>
+ child.output.zip(columnStats).foreach { case (attr, metric) =>
val actualDataTypes = metric.elementTypes.value.mkString("{", ",", "}")
logDebug(s" ${attr.name} ${attr.dataType}: $actualDataTypes")
}
@@ -108,6 +110,7 @@ package object debug {
child.execute().mapPartitions { iter =>
new Iterator[InternalRow] {
def hasNext: Boolean = iter.hasNext
+
def next(): InternalRow = {
val currentRow = iter.next()
tupleCount += 1
@@ -124,5 +127,17 @@ package object debug {
}
}
}
+
+ override def upstreams(): Seq[RDD[InternalRow]] = {
+ child.asInstanceOf[CodegenSupport].upstreams()
+ }
+
+ override def doProduce(ctx: CodegenContext): String = {
+ child.asInstanceOf[CodegenSupport].produce(ctx, this)
+ }
+
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ consume(ctx, input)
+ }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
index 4eb248569b..12e586ada5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
@@ -21,7 +21,7 @@ import java.util.concurrent.atomic.AtomicLong
import scala.collection.mutable
-import org.apache.spark.sql.execution.{InputAdapter, SparkPlanInfo, WholeStageCodegen}
+import org.apache.spark.sql.execution.SparkPlanInfo
import org.apache.spark.sql.execution.metric.SQLMetrics
/**
@@ -73,36 +73,40 @@ private[sql] object SparkPlanGraph {
edges: mutable.ArrayBuffer[SparkPlanGraphEdge],
parent: SparkPlanGraphNode,
subgraph: SparkPlanGraphCluster): Unit = {
- if (planInfo.nodeName == classOf[WholeStageCodegen].getSimpleName) {
- val cluster = new SparkPlanGraphCluster(
- nodeIdGenerator.getAndIncrement(),
- planInfo.nodeName,
- planInfo.simpleString,
- mutable.ArrayBuffer[SparkPlanGraphNode]())
- nodes += cluster
- buildSparkPlanGraphNode(
- planInfo.children.head, nodeIdGenerator, nodes, edges, parent, cluster)
- } else if (planInfo.nodeName == classOf[InputAdapter].getSimpleName) {
- buildSparkPlanGraphNode(planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null)
- } else {
- val metrics = planInfo.metrics.map { metric =>
- SQLPlanMetric(metric.name, metric.accumulatorId,
- SQLMetrics.getMetricParam(metric.metricParam))
- }
- val node = new SparkPlanGraphNode(
- nodeIdGenerator.getAndIncrement(), planInfo.nodeName,
- planInfo.simpleString, planInfo.metadata, metrics)
- if (subgraph == null) {
- nodes += node
- } else {
- subgraph.nodes += node
- }
-
- if (parent != null) {
- edges += SparkPlanGraphEdge(node.id, parent.id)
- }
- planInfo.children.foreach(
- buildSparkPlanGraphNode(_, nodeIdGenerator, nodes, edges, node, subgraph))
+ planInfo.nodeName match {
+ case "WholeStageCodegen" =>
+ val cluster = new SparkPlanGraphCluster(
+ nodeIdGenerator.getAndIncrement(),
+ planInfo.nodeName,
+ planInfo.simpleString,
+ mutable.ArrayBuffer[SparkPlanGraphNode]())
+ nodes += cluster
+ buildSparkPlanGraphNode(
+ planInfo.children.head, nodeIdGenerator, nodes, edges, parent, cluster)
+ case "InputAdapter" =>
+ buildSparkPlanGraphNode(planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null)
+ case "Subquery" if subgraph != null =>
+ // Subquery should not be included in WholeStageCodegen
+ buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, parent, null)
+ case _ =>
+ val metrics = planInfo.metrics.map { metric =>
+ SQLPlanMetric(metric.name, metric.accumulatorId,
+ SQLMetrics.getMetricParam(metric.metricParam))
+ }
+ val node = new SparkPlanGraphNode(
+ nodeIdGenerator.getAndIncrement(), planInfo.nodeName,
+ planInfo.simpleString, planInfo.metadata, metrics)
+ if (subgraph == null) {
+ nodes += node
+ } else {
+ subgraph.nodes += node
+ }
+
+ if (parent != null) {
+ edges += SparkPlanGraphEdge(node.id, parent.id)
+ }
+ planInfo.children.foreach(
+ buildSparkPlanGraphNode(_, nodeIdGenerator, nodes, edges, node, subgraph))
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index de371d85d9..e00c762c67 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -31,14 +31,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
val df = sqlContext.range(10).filter("id = 1").selectExpr("id + 1")
val plan = df.queryExecution.executedPlan
assert(plan.find(_.isInstanceOf[WholeStageCodegen]).isDefined)
-
- checkThatPlansAgree(
- sqlContext.range(100),
- (p: SparkPlan) =>
- WholeStageCodegen(Filter('a == 1, InputAdapter(p)), Seq()),
- (p: SparkPlan) => Filter('a == 1, p),
- sortAnswers = false
- )
+ assert(df.collect() === Array(Row(2)))
}
test("Aggregate should be included in WholeStageCodegen") {
@@ -46,7 +39,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
val plan = df.queryExecution.executedPlan
assert(plan.find(p =>
p.isInstanceOf[WholeStageCodegen] &&
- p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined)
+ p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[TungstenAggregate]).isDefined)
assert(df.collect() === Array(Row(9, 4.5)))
}
@@ -55,7 +48,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
val plan = df.queryExecution.executedPlan
assert(plan.find(p =>
p.isInstanceOf[WholeStageCodegen] &&
- p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined)
+ p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[TungstenAggregate]).isDefined)
assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1)))
}
@@ -66,7 +59,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
val df = sqlContext.range(10).join(broadcast(smallDF), col("k") === col("id"))
assert(df.queryExecution.executedPlan.find(p =>
p.isInstanceOf[WholeStageCodegen] &&
- p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[BroadcastHashJoin]).isDefined)
+ p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[BroadcastHashJoin]).isDefined)
assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2")))
}
@@ -75,7 +68,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
val plan = df.queryExecution.executedPlan
assert(plan.find(p =>
p.isInstanceOf[WholeStageCodegen] &&
- p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[Sort]).isDefined)
+ p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[Sort]).isDefined)
assert(df.collect() === Array(Row(1), Row(2), Row(3)))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index 4358c7c76d..b0d64aa7bf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -210,8 +210,8 @@ class JDBCSuite extends SparkFunSuite
// the plan only has PhysicalRDD to scan JDBCRelation.
assert(parentPlan.isInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen])
val node = parentPlan.asInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen]
- assert(node.plan.isInstanceOf[org.apache.spark.sql.execution.PhysicalRDD])
- assert(node.plan.asInstanceOf[PhysicalRDD].nodeName.contains("JDBCRelation"))
+ assert(node.child.isInstanceOf[org.apache.spark.sql.execution.PhysicalRDD])
+ assert(node.child.asInstanceOf[PhysicalRDD].nodeName.contains("JDBCRelation"))
df
}
assert(checkPushdown(sql("SELECT * FROM foobar WHERE THEID < 1")).collect().size == 0)
@@ -248,7 +248,7 @@ class JDBCSuite extends SparkFunSuite
// cannot compile given predicates.
assert(parentPlan.isInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen])
val node = parentPlan.asInstanceOf[org.apache.spark.sql.execution.WholeStageCodegen]
- assert(node.plan.isInstanceOf[org.apache.spark.sql.execution.Filter])
+ assert(node.child.isInstanceOf[org.apache.spark.sql.execution.Filter])
df
}
assert(checkNotPushdown(sql("SELECT * FROM foobar WHERE (THEID + 1) < 2")).collect().size == 0)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
index 15a95623d1..e7d2b5ad96 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
@@ -93,7 +93,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
val metric = qe.executedPlan match {
- case w: WholeStageCodegen => w.plan.longMetric("numOutputRows")
+ case w: WholeStageCodegen => w.child.longMetric("numOutputRows")
case other => other.longMetric("numOutputRows")
}
metrics += metric.value.value