aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala133
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala15
3 files changed, 67 insertions, 83 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index c0fa79612a..26c3d286b1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.catalyst.analysis
-import scala.annotation.tailrec
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.AnalysisException
@@ -598,98 +597,69 @@ class Analyzer(
// Skip sort with aggregate. This will be handled in ResolveAggregateFunctions
case sa @ Sort(_, _, child: Aggregate) => sa
- case s @ Sort(_, _, child) if !s.resolved && child.resolved =>
- val (newOrdering, missingResolvableAttrs) = collectResolvableMissingAttrs(s.order, child)
-
- if (missingResolvableAttrs.isEmpty) {
- val unresolvableAttrs = s.order.filterNot(_.resolved)
- logDebug(s"Failed to find $unresolvableAttrs in ${child.output.mkString(", ")}")
- s // Nothing we can do here. Return original plan.
- } else {
- // Add the missing attributes into projectList of Project/Window or
- // aggregateExpressions of Aggregate, if they are in the inputSet
- // but not in the outputSet of the plan.
- val newChild = child transformUp {
- case p: Project =>
- p.copy(projectList = p.projectList ++
- missingResolvableAttrs.filter((p.inputSet -- p.outputSet).contains))
- case w: Window =>
- w.copy(projectList = w.projectList ++
- missingResolvableAttrs.filter((w.inputSet -- w.outputSet).contains))
- case a: Aggregate =>
- val resolvableAttrs = missingResolvableAttrs.filter(a.groupingExpressions.contains)
- val notResolvedAttrs = resolvableAttrs.filterNot(a.aggregateExpressions.contains)
- val newAggregateExpressions = a.aggregateExpressions ++ notResolvedAttrs
- a.copy(aggregateExpressions = newAggregateExpressions)
- case o => o
- }
-
+ case s @ Sort(order, _, child) if !s.resolved && child.resolved =>
+ val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder])
+ val requiredAttrs = AttributeSet(newOrder).filter(_.resolved)
+ val missingAttrs = requiredAttrs -- child.outputSet
+ if (missingAttrs.nonEmpty) {
// Add missing attributes and then project them away after the sort.
Project(child.output,
- Sort(newOrdering, s.global, newChild))
+ Sort(newOrder, s.global, addMissingAttr(child, missingAttrs)))
+ } else if (newOrder != order) {
+ s.copy(order = newOrder)
+ } else {
+ s
}
}
/**
- * Traverse the tree until resolving the sorting attributes
- * Return all the resolvable missing sorting attributes
- */
- @tailrec
- private def collectResolvableMissingAttrs(
- ordering: Seq[SortOrder],
- plan: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = {
+ * Add the missing attributes into projectList of Project/Window or aggregateExpressions of
+ * Aggregate.
+ */
+ private def addMissingAttr(plan: LogicalPlan, missingAttrs: AttributeSet): LogicalPlan = {
+ if (missingAttrs.isEmpty) {
+ return plan
+ }
plan match {
- // Only Windows and Project have projectList-like attribute.
- case un: UnaryNode if un.isInstanceOf[Project] || un.isInstanceOf[Window] =>
- val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, un, un.child)
- // If missingAttrs is non empty, that means we got it and return it;
- // Otherwise, continue to traverse the tree.
- if (missingAttrs.nonEmpty) {
- (newOrdering, missingAttrs)
- } else {
- collectResolvableMissingAttrs(ordering, un.child)
- }
+ case p: Project =>
+ val missing = missingAttrs -- p.child.outputSet
+ Project(p.projectList ++ missingAttrs, addMissingAttr(p.child, missing))
+ case w: Window =>
+ val missing = missingAttrs -- w.child.outputSet
+ w.copy(projectList = w.projectList ++ missingAttrs,
+ child = addMissingAttr(w.child, missing))
case a: Aggregate =>
- val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, a, a.child)
- // For Aggregate, all the order by columns must be specified in group by clauses
- if (missingAttrs.nonEmpty &&
- missingAttrs.forall(ar => a.groupingExpressions.exists(_.semanticEquals(ar)))) {
- (newOrdering, missingAttrs)
- } else {
- // If missingAttrs is empty, we are unable to resolve any unresolved missing attributes
- (Seq.empty[SortOrder], Seq.empty[Attribute])
+ // all the missing attributes should be grouping expressions
+ // TODO: push down AggregateExpression
+ missingAttrs.foreach { attr =>
+ if (!a.groupingExpressions.exists(_.semanticEquals(attr))) {
+ throw new AnalysisException(s"Can't add $attr to ${a.simpleString}")
+ }
}
- // Jump over the following UnaryNode types
- // The output of these types is the same as their child's output
- case _: Distinct |
- _: Filter |
- _: RepartitionByExpression =>
- collectResolvableMissingAttrs(ordering, plan.asInstanceOf[UnaryNode].child)
- // If hitting the other unsupported operators, we are unable to resolve it.
- case other => (Seq.empty[SortOrder], Seq.empty[Attribute])
+ val newAggregateExpressions = a.aggregateExpressions ++ missingAttrs
+ a.copy(aggregateExpressions = newAggregateExpressions)
+ case u: UnaryNode =>
+ u.withNewChildren(addMissingAttr(u.child, missingAttrs) :: Nil)
+ case other =>
+ throw new AnalysisException(s"Can't add $missingAttrs to $other")
}
}
/**
- * Try to resolve the sort ordering and returns it with a list of attributes that are missing
- * from the plan but are present in the child.
- */
- private def resolveAndFindMissing(
- ordering: Seq[SortOrder],
- plan: LogicalPlan,
- child: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = {
- val newOrdering =
- ordering.map(order => resolveExpression(order, child).asInstanceOf[SortOrder])
- // Construct a set that contains all of the attributes that we need to evaluate the
- // ordering.
- val requiredAttributes = AttributeSet(newOrdering).filter(_.resolved)
- // Figure out which ones are missing from the projection, so that we can add them and
- // remove them after the sort.
- val missingInProject = requiredAttributes -- plan.outputSet
- // It is important to return the new SortOrders here, instead of waiting for the standard
- // resolving process as adding attributes to the project below can actually introduce
- // ambiguity that was not present before.
- (newOrdering, missingInProject.toSeq)
+ * Resolve the expression on a specified logical plan and it's child (recursively), until
+ * the expression is resolved or meet a non-unary node or Subquery.
+ */
+ private def resolveExpressionRecursively(expr: Expression, plan: LogicalPlan): Expression = {
+ val resolved = resolveExpression(expr, plan)
+ if (resolved.resolved) {
+ resolved
+ } else {
+ plan match {
+ case u: UnaryNode if !u.isInstanceOf[Subquery] =>
+ resolveExpressionRecursively(resolved, u.child)
+ case other => resolved
+ }
+ }
}
}
@@ -782,8 +752,7 @@ class Analyzer(
filter
}
- case sort @ Sort(sortOrder, global, aggregate: Aggregate)
- if aggregate.resolved =>
+ case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved =>
// Try resolving the ordering as though it is in the aggregate clause.
try {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index ebf885a8fe..f85ae24e04 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -90,7 +90,7 @@ class AnalysisSuite extends AnalysisTest {
.where(a > "str").select(a, b, c)
.where(b > "str").select(a, b, c)
.sortBy(b.asc, c.desc)
- .select(a, b).select(a)
+ .select(a)
checkAnalysis(plan1, expected1)
// Case 2: all the missing attributes are in the leaf node
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index 6048b8f5a3..be864f79d6 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -978,6 +978,21 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
("d", 1),
("c", 2)
).map(i => Row(i._1, i._2)))
+
+ checkAnswer(
+ sql(
+ """
+ |select area, sum(product) / sum(sum(product)) over (partition by area) as c1
+ |from windowData group by area, month order by month, c1
+ """.stripMargin),
+ Seq(
+ ("d", 1.0),
+ ("a", 1.0),
+ ("b", 0.4666666666666667),
+ ("b", 0.5333333333333333),
+ ("c", 0.45),
+ ("c", 0.55)
+ ).map(i => Row(i._1, i._2)))
}
// todo: fix this test case by reimplementing the function ResolveAggregateFunctions