aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-02-12 09:34:18 -0800
committerDavies Liu <davies.liu@gmail.com>2016-02-12 09:34:18 -0800
commit5b805df279d744543851f06e5a0d741354ef485b (patch)
treee23d9f7d5a4f5851e216accba242ec0e8ce73e4d /sql
parent64515e5fbfd694d06fdbc28040fce7baf90a32aa (diff)
downloadspark-5b805df279d744543851f06e5a0d741354ef485b.tar.gz
spark-5b805df279d744543851f06e5a0d741354ef485b.tar.bz2
spark-5b805df279d744543851f06e5a0d741354ef485b.zip
[SPARK-12705] [SQL] push missing attributes for Sort
The current implementation of ResolveSortReferences can only push one missing attributes into it's child, it failed to analyze TPCDS Q98, because of there are two missing attributes in that (one from Window, another from Aggregate). Author: Davies Liu <davies@databricks.com> Closes #11153 from davies/resolve_sort.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala133
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala15
3 files changed, 67 insertions, 83 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index c0fa79612a..26c3d286b1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.catalyst.analysis
-import scala.annotation.tailrec
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.AnalysisException
@@ -598,98 +597,69 @@ class Analyzer(
// Skip sort with aggregate. This will be handled in ResolveAggregateFunctions
case sa @ Sort(_, _, child: Aggregate) => sa
- case s @ Sort(_, _, child) if !s.resolved && child.resolved =>
- val (newOrdering, missingResolvableAttrs) = collectResolvableMissingAttrs(s.order, child)
-
- if (missingResolvableAttrs.isEmpty) {
- val unresolvableAttrs = s.order.filterNot(_.resolved)
- logDebug(s"Failed to find $unresolvableAttrs in ${child.output.mkString(", ")}")
- s // Nothing we can do here. Return original plan.
- } else {
- // Add the missing attributes into projectList of Project/Window or
- // aggregateExpressions of Aggregate, if they are in the inputSet
- // but not in the outputSet of the plan.
- val newChild = child transformUp {
- case p: Project =>
- p.copy(projectList = p.projectList ++
- missingResolvableAttrs.filter((p.inputSet -- p.outputSet).contains))
- case w: Window =>
- w.copy(projectList = w.projectList ++
- missingResolvableAttrs.filter((w.inputSet -- w.outputSet).contains))
- case a: Aggregate =>
- val resolvableAttrs = missingResolvableAttrs.filter(a.groupingExpressions.contains)
- val notResolvedAttrs = resolvableAttrs.filterNot(a.aggregateExpressions.contains)
- val newAggregateExpressions = a.aggregateExpressions ++ notResolvedAttrs
- a.copy(aggregateExpressions = newAggregateExpressions)
- case o => o
- }
-
+ case s @ Sort(order, _, child) if !s.resolved && child.resolved =>
+ val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder])
+ val requiredAttrs = AttributeSet(newOrder).filter(_.resolved)
+ val missingAttrs = requiredAttrs -- child.outputSet
+ if (missingAttrs.nonEmpty) {
// Add missing attributes and then project them away after the sort.
Project(child.output,
- Sort(newOrdering, s.global, newChild))
+ Sort(newOrder, s.global, addMissingAttr(child, missingAttrs)))
+ } else if (newOrder != order) {
+ s.copy(order = newOrder)
+ } else {
+ s
}
}
/**
- * Traverse the tree until resolving the sorting attributes
- * Return all the resolvable missing sorting attributes
- */
- @tailrec
- private def collectResolvableMissingAttrs(
- ordering: Seq[SortOrder],
- plan: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = {
+ * Add the missing attributes into projectList of Project/Window or aggregateExpressions of
+ * Aggregate.
+ */
+ private def addMissingAttr(plan: LogicalPlan, missingAttrs: AttributeSet): LogicalPlan = {
+ if (missingAttrs.isEmpty) {
+ return plan
+ }
plan match {
- // Only Windows and Project have projectList-like attribute.
- case un: UnaryNode if un.isInstanceOf[Project] || un.isInstanceOf[Window] =>
- val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, un, un.child)
- // If missingAttrs is non empty, that means we got it and return it;
- // Otherwise, continue to traverse the tree.
- if (missingAttrs.nonEmpty) {
- (newOrdering, missingAttrs)
- } else {
- collectResolvableMissingAttrs(ordering, un.child)
- }
+ case p: Project =>
+ val missing = missingAttrs -- p.child.outputSet
+ Project(p.projectList ++ missingAttrs, addMissingAttr(p.child, missing))
+ case w: Window =>
+ val missing = missingAttrs -- w.child.outputSet
+ w.copy(projectList = w.projectList ++ missingAttrs,
+ child = addMissingAttr(w.child, missing))
case a: Aggregate =>
- val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, a, a.child)
- // For Aggregate, all the order by columns must be specified in group by clauses
- if (missingAttrs.nonEmpty &&
- missingAttrs.forall(ar => a.groupingExpressions.exists(_.semanticEquals(ar)))) {
- (newOrdering, missingAttrs)
- } else {
- // If missingAttrs is empty, we are unable to resolve any unresolved missing attributes
- (Seq.empty[SortOrder], Seq.empty[Attribute])
+ // all the missing attributes should be grouping expressions
+ // TODO: push down AggregateExpression
+ missingAttrs.foreach { attr =>
+ if (!a.groupingExpressions.exists(_.semanticEquals(attr))) {
+ throw new AnalysisException(s"Can't add $attr to ${a.simpleString}")
+ }
}
- // Jump over the following UnaryNode types
- // The output of these types is the same as their child's output
- case _: Distinct |
- _: Filter |
- _: RepartitionByExpression =>
- collectResolvableMissingAttrs(ordering, plan.asInstanceOf[UnaryNode].child)
- // If hitting the other unsupported operators, we are unable to resolve it.
- case other => (Seq.empty[SortOrder], Seq.empty[Attribute])
+ val newAggregateExpressions = a.aggregateExpressions ++ missingAttrs
+ a.copy(aggregateExpressions = newAggregateExpressions)
+ case u: UnaryNode =>
+ u.withNewChildren(addMissingAttr(u.child, missingAttrs) :: Nil)
+ case other =>
+ throw new AnalysisException(s"Can't add $missingAttrs to $other")
}
}
/**
- * Try to resolve the sort ordering and returns it with a list of attributes that are missing
- * from the plan but are present in the child.
- */
- private def resolveAndFindMissing(
- ordering: Seq[SortOrder],
- plan: LogicalPlan,
- child: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = {
- val newOrdering =
- ordering.map(order => resolveExpression(order, child).asInstanceOf[SortOrder])
- // Construct a set that contains all of the attributes that we need to evaluate the
- // ordering.
- val requiredAttributes = AttributeSet(newOrdering).filter(_.resolved)
- // Figure out which ones are missing from the projection, so that we can add them and
- // remove them after the sort.
- val missingInProject = requiredAttributes -- plan.outputSet
- // It is important to return the new SortOrders here, instead of waiting for the standard
- // resolving process as adding attributes to the project below can actually introduce
- // ambiguity that was not present before.
- (newOrdering, missingInProject.toSeq)
+ * Resolve the expression on a specified logical plan and it's child (recursively), until
+ * the expression is resolved or meet a non-unary node or Subquery.
+ */
+ private def resolveExpressionRecursively(expr: Expression, plan: LogicalPlan): Expression = {
+ val resolved = resolveExpression(expr, plan)
+ if (resolved.resolved) {
+ resolved
+ } else {
+ plan match {
+ case u: UnaryNode if !u.isInstanceOf[Subquery] =>
+ resolveExpressionRecursively(resolved, u.child)
+ case other => resolved
+ }
+ }
}
}
@@ -782,8 +752,7 @@ class Analyzer(
filter
}
- case sort @ Sort(sortOrder, global, aggregate: Aggregate)
- if aggregate.resolved =>
+ case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved =>
// Try resolving the ordering as though it is in the aggregate clause.
try {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index ebf885a8fe..f85ae24e04 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -90,7 +90,7 @@ class AnalysisSuite extends AnalysisTest {
.where(a > "str").select(a, b, c)
.where(b > "str").select(a, b, c)
.sortBy(b.asc, c.desc)
- .select(a, b).select(a)
+ .select(a)
checkAnalysis(plan1, expected1)
// Case 2: all the missing attributes are in the leaf node
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index 6048b8f5a3..be864f79d6 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -978,6 +978,21 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
("d", 1),
("c", 2)
).map(i => Row(i._1, i._2)))
+
+ checkAnswer(
+ sql(
+ """
+ |select area, sum(product) / sum(sum(product)) over (partition by area) as c1
+ |from windowData group by area, month order by month, c1
+ """.stripMargin),
+ Seq(
+ ("d", 1.0),
+ ("a", 1.0),
+ ("b", 0.4666666666666667),
+ ("b", 0.5333333333333333),
+ ("c", 0.45),
+ ("c", 0.55)
+ ).map(i => Row(i._1, i._2)))
}
// todo: fix this test case by reimplementing the function ResolveAggregateFunctions