aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorgatorsmile <gatorsmile@gmail.com>2016-02-01 11:57:13 -0800
committerMichael Armbrust <michael@databricks.com>2016-02-01 11:57:13 -0800
commit8f26eb5ef6853a6666d7d9481b333de70bc501ed (patch)
tree887eb1d86baf1e8d7fbef56e31e91ff0d253d1f0
parent33c8a490f7f64320c53530a57bd8d34916e3607c (diff)
downloadspark-8f26eb5ef6853a6666d7d9481b333de70bc501ed.tar.gz
spark-8f26eb5ef6853a6666d7d9481b333de70bc501ed.tar.bz2
spark-8f26eb5ef6853a6666d7d9481b333de70bc501ed.zip
[SPARK-12705][SPARK-10777][SQL] Analyzer Rule ResolveSortReferences
JIRA: https://issues.apache.org/jira/browse/SPARK-12705 **Scope:** This PR is a general fix for sorting reference resolution when the child's `outputSet` does not have the order-by attributes (called, *missing attributes*): - UnaryNode support is limited to `Project`, `Window`, `Aggregate`, `Distinct`, `Filter`, `RepartitionByExpression`. - We will not try to resolve the missing references inside a subquery, unless the outputSet of this subquery contains it. **General Reference Resolution Rules:** - Jump over the nodes with the following types: `Distinct`, `Filter`, `RepartitionByExpression`. Do not need to add missing attributes. The reason is their `outputSet` is decided by their `inputSet`, which is the `outputSet` of their children. - Group-by expressions in `Aggregate`: missing order-by attributes are not allowed to be added into group-by expressions since it will change the query result. Thus, in RDBMS, it is not allowed. - Aggregate expressions in `Aggregate`: if the group-by expressions in `Aggregate` contains the missing attributes but aggregate expressions do not have it, just add them into the aggregate expressions. This can resolve the analysisExceptions thrown by the three TCPDS queries. - `Project` and `Window` are special. We just need to add the missing attributes to their `projectList`. **Implementation:** 1. Traverse the whole tree in a pre-order manner to find all the resolvable missing order-by attributes. 2. Traverse the whole tree in a post-order manner to add the found missing order-by attributes to the node if their `inputSet` contains the attributes. 3. If the origins of the missing order-by attributes are different nodes, each pass only resolves the missing attributes that are from the same node. **Risk:** Low. This rule will be trigger iff ```!s.resolved && child.resolved``` is true. Thus, very few cases are affected. Author: gatorsmile <gatorsmile@gmail.com> Closes #10678 from gatorsmile/sortWindows.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala101
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala83
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala16
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala6
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala84
6 files changed, 274 insertions, 22 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index ee60fca1ad..a983dc1cdf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.analysis
+import scala.annotation.tailrec
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.AnalysisException
@@ -452,7 +453,7 @@ class Analyzer(
i.copy(right = dedupRight(left, right))
// When resolve `SortOrder`s in Sort based on child, don't report errors as
- // we still have chance to resolve it based on grandchild
+ // we still have chance to resolve it based on its descendants
case s @ Sort(ordering, global, child) if child.resolved && !s.resolved =>
val newOrdering = resolveSortOrders(ordering, child, throws = false)
Sort(newOrdering, global, child)
@@ -533,38 +534,96 @@ class Analyzer(
*/
object ResolveSortReferences extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
- case s @ Sort(ordering, global, p @ Project(projectList, child))
- if !s.resolved && p.resolved =>
- val (newOrdering, missing) = resolveAndFindMissing(ordering, p, child)
+ // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions
+ case sa @ Sort(_, _, child: Aggregate) => sa
- // If this rule was not a no-op, return the transformed plan, otherwise return the original.
- if (missing.nonEmpty) {
- // Add missing attributes and then project them away after the sort.
- Project(p.output,
- Sort(newOrdering, global,
- Project(projectList ++ missing, child)))
- } else {
- logDebug(s"Failed to find $missing in ${p.output.mkString(", ")}")
+ case s @ Sort(_, _, child) if !s.resolved && child.resolved =>
+ val (newOrdering, missingResolvableAttrs) = collectResolvableMissingAttrs(s.order, child)
+
+ if (missingResolvableAttrs.isEmpty) {
+ val unresolvableAttrs = s.order.filterNot(_.resolved)
+ logDebug(s"Failed to find $unresolvableAttrs in ${child.output.mkString(", ")}")
s // Nothing we can do here. Return original plan.
+ } else {
+ // Add the missing attributes into projectList of Project/Window or
+ // aggregateExpressions of Aggregate, if they are in the inputSet
+ // but not in the outputSet of the plan.
+ val newChild = child transformUp {
+ case p: Project =>
+ p.copy(projectList = p.projectList ++
+ missingResolvableAttrs.filter((p.inputSet -- p.outputSet).contains))
+ case w: Window =>
+ w.copy(projectList = w.projectList ++
+ missingResolvableAttrs.filter((w.inputSet -- w.outputSet).contains))
+ case a: Aggregate =>
+ val resolvableAttrs = missingResolvableAttrs.filter(a.groupingExpressions.contains)
+ val notResolvedAttrs = resolvableAttrs.filterNot(a.aggregateExpressions.contains)
+ val newAggregateExpressions = a.aggregateExpressions ++ notResolvedAttrs
+ a.copy(aggregateExpressions = newAggregateExpressions)
+ case o => o
+ }
+
+ // Add missing attributes and then project them away after the sort.
+ Project(child.output,
+ Sort(newOrdering, s.global, newChild))
}
}
/**
- * Given a child and a grandchild that are present beneath a sort operator, try to resolve
- * the sort ordering and returns it with a list of attributes that are missing from the
- * child but are present in the grandchild.
+ * Traverse the tree until resolving the sorting attributes
+ * Return all the resolvable missing sorting attributes
+ */
+ @tailrec
+ private def collectResolvableMissingAttrs(
+ ordering: Seq[SortOrder],
+ plan: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = {
+ plan match {
+ // Only Windows and Project have projectList-like attribute.
+ case un: UnaryNode if un.isInstanceOf[Project] || un.isInstanceOf[Window] =>
+ val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, un, un.child)
+ // If missingAttrs is non empty, that means we got it and return it;
+ // Otherwise, continue to traverse the tree.
+ if (missingAttrs.nonEmpty) {
+ (newOrdering, missingAttrs)
+ } else {
+ collectResolvableMissingAttrs(ordering, un.child)
+ }
+ case a: Aggregate =>
+ val (newOrdering, missingAttrs) = resolveAndFindMissing(ordering, a, a.child)
+ // For Aggregate, all the order by columns must be specified in group by clauses
+ if (missingAttrs.nonEmpty &&
+ missingAttrs.forall(ar => a.groupingExpressions.exists(_.semanticEquals(ar)))) {
+ (newOrdering, missingAttrs)
+ } else {
+ // If missingAttrs is empty, we are unable to resolve any unresolved missing attributes
+ (Seq.empty[SortOrder], Seq.empty[Attribute])
+ }
+ // Jump over the following UnaryNode types
+ // The output of these types is the same as their child's output
+ case _: Distinct |
+ _: Filter |
+ _: RepartitionByExpression =>
+ collectResolvableMissingAttrs(ordering, plan.asInstanceOf[UnaryNode].child)
+ // If hitting the other unsupported operators, we are unable to resolve it.
+ case other => (Seq.empty[SortOrder], Seq.empty[Attribute])
+ }
+ }
+
+ /**
+ * Try to resolve the sort ordering and returns it with a list of attributes that are missing
+ * from the plan but are present in the child.
*/
- def resolveAndFindMissing(
+ private def resolveAndFindMissing(
ordering: Seq[SortOrder],
- child: LogicalPlan,
- grandchild: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = {
- val newOrdering = resolveSortOrders(ordering, grandchild, throws = true)
+ plan: LogicalPlan,
+ child: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = {
+ val newOrdering = resolveSortOrders(ordering, child, throws = false)
// Construct a set that contains all of the attributes that we need to evaluate the
// ordering.
val requiredAttributes = AttributeSet(newOrdering).filter(_.resolved)
// Figure out which ones are missing from the projection, so that we can add them and
// remove them after the sort.
- val missingInProject = requiredAttributes -- child.output
+ val missingInProject = requiredAttributes -- plan.outputSet
// It is important to return the new SortOrders here, instead of waiting for the standard
// resolving process as adding attributes to the project below can actually introduce
// ambiguity that was not present before.
@@ -719,7 +778,7 @@ class Analyzer(
}
}
- protected def containsAggregate(condition: Expression): Boolean = {
+ def containsAggregate(condition: Expression): Boolean = {
condition.find(_.isInstanceOf[AggregateExpression]).isDefined
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 1938bce02a..ebf885a8fe 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -76,6 +76,89 @@ class AnalysisSuite extends AnalysisTest {
caseSensitive = false)
}
+ test("resolve sort references - filter/limit") {
+ val a = testRelation2.output(0)
+ val b = testRelation2.output(1)
+ val c = testRelation2.output(2)
+
+ // Case 1: one missing attribute is in the leaf node and another is in the unary node
+ val plan1 = testRelation2
+ .where('a > "str").select('a, 'b)
+ .where('b > "str").select('a)
+ .sortBy('b.asc, 'c.desc)
+ val expected1 = testRelation2
+ .where(a > "str").select(a, b, c)
+ .where(b > "str").select(a, b, c)
+ .sortBy(b.asc, c.desc)
+ .select(a, b).select(a)
+ checkAnalysis(plan1, expected1)
+
+ // Case 2: all the missing attributes are in the leaf node
+ val plan2 = testRelation2
+ .where('a > "str").select('a)
+ .where('a > "str").select('a)
+ .sortBy('b.asc, 'c.desc)
+ val expected2 = testRelation2
+ .where(a > "str").select(a, b, c)
+ .where(a > "str").select(a, b, c)
+ .sortBy(b.asc, c.desc)
+ .select(a)
+ checkAnalysis(plan2, expected2)
+ }
+
+ test("resolve sort references - join") {
+ val a = testRelation2.output(0)
+ val b = testRelation2.output(1)
+ val c = testRelation2.output(2)
+ val h = testRelation3.output(3)
+
+ // Case: join itself can resolve all the missing attributes
+ val plan = testRelation2.join(testRelation3)
+ .where('a > "str").select('a, 'b)
+ .sortBy('c.desc, 'h.asc)
+ val expected = testRelation2.join(testRelation3)
+ .where(a > "str").select(a, b, c, h)
+ .sortBy(c.desc, h.asc)
+ .select(a, b)
+ checkAnalysis(plan, expected)
+ }
+
+ test("resolve sort references - aggregate") {
+ val a = testRelation2.output(0)
+ val b = testRelation2.output(1)
+ val c = testRelation2.output(2)
+ val alias_a3 = count(a).as("a3")
+ val alias_b = b.as("aggOrder")
+
+ // Case 1: when the child of Sort is not Aggregate,
+ // the sort reference is handled by the rule ResolveSortReferences
+ val plan1 = testRelation2
+ .groupBy('a, 'c, 'b)('a, 'c, count('a).as("a3"))
+ .select('a, 'c, 'a3)
+ .orderBy('b.asc)
+
+ val expected1 = testRelation2
+ .groupBy(a, c, b)(a, c, alias_a3, b)
+ .select(a, c, alias_a3.toAttribute, b)
+ .orderBy(b.asc)
+ .select(a, c, alias_a3.toAttribute)
+
+ checkAnalysis(plan1, expected1)
+
+ // Case 2: when the child of Sort is Aggregate,
+ // the sort reference is handled by the rule ResolveAggregateFunctions
+ val plan2 = testRelation2
+ .groupBy('a, 'c, 'b)('a, 'c, count('a).as("a3"))
+ .orderBy('b.asc)
+
+ val expected2 = testRelation2
+ .groupBy(a, c, b)(a, c, alias_a3, alias_b)
+ .orderBy(alias_b.toAttribute.asc)
+ .select(a, c, alias_a3.toAttribute)
+
+ checkAnalysis(plan2, expected2)
+ }
+
test("resolve relations") {
assertAnalysisError(
UnresolvedRelation(TableIdentifier("tAbLe"), None), Seq("Table not found: tAbLe"))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala
index bc07b609a3..3741a6ba95 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala
@@ -31,6 +31,12 @@ object TestRelations {
AttributeReference("d", DecimalType(10, 2))(),
AttributeReference("e", ShortType)())
+ val testRelation3 = LocalRelation(
+ AttributeReference("e", ShortType)(),
+ AttributeReference("f", StringType)(),
+ AttributeReference("g", DoubleType)(),
+ AttributeReference("h", DecimalType(10, 2))())
+
val nestedRelation = LocalRelation(
AttributeReference("top", StructType(
StructField("duplicateField", StringType) ::
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
index c17be8ace9..a5e5f15642 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
@@ -42,6 +42,22 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
Row(1, 2, "1", "2") :: Row(2, 3, "2", "3") :: Row(3, 4, "3", "4") :: Nil)
}
+ test("join - sorted columns not in join's outputSet") {
+ val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str_sort").as('df1)
+ val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str").as('df2)
+ val df3 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str").as('df3)
+
+ checkAnswer(
+ df.join(df2, $"df1.int" === $"df2.int", "outer").select($"df1.int", $"df2.int2")
+ .orderBy('str_sort.asc, 'str.asc),
+ Row(null, 6) :: Row(1, 3) :: Row(3, null) :: Nil)
+
+ checkAnswer(
+ df2.join(df3, $"df2.int" === $"df3.int", "inner")
+ .select($"df2.int", $"df3.int").orderBy($"df2.str".desc),
+ Row(5, 5) :: Row(1, 1) :: Nil)
+ }
+
test("join - join using multiple columns and specifying join type") {
val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str")
val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 4ff99bdf29..c02133ffc8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -954,6 +954,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
assert(expected === actual)
}
+ test("Sorting columns are not in Filter and Project") {
+ checkAnswer(
+ upperCaseData.filter('N > 1).select('N).filter('N < 6).orderBy('L.asc),
+ Row(2) :: Row(3) :: Row(4) :: Row(5) :: Nil)
+ }
+
test("SPARK-9323: DataFrame.orderBy should support nested column name") {
val df = sqlContext.read.json(sparkContext.makeRDD(
"""{"a": {"b": 1}}""" :: Nil))
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index 1ada2e325b..6048b8f5a3 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -736,7 +736,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
""".stripMargin), (2 to 6).map(i => Row(i)))
}
- test("window function: udaf with aggregate expressin") {
+ test("window function: udaf with aggregate expression") {
val data = Seq(
WindowData(1, "a", 5),
WindowData(2, "a", 6),
@@ -927,6 +927,88 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
).map(i => Row(i._1, i._2, i._3, i._4)))
}
+ test("window function: Sorting columns are not in Project") {
+ val data = Seq(
+ WindowData(1, "d", 10),
+ WindowData(2, "a", 6),
+ WindowData(3, "b", 7),
+ WindowData(4, "b", 8),
+ WindowData(5, "c", 9),
+ WindowData(6, "c", 11)
+ )
+ sparkContext.parallelize(data).toDF().registerTempTable("windowData")
+
+ checkAnswer(
+ sql("select month, product, sum(product + 1) over() from windowData order by area"),
+ Seq(
+ (2, 6, 57),
+ (3, 7, 57),
+ (4, 8, 57),
+ (5, 9, 57),
+ (6, 11, 57),
+ (1, 10, 57)
+ ).map(i => Row(i._1, i._2, i._3)))
+
+ checkAnswer(
+ sql(
+ """
+ |select area, rank() over (partition by area order by tmp.month) + tmp.tmp1 as c1
+ |from (select month, area, product as p, 1 as tmp1 from windowData) tmp order by p
+ """.stripMargin),
+ Seq(
+ ("a", 2),
+ ("b", 2),
+ ("b", 3),
+ ("c", 2),
+ ("d", 2),
+ ("c", 3)
+ ).map(i => Row(i._1, i._2)))
+
+ checkAnswer(
+ sql(
+ """
+ |select area, rank() over (partition by area order by month) as c1
+ |from windowData group by product, area, month order by product, area
+ """.stripMargin),
+ Seq(
+ ("a", 1),
+ ("b", 1),
+ ("b", 2),
+ ("c", 1),
+ ("d", 1),
+ ("c", 2)
+ ).map(i => Row(i._1, i._2)))
+ }
+
+ // todo: fix this test case by reimplementing the function ResolveAggregateFunctions
+ ignore("window function: Pushing aggregate Expressions in Sort to Aggregate") {
+ val data = Seq(
+ WindowData(1, "d", 10),
+ WindowData(2, "a", 6),
+ WindowData(3, "b", 7),
+ WindowData(4, "b", 8),
+ WindowData(5, "c", 9),
+ WindowData(6, "c", 11)
+ )
+ sparkContext.parallelize(data).toDF().registerTempTable("windowData")
+
+ checkAnswer(
+ sql(
+ """
+ |select area, sum(product) over () as c from windowData
+ |where product > 3 group by area, product
+ |having avg(month) > 0 order by avg(month), product
+ """.stripMargin),
+ Seq(
+ ("a", 51),
+ ("b", 51),
+ ("b", 51),
+ ("c", 51),
+ ("c", 51),
+ ("d", 51)
+ ).map(i => Row(i._1, i._2)))
+ }
+
test("window function: multiple window expressions in a single expression") {
val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y")
nums.registerTempTable("nums")