From a66c23e134a0b1ad9540626fb7436d70d577d929 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 16 Dec 2014 15:31:19 -0800 Subject: [SPARK-4827][SQL] Fix resolution of deeply nested Project(attr, Project(Star,...)). Since `AttributeReference` resolution and `*` expansion are currently in separate rules, each pair requires a full iteration instead of being able to resolve in a single pass. Since its pretty easy to construct queries that have many of these in a row, I combine them into a single rule in this PR. Author: Michael Armbrust Closes #3674 from marmbrus/projectStars and squashes the following commits: d83d6a1 [Michael Armbrust] Fix resolution of deeply nested Project(attr, Project(Star,...)). --- .../spark/sql/catalyst/analysis/Analyzer.scala | 75 ++++++++++------------ .../sql/catalyst/analysis/AnalysisSuite.scala | 12 +++- 2 files changed, 45 insertions(+), 42 deletions(-) (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 04639219a3..ea9bb39786 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -58,7 +58,6 @@ class Analyzer(catalog: Catalog, ResolveSortReferences :: NewRelationInstances :: ImplicitGenerate :: - StarExpansion :: ResolveFunctions :: GlobalAggregates :: UnresolvedHavingClauseAttributes :: @@ -153,7 +152,34 @@ class Analyzer(catalog: Catalog, */ object ResolveReferences extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case q: LogicalPlan if q.childrenResolved => + case p: LogicalPlan if !p.childrenResolved => p + + // If the projection list contains Stars, expand it. + case p@Project(projectList, child) if containsStar(projectList) => + Project( + projectList.flatMap { + case s: Star => s.expand(child.output, resolver) + case o => o :: Nil + }, + child) + case t: ScriptTransformation if containsStar(t.input) => + t.copy( + input = t.input.flatMap { + case s: Star => s.expand(t.child.output, resolver) + case o => o :: Nil + } + ) + + // If the aggregate function argument contains Stars, expand it. + case a: Aggregate if containsStar(a.aggregateExpressions) => + a.copy( + aggregateExpressions = a.aggregateExpressions.flatMap { + case s: Star => s.expand(a.child.output, resolver) + case o => o :: Nil + } + ) + + case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") q transformExpressions { case u @ UnresolvedAttribute(name) => @@ -163,6 +189,12 @@ class Analyzer(catalog: Catalog, result } } + + /** + * Returns true if `exprs` contains a [[Star]]. + */ + protected def containsStar(exprs: Seq[Expression]): Boolean = + exprs.collect { case _: Star => true}.nonEmpty } /** @@ -277,45 +309,6 @@ class Analyzer(catalog: Catalog, Generate(g, join = false, outer = false, None, child) } } - - /** - * Expands any references to [[Star]] (*) in project operators. - */ - object StarExpansion extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // Wait until children are resolved - case p: LogicalPlan if !p.childrenResolved => p - // If the projection list contains Stars, expand it. - case p @ Project(projectList, child) if containsStar(projectList) => - Project( - projectList.flatMap { - case s: Star => s.expand(child.output, resolver) - case o => o :: Nil - }, - child) - case t: ScriptTransformation if containsStar(t.input) => - t.copy( - input = t.input.flatMap { - case s: Star => s.expand(t.child.output, resolver) - case o => o :: Nil - } - ) - // If the aggregate function argument contains Stars, expand it. - case a: Aggregate if containsStar(a.aggregateExpressions) => - a.copy( - aggregateExpressions = a.aggregateExpressions.flatMap { - case s: Star => s.expand(a.child.output, resolver) - case o => o :: Nil - } - ) - } - - /** - * Returns true if `exprs` contains a [[Star]]. - */ - protected def containsStar(exprs: Seq[Expression]): Boolean = - exprs.collect { case _: Star => true }.nonEmpty - } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 33a3cba3d4..82f2101d8c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -19,12 +19,14 @@ package org.apache.spark.sql.catalyst.analysis import org.scalatest.{BeforeAndAfter, FunSuite} -import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ + class AnalysisSuite extends FunSuite with BeforeAndAfter { val caseSensitiveCatalog = new SimpleCatalog(true) val caseInsensitiveCatalog = new SimpleCatalog(false) @@ -46,6 +48,14 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { caseInsensitiveCatalog.registerTable(None, "TaBlE", testRelation) } + test("union project *") { + val plan = (1 to 100) + .map(_ => testRelation) + .fold[LogicalPlan](testRelation)((a,b) => a.select(Star(None)).select('a).unionAll(b.select(Star(None)))) + + assert(caseInsensitiveAnalyze(plan).resolved) + } + test("analyze project") { assert( caseSensitiveAnalyze(Project(Seq(UnresolvedAttribute("a")), testRelation)) === -- cgit v1.2.3