From ffdd1fcd1e8f4f6453d5b0517c0ce82766b8e75f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 14 Dec 2016 16:12:14 -0800 Subject: [SPARK-18854][SQL] numberedTreeString and apply(i) inconsistent for subqueries ## What changes were proposed in this pull request? This is a bug introduced by subquery handling. numberedTreeString (which uses generateTreeString under the hood) numbers trees including innerChildren (used to print subqueries), but apply (which uses getNodeNumbered) ignores innerChildren. As a result, apply(i) would return the wrong plan node if there are subqueries. This patch fixes the bug. ## How was this patch tested? Added a test case in SubquerySuite.scala to test both the depth-first traversal of numbering as well as making sure the two methods are consistent. Author: Reynold Xin Closes #16277 from rxin/SPARK-18854. --- .../spark/sql/catalyst/plans/QueryPlan.scala | 9 +++++ .../plans/logical/basicLogicalOperators.scala | 2 +- .../apache/spark/sql/catalyst/trees/TreeNode.scala | 46 ++++++++++++---------- 3 files changed, 36 insertions(+), 21 deletions(-) (limited to 'sql/catalyst/src') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index b108017c4c..e67f2be6d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -24,6 +24,15 @@ import org.apache.spark.sql.types.{DataType, StructType} abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanType] { self: PlanType => + /** + * Override [[TreeNode.apply]] to so we can return a more narrow type. + * + * Note that this cannot return BaseType because logical plan's plan node might return + * physical plan for innerChildren, e.g. in-memory relation logical plan node has a reference + * to the physical plan node it is referencing. + */ + override def apply(number: Int): QueryPlan[_] = super.apply(number).asInstanceOf[QueryPlan[_]] + def output: Seq[Attribute] /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index b9bdd53dd1..0de5aa8a93 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -393,7 +393,7 @@ case class With(child: LogicalPlan, cteRelations: Seq[(String, SubqueryAlias)]) s"CTE $cteAliases" } - override def innerChildren: Seq[QueryPlan[_]] = cteRelations.map(_._2) + override def innerChildren: Seq[LogicalPlan] = cteRelations.map(_._2) } case class WithWindowDefinition( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index ea8d8fef7b..670fa2bc8d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.trees import java.util.UUID import scala.collection.Map -import scala.collection.mutable.Stack import scala.reflect.ClassTag import org.apache.commons.lang3.ClassUtils @@ -28,12 +27,9 @@ import org.json4s.JsonAST._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.SparkContext -import org.apache.spark.rdd.{EmptyRDD, RDD} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, FunctionResource} import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.ScalaReflection._ -import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ @@ -493,7 +489,10 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { /** * Returns a string representation of the nodes in this tree, where each operator is numbered. - * The numbers can be used with [[trees.TreeNode.apply apply]] to easily access specific subtrees. + * The numbers can be used with [[TreeNode.apply]] to easily access specific subtrees. + * + * The numbers are based on depth-first traversal of the tree (with innerChildren traversed first + * before children). */ def numberedTreeString: String = treeString.split("\n").zipWithIndex.map { case (line, i) => f"$i%02d $line" }.mkString("\n") @@ -501,17 +500,24 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { /** * Returns the tree node at the specified number. * Numbers for each node can be found in the [[numberedTreeString]]. + * + * Note that this cannot return BaseType because logical plan's plan node might return + * physical plan for innerChildren, e.g. in-memory relation logical plan node has a reference + * to the physical plan node it is referencing. */ - def apply(number: Int): BaseType = getNodeNumbered(new MutableInt(number)) + def apply(number: Int): TreeNode[_] = getNodeNumbered(new MutableInt(number)).orNull - protected def getNodeNumbered(number: MutableInt): BaseType = { + private def getNodeNumbered(number: MutableInt): Option[TreeNode[_]] = { if (number.i < 0) { - null.asInstanceOf[BaseType] + None } else if (number.i == 0) { - this + Some(this) } else { number.i -= 1 - children.map(_.getNodeNumbered(number)).find(_ != null).getOrElse(null.asInstanceOf[BaseType]) + // Note that this traversal order must be the same as numberedTreeString. + innerChildren.map(_.getNodeNumbered(number)).find(_ != None).getOrElse { + children.map(_.getNodeNumbered(number)).find(_ != None).flatten + } } } @@ -527,6 +533,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { * The `i`-th element in `lastChildren` indicates whether the ancestor of the current node at * depth `i + 1` is the last child of its own parent node. The depth of the root node is 0, and * `lastChildren` for the root node should be empty. + * + * Note that this traversal (numbering) order must be the same as [[getNodeNumbered]]. */ def generateTreeString( depth: Int, @@ -534,19 +542,16 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { builder: StringBuilder, verbose: Boolean, prefix: String = ""): StringBuilder = { + if (depth > 0) { lastChildren.init.foreach { isLast => - val prefixFragment = if (isLast) " " else ": " - builder.append(prefixFragment) + builder.append(if (isLast) " " else ": ") } - - val branch = if (lastChildren.last) "+- " else ":- " - builder.append(branch) + builder.append(if (lastChildren.last) "+- " else ":- ") } builder.append(prefix) - val headline = if (verbose) verboseString else simpleString - builder.append(headline) + builder.append(if (verbose) verboseString else simpleString) builder.append("\n") if (innerChildren.nonEmpty) { @@ -557,9 +562,10 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } if (children.nonEmpty) { - children.init.foreach( - _.generateTreeString(depth + 1, lastChildren :+ false, builder, verbose, prefix)) - children.last.generateTreeString(depth + 1, lastChildren :+ true, builder, verbose, prefix) + children.init.foreach(_.generateTreeString( + depth + 1, lastChildren :+ false, builder, verbose, prefix)) + children.last.generateTreeString( + depth + 1, lastChildren :+ true, builder, verbose, prefix) } builder -- cgit v1.2.3