aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2014-12-16 15:31:19 -0800
committerMichael Armbrust <michael@databricks.com>2014-12-16 15:31:19 -0800
commita66c23e134a0b1ad9540626fb7436d70d577d929 (patch)
tree784eea866272bab34f2f82155b69c0a1b7cfe5d5 /sql
parent30f6b85c816d1ef611a7be071af0053d64b6fe9e (diff)
downloadspark-a66c23e134a0b1ad9540626fb7436d70d577d929.tar.gz
spark-a66c23e134a0b1ad9540626fb7436d70d577d929.tar.bz2
spark-a66c23e134a0b1ad9540626fb7436d70d577d929.zip
[SPARK-4827][SQL] Fix resolution of deeply nested Project(attr, Project(Star,...)).
Since `AttributeReference` resolution and `*` expansion are currently in separate rules, each pair requires a full iteration instead of being able to resolve in a single pass. Since its pretty easy to construct queries that have many of these in a row, I combine them into a single rule in this PR. Author: Michael Armbrust <michael@databricks.com> Closes #3674 from marmbrus/projectStars and squashes the following commits: d83d6a1 [Michael Armbrust] Fix resolution of deeply nested Project(attr, Project(Star,...)).
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala75
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala12
2 files changed, 45 insertions, 42 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 04639219a3..ea9bb39786 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -58,7 +58,6 @@ class Analyzer(catalog: Catalog,
ResolveSortReferences ::
NewRelationInstances ::
ImplicitGenerate ::
- StarExpansion ::
ResolveFunctions ::
GlobalAggregates ::
UnresolvedHavingClauseAttributes ::
@@ -153,7 +152,34 @@ class Analyzer(catalog: Catalog,
*/
object ResolveReferences extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
- case q: LogicalPlan if q.childrenResolved =>
+ case p: LogicalPlan if !p.childrenResolved => p
+
+ // If the projection list contains Stars, expand it.
+ case p@Project(projectList, child) if containsStar(projectList) =>
+ Project(
+ projectList.flatMap {
+ case s: Star => s.expand(child.output, resolver)
+ case o => o :: Nil
+ },
+ child)
+ case t: ScriptTransformation if containsStar(t.input) =>
+ t.copy(
+ input = t.input.flatMap {
+ case s: Star => s.expand(t.child.output, resolver)
+ case o => o :: Nil
+ }
+ )
+
+ // If the aggregate function argument contains Stars, expand it.
+ case a: Aggregate if containsStar(a.aggregateExpressions) =>
+ a.copy(
+ aggregateExpressions = a.aggregateExpressions.flatMap {
+ case s: Star => s.expand(a.child.output, resolver)
+ case o => o :: Nil
+ }
+ )
+
+ case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString}")
q transformExpressions {
case u @ UnresolvedAttribute(name) =>
@@ -163,6 +189,12 @@ class Analyzer(catalog: Catalog,
result
}
}
+
+ /**
+ * Returns true if `exprs` contains a [[Star]].
+ */
+ protected def containsStar(exprs: Seq[Expression]): Boolean =
+ exprs.collect { case _: Star => true}.nonEmpty
}
/**
@@ -277,45 +309,6 @@ class Analyzer(catalog: Catalog,
Generate(g, join = false, outer = false, None, child)
}
}
-
- /**
- * Expands any references to [[Star]] (*) in project operators.
- */
- object StarExpansion extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- // Wait until children are resolved
- case p: LogicalPlan if !p.childrenResolved => p
- // If the projection list contains Stars, expand it.
- case p @ Project(projectList, child) if containsStar(projectList) =>
- Project(
- projectList.flatMap {
- case s: Star => s.expand(child.output, resolver)
- case o => o :: Nil
- },
- child)
- case t: ScriptTransformation if containsStar(t.input) =>
- t.copy(
- input = t.input.flatMap {
- case s: Star => s.expand(t.child.output, resolver)
- case o => o :: Nil
- }
- )
- // If the aggregate function argument contains Stars, expand it.
- case a: Aggregate if containsStar(a.aggregateExpressions) =>
- a.copy(
- aggregateExpressions = a.aggregateExpressions.flatMap {
- case s: Star => s.expand(a.child.output, resolver)
- case o => o :: Nil
- }
- )
- }
-
- /**
- * Returns true if `exprs` contains a [[Star]].
- */
- protected def containsStar(exprs: Seq[Expression]): Boolean =
- exprs.collect { case _: Star => true }.nonEmpty
- }
}
/**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 33a3cba3d4..82f2101d8c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -19,12 +19,14 @@ package org.apache.spark.sql.catalyst.analysis
import org.scalatest.{BeforeAndAfter, FunSuite}
-import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference}
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+
class AnalysisSuite extends FunSuite with BeforeAndAfter {
val caseSensitiveCatalog = new SimpleCatalog(true)
val caseInsensitiveCatalog = new SimpleCatalog(false)
@@ -46,6 +48,14 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
caseInsensitiveCatalog.registerTable(None, "TaBlE", testRelation)
}
+ test("union project *") {
+ val plan = (1 to 100)
+ .map(_ => testRelation)
+ .fold[LogicalPlan](testRelation)((a,b) => a.select(Star(None)).select('a).unionAll(b.select(Star(None))))
+
+ assert(caseInsensitiveAnalyze(plan).resolved)
+ }
+
test("analyze project") {
assert(
caseSensitiveAnalyze(Project(Seq(UnresolvedAttribute("a")), testRelation)) ===