aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2016-12-14 16:12:14 -0800
committerReynold Xin <rxin@databricks.com>2016-12-14 16:12:14 -0800
commitffdd1fcd1e8f4f6453d5b0517c0ce82766b8e75f (patch)
tree130adab0dec73b757c63949aa87bba747e320554 /sql
parent78627425708a0afbe113efdf449e8622b43b652d (diff)
downloadspark-ffdd1fcd1e8f4f6453d5b0517c0ce82766b8e75f.tar.gz
spark-ffdd1fcd1e8f4f6453d5b0517c0ce82766b8e75f.tar.bz2
spark-ffdd1fcd1e8f4f6453d5b0517c0ce82766b8e75f.zip
[SPARK-18854][SQL] numberedTreeString and apply(i) inconsistent for subqueries
## What changes were proposed in this pull request? This is a bug introduced by subquery handling. numberedTreeString (which uses generateTreeString under the hood) numbers trees including innerChildren (used to print subqueries), but apply (which uses getNodeNumbered) ignores innerChildren. As a result, apply(i) would return the wrong plan node if there are subqueries. This patch fixes the bug. ## How was this patch tested? Added a test case in SubquerySuite.scala to test both the depth-first traversal of numbering as well as making sure the two methods are consistent. Author: Reynold Xin <rxin@databricks.com> Closes #16277 from rxin/SPARK-18854.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala9
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala46
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala18
5 files changed, 55 insertions, 23 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index b108017c4c..e67f2be6d2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -24,6 +24,15 @@ import org.apache.spark.sql.types.{DataType, StructType}
abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanType] {
self: PlanType =>
+ /**
+ * Override [[TreeNode.apply]] to so we can return a more narrow type.
+ *
+ * Note that this cannot return BaseType because logical plan's plan node might return
+ * physical plan for innerChildren, e.g. in-memory relation logical plan node has a reference
+ * to the physical plan node it is referencing.
+ */
+ override def apply(number: Int): QueryPlan[_] = super.apply(number).asInstanceOf[QueryPlan[_]]
+
def output: Seq[Attribute]
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index b9bdd53dd1..0de5aa8a93 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -393,7 +393,7 @@ case class With(child: LogicalPlan, cteRelations: Seq[(String, SubqueryAlias)])
s"CTE $cteAliases"
}
- override def innerChildren: Seq[QueryPlan[_]] = cteRelations.map(_._2)
+ override def innerChildren: Seq[LogicalPlan] = cteRelations.map(_._2)
}
case class WithWindowDefinition(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index ea8d8fef7b..670fa2bc8d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.trees
import java.util.UUID
import scala.collection.Map
-import scala.collection.mutable.Stack
import scala.reflect.ClassTag
import org.apache.commons.lang3.ClassUtils
@@ -28,12 +27,9 @@ import org.json4s.JsonAST._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
-import org.apache.spark.SparkContext
-import org.apache.spark.rdd.{EmptyRDD, RDD}
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, FunctionResource}
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.ScalaReflection._
-import org.apache.spark.sql.catalyst.ScalaReflectionLock
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
@@ -493,7 +489,10 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
/**
* Returns a string representation of the nodes in this tree, where each operator is numbered.
- * The numbers can be used with [[trees.TreeNode.apply apply]] to easily access specific subtrees.
+ * The numbers can be used with [[TreeNode.apply]] to easily access specific subtrees.
+ *
+ * The numbers are based on depth-first traversal of the tree (with innerChildren traversed first
+ * before children).
*/
def numberedTreeString: String =
treeString.split("\n").zipWithIndex.map { case (line, i) => f"$i%02d $line" }.mkString("\n")
@@ -501,17 +500,24 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
/**
* Returns the tree node at the specified number.
* Numbers for each node can be found in the [[numberedTreeString]].
+ *
+ * Note that this cannot return BaseType because logical plan's plan node might return
+ * physical plan for innerChildren, e.g. in-memory relation logical plan node has a reference
+ * to the physical plan node it is referencing.
*/
- def apply(number: Int): BaseType = getNodeNumbered(new MutableInt(number))
+ def apply(number: Int): TreeNode[_] = getNodeNumbered(new MutableInt(number)).orNull
- protected def getNodeNumbered(number: MutableInt): BaseType = {
+ private def getNodeNumbered(number: MutableInt): Option[TreeNode[_]] = {
if (number.i < 0) {
- null.asInstanceOf[BaseType]
+ None
} else if (number.i == 0) {
- this
+ Some(this)
} else {
number.i -= 1
- children.map(_.getNodeNumbered(number)).find(_ != null).getOrElse(null.asInstanceOf[BaseType])
+ // Note that this traversal order must be the same as numberedTreeString.
+ innerChildren.map(_.getNodeNumbered(number)).find(_ != None).getOrElse {
+ children.map(_.getNodeNumbered(number)).find(_ != None).flatten
+ }
}
}
@@ -527,6 +533,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
* The `i`-th element in `lastChildren` indicates whether the ancestor of the current node at
* depth `i + 1` is the last child of its own parent node. The depth of the root node is 0, and
* `lastChildren` for the root node should be empty.
+ *
+ * Note that this traversal (numbering) order must be the same as [[getNodeNumbered]].
*/
def generateTreeString(
depth: Int,
@@ -534,19 +542,16 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
builder: StringBuilder,
verbose: Boolean,
prefix: String = ""): StringBuilder = {
+
if (depth > 0) {
lastChildren.init.foreach { isLast =>
- val prefixFragment = if (isLast) " " else ": "
- builder.append(prefixFragment)
+ builder.append(if (isLast) " " else ": ")
}
-
- val branch = if (lastChildren.last) "+- " else ":- "
- builder.append(branch)
+ builder.append(if (lastChildren.last) "+- " else ":- ")
}
builder.append(prefix)
- val headline = if (verbose) verboseString else simpleString
- builder.append(headline)
+ builder.append(if (verbose) verboseString else simpleString)
builder.append("\n")
if (innerChildren.nonEmpty) {
@@ -557,9 +562,10 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
}
if (children.nonEmpty) {
- children.init.foreach(
- _.generateTreeString(depth + 1, lastChildren :+ false, builder, verbose, prefix))
- children.last.generateTreeString(depth + 1, lastChildren :+ true, builder, verbose, prefix)
+ children.init.foreach(_.generateTreeString(
+ depth + 1, lastChildren :+ false, builder, verbose, prefix))
+ children.last.generateTreeString(
+ depth + 1, lastChildren :+ true, builder, verbose, prefix)
}
builder
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
index 56bd5c1891..03cc04659b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
@@ -24,7 +24,6 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.execution.SparkPlan
@@ -64,7 +63,7 @@ case class InMemoryRelation(
val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator)
extends logical.LeafNode with MultiInstanceRelation {
- override protected def innerChildren: Seq[QueryPlan[_]] = Seq(child)
+ override protected def innerChildren: Seq[SparkPlan] = Seq(child)
override def producedAttributes: AttributeSet = outputSet
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index 5a4b1cfe95..2ef8b18c04 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -54,6 +54,24 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
t.createOrReplaceTempView("t")
}
+ test("SPARK-18854 numberedTreeString for subquery") {
+ val df = sql("select * from range(10) where id not in " +
+ "(select id from range(2) union all select id from range(2))")
+
+ // The depth first traversal of the plan tree
+ val dfs = Seq("Project", "Filter", "Union", "Project", "Range", "Project", "Range", "Range")
+ val numbered = df.queryExecution.analyzed.numberedTreeString.split("\n")
+
+ // There should be 8 plan nodes in total
+ assert(numbered.size == dfs.size)
+
+ for (i <- dfs.indices) {
+ val node = df.queryExecution.analyzed(i)
+ assert(node.nodeName == dfs(i))
+ assert(numbered(i).contains(node.nodeName))
+ }
+ }
+
test("rdd deserialization does not crash [SPARK-15791]") {
sql("select (select 1 as b) as b").rdd.count()
}