diff options
Diffstat (limited to 'sql/catalyst')
3 files changed, 28 insertions, 3 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e9dbded3d4..c8ed4190a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -142,7 +142,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) object OptimizeSubqueries extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case s: SubqueryExpression => - s.withNewPlan(Optimizer.this.execute(s.plan)) + val Subquery(newPlan) = Optimizer.this.execute(Subquery(s.plan)) + s.withNewPlan(newPlan) } } } @@ -187,7 +188,10 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] { // If the alias name is different from attribute name, we can't strip it either, or we // may accidentally change the output schema name of the root plan. case a @ Alias(attr: Attribute, name) - if a.metadata == Metadata.empty && name == attr.name && !blacklist.contains(attr) => + if a.metadata == Metadata.empty && + name == attr.name && + !blacklist.contains(attr) && + !blacklist.contains(a) => attr case a => a } @@ -195,10 +199,15 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] { /** * Remove redundant alias expression from a LogicalPlan and its subtree. A blacklist is used to * prevent the removal of seemingly redundant aliases used to deduplicate the input for a (self) - * join. + * join or to prevent the removal of top-level subquery attributes. */ private def removeRedundantAliases(plan: LogicalPlan, blacklist: AttributeSet): LogicalPlan = { plan match { + // We want to keep the same output attributes for subqueries. This means we cannot remove + // the aliases that produce these attributes + case Subquery(child) => + Subquery(removeRedundantAliases(child, blacklist ++ child.outputSet)) + // A join has to be treated differently, because the left and the right side of the join are // not allowed to use the same attributes. We use a blacklist to prevent us from creating a // situation in which this happens; the rule will only remove an alias if its child diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 31b6ed48a2..5cbf263d1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -38,6 +38,14 @@ case class ReturnAnswer(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output } +/** + * This node is inserted at the top of a subquery when it is optimized. This makes sure we can + * recognize a subquery as such, and it allows us to write subquery aware transformations. + */ +case class Subquery(child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = projectList.map(_.toAttribute) override def maxRows: Option[Long] = child.maxRows diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala index c01ea01ec6..1973b5abb4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala @@ -116,4 +116,12 @@ class RemoveRedundantAliasAndProjectSuite extends PlanTest with PredicateHelper val expected = relation.window(Seq('b), Seq('a), Seq()).analyze comparePlans(optimized, expected) } + + test("do not remove output attributes from a subquery") { + val relation = LocalRelation('a.int, 'b.int) + val query = Subquery(relation.select('a as "a", 'b as "b").where('b < 10).select('a).analyze) + val optimized = Optimize.execute(query) + val expected = Subquery(relation.select('a as "a", 'b).where('b < 10).select('a).analyze) + comparePlans(optimized, expected) + } } |