aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-06-17 14:46:00 -0700
committerMichael Armbrust <michael@databricks.com>2015-06-17 14:46:00 -0700
commit7f05b1fe696daa28fee514c9aef805be5913cfcd (patch)
treeceeeb3e7afd2227f290cb51d3c28cc1f477e3cba /sql
parenta411a40de2209c56e898e3fb4af955d7b55af11c (diff)
downloadspark-7f05b1fe696daa28fee514c9aef805be5913cfcd.tar.gz
spark-7f05b1fe696daa28fee514c9aef805be5913cfcd.tar.bz2
spark-7f05b1fe696daa28fee514c9aef805be5913cfcd.zip
[SPARK-7067] [SQL] fix bug when use complex nested fields in ORDER BY
This PR is a improvement for https://github.com/apache/spark/pull/5189. The resolution rule for ORDER BY is: first resolve based on what comes from the select clause and then fall back on its child only when this fails. There are 2 steps. First, try to resolve `Sort` in `ResolveReferences` based on select clause, and ignore exceptions. Second, try to resolve `Sort` in `ResolveSortReferences` and add missing projection. However, the way we resolve `SortOrder` is wrong. We just resolve `UnresolvedAttribute` and use the result to indicate if we can resolve `SortOrder`. But `UnresolvedAttribute` is only part of `GetField` chain(broken by `GetItem`), so we need to go through the whole chain to indicate if we can resolve `SortOrder`. With this change, we can also avoid re-throw GetField exception in `CheckAnalysis` which is little ugly. Author: Wenchen Fan <cloud0fan@outlook.com> Closes #5659 from cloud-fan/order-by and squashes the following commits: cfa79f8 [Wenchen Fan] update test 3245d28 [Wenchen Fan] minor improve 465ee07 [Wenchen Fan] address comment 1fc41a2 [Wenchen Fan] fix SPARK-7067
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala75
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala43
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala8
5 files changed, 70 insertions, 66 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index badf903478..21b0576025 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -336,9 +336,15 @@ class Analyzer(
}
j.copy(right = newRight)
+ // When resolve `SortOrder`s in Sort based on child, don't report errors as
+ // we still have chance to resolve it based on grandchild
+ case s @ Sort(ordering, global, child) if child.resolved && !s.resolved =>
+ val newOrdering = resolveSortOrders(ordering, child, throws = false)
+ Sort(newOrdering, global, child)
+
case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString}")
- q transformExpressionsUp {
+ q transformExpressionsUp {
case u @ UnresolvedAttribute(nameParts) if nameParts.length == 1 &&
resolver(nameParts(0), VirtualColumn.groupingIdName) &&
q.isInstanceOf[GroupingAnalytics] =>
@@ -373,6 +379,26 @@ class Analyzer(
exprs.exists(_.collect { case _: Star => true }.nonEmpty)
}
+ private def resolveSortOrders(ordering: Seq[SortOrder], plan: LogicalPlan, throws: Boolean) = {
+ ordering.map { order =>
+ // Resolve SortOrder in one round.
+ // If throws == false or the desired attribute doesn't exist
+ // (like try to resolve `a.b` but `a` doesn't exist), fail and return the origin one.
+ // Else, throw exception.
+ try {
+ val newOrder = order transformUp {
+ case u @ UnresolvedAttribute(nameParts) =>
+ plan.resolve(nameParts, resolver).getOrElse(u)
+ case UnresolvedExtractValue(child, fieldName) if child.resolved =>
+ ExtractValue(child, fieldName, resolver)
+ }
+ newOrder.asInstanceOf[SortOrder]
+ } catch {
+ case a: AnalysisException if !throws => order
+ }
+ }
+ }
+
/**
* In many dialects of SQL it is valid to sort by attributes that are not present in the SELECT
* clause. This rule detects such queries and adds the required attributes to the original
@@ -383,13 +409,13 @@ class Analyzer(
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case s @ Sort(ordering, global, p @ Project(projectList, child))
if !s.resolved && p.resolved =>
- val (resolvedOrdering, missing) = resolveAndFindMissing(ordering, p, child)
+ val (newOrdering, missing) = resolveAndFindMissing(ordering, p, child)
// If this rule was not a no-op, return the transformed plan, otherwise return the original.
if (missing.nonEmpty) {
// Add missing attributes and then project them away after the sort.
Project(p.output,
- Sort(resolvedOrdering, global,
+ Sort(newOrdering, global,
Project(projectList ++ missing, child)))
} else {
logDebug(s"Failed to find $missing in ${p.output.mkString(", ")}")
@@ -404,19 +430,19 @@ class Analyzer(
)
// Find sort attributes that are projected away so we can temporarily add them back in.
- val (resolvedOrdering, unresolved) = resolveAndFindMissing(ordering, a, groupingRelation)
+ val (newOrdering, missingAttr) = resolveAndFindMissing(ordering, a, groupingRelation)
// Find aggregate expressions and evaluate them early, since they can't be evaluated in a
// Sort.
- val (withAggsRemoved, aliasedAggregateList) = resolvedOrdering.map {
+ val (withAggsRemoved, aliasedAggregateList) = newOrdering.map {
case aggOrdering if aggOrdering.collect { case a: AggregateExpression => a }.nonEmpty =>
val aliased = Alias(aggOrdering.child, "_aggOrdering")()
- (aggOrdering.copy(child = aliased.toAttribute), aliased :: Nil)
+ (aggOrdering.copy(child = aliased.toAttribute), Some(aliased))
- case other => (other, Nil)
+ case other => (other, None)
}.unzip
- val missing = unresolved ++ aliasedAggregateList.flatten
+ val missing = missingAttr ++ aliasedAggregateList.flatten
if (missing.nonEmpty) {
// Add missing grouping exprs and then project them away after the sort.
@@ -429,40 +455,25 @@ class Analyzer(
}
/**
- * Given a child and a grandchild that are present beneath a sort operator, returns
- * a resolved sort ordering and a list of attributes that are missing from the child
- * but are present in the grandchild.
+ * Given a child and a grandchild that are present beneath a sort operator, try to resolve
+ * the sort ordering and returns it with a list of attributes that are missing from the
+ * child but are present in the grandchild.
*/
def resolveAndFindMissing(
ordering: Seq[SortOrder],
child: LogicalPlan,
grandchild: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = {
- // Find any attributes that remain unresolved in the sort.
- val unresolved: Seq[Seq[String]] =
- ordering.flatMap(_.collect { case UnresolvedAttribute(nameParts) => nameParts })
-
- // Create a map from name, to resolved attributes, when the desired name can be found
- // prior to the projection.
- val resolved: Map[Seq[String], NamedExpression] =
- unresolved.flatMap(u => grandchild.resolve(u, resolver).map(a => u -> a)).toMap
-
+ val newOrdering = resolveSortOrders(ordering, grandchild, throws = true)
// Construct a set that contains all of the attributes that we need to evaluate the
// ordering.
- val requiredAttributes = AttributeSet(resolved.values)
-
+ val requiredAttributes = AttributeSet(newOrdering.filter(_.resolved))
// Figure out which ones are missing from the projection, so that we can add them and
// remove them after the sort.
val missingInProject = requiredAttributes -- child.output
-
- // Now that we have all the attributes we need, reconstruct a resolved ordering.
- // It is important to do it here, instead of waiting for the standard resolved as adding
- // attributes to the project below can actually introduce ambiquity that was not present
- // before.
- val resolvedOrdering = ordering.map(_ transform {
- case u @ UnresolvedAttribute(name) => resolved.getOrElse(name, u)
- }).asInstanceOf[Seq[SortOrder]]
-
- (resolvedOrdering, missingInProject.toSeq)
+ // It is important to return the new SortOrders here, instead of waiting for the standard
+ // resolving process as adding attributes to the project below can actually introduce
+ // ambiguity that was not present before.
+ (newOrdering, missingInProject.toSeq)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index c0695ae369..7fabd2bfc8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -51,14 +51,6 @@ trait CheckAnalysis {
case operator: LogicalPlan =>
operator transformExpressionsUp {
case a: Attribute if !a.resolved =>
- if (operator.childrenResolved) {
- a match {
- case UnresolvedAttribute(nameParts) =>
- // Throw errors for specific problems with get field.
- operator.resolveChildren(nameParts, resolver, throwErrors = true)
- }
- }
-
val from = operator.inputSet.map(_.name).mkString(", ")
a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index c8c6676f24..a853e27c12 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -50,19 +50,19 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
* [[org.apache.spark.sql.catalyst.analysis.UnresolvedRelation UnresolvedRelation]]
* should return `false`).
*/
- lazy val resolved: Boolean = !expressions.exists(!_.resolved) && childrenResolved
+ lazy val resolved: Boolean = expressions.forall(_.resolved) && childrenResolved
override protected def statePrefix = if (!resolved) "'" else super.statePrefix
/**
* Returns true if all its children of this query plan have been resolved.
*/
- def childrenResolved: Boolean = !children.exists(!_.resolved)
+ def childrenResolved: Boolean = children.forall(_.resolved)
/**
* Returns true when the given logical plan will return the same results as this logical plan.
*
- * Since its likely undecideable to generally determine if two given plans will produce the same
+ * Since its likely undecidable to generally determine if two given plans will produce the same
* results, it is okay for this function to return false, even if the results are actually
* the same. Such behavior will not affect correctness, only the application of performance
* enhancements like caching. However, it is not acceptable to return true if the results could
@@ -111,9 +111,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
*/
def resolveChildren(
nameParts: Seq[String],
- resolver: Resolver,
- throwErrors: Boolean = false): Option[NamedExpression] =
- resolve(nameParts, children.flatMap(_.output), resolver, throwErrors)
+ resolver: Resolver): Option[NamedExpression] =
+ resolve(nameParts, children.flatMap(_.output), resolver)
/**
* Optionally resolves the given strings to a [[NamedExpression]] based on the output of this
@@ -122,9 +121,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
*/
def resolve(
nameParts: Seq[String],
- resolver: Resolver,
- throwErrors: Boolean = false): Option[NamedExpression] =
- resolve(nameParts, output, resolver, throwErrors)
+ resolver: Resolver): Option[NamedExpression] =
+ resolve(nameParts, output, resolver)
/**
* Given an attribute name, split it to name parts by dot, but
@@ -134,7 +132,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
def resolveQuoted(
name: String,
resolver: Resolver): Option[NamedExpression] = {
- resolve(parseAttributeName(name), resolver, true)
+ resolve(parseAttributeName(name), output, resolver)
}
/**
@@ -219,8 +217,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
protected def resolve(
nameParts: Seq[String],
input: Seq[Attribute],
- resolver: Resolver,
- throwErrors: Boolean): Option[NamedExpression] = {
+ resolver: Resolver): Option[NamedExpression] = {
// A sequence of possible candidate matches.
// Each candidate is a tuple. The first element is a resolved attribute, followed by a list
@@ -254,19 +251,15 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
// One match, but we also need to extract the requested nested field.
case Seq((a, nestedFields)) =>
- try {
- // The foldLeft adds GetFields for every remaining parts of the identifier,
- // and aliases it with the last part of the identifier.
- // For example, consider "a.b.c", where "a" is resolved to an existing attribute.
- // Then this will add GetField("c", GetField("b", a)), and alias
- // the final expression as "c".
- val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) =>
- ExtractValue(expr, Literal(fieldName), resolver))
- val aliasName = nestedFields.last
- Some(Alias(fieldExprs, aliasName)())
- } catch {
- case a: AnalysisException if !throwErrors => None
- }
+ // The foldLeft adds ExtractValues for every remaining parts of the identifier,
+ // and aliases it with the last part of the identifier.
+ // For example, consider "a.b.c", where "a" is resolved to an existing attribute.
+ // Then this will add ExtractValue("c", ExtractValue("b", a)), and alias
+ // the final expression as "c".
+ val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) =>
+ ExtractValue(expr, Literal(fieldName), resolver))
+ val aliasName = nestedFields.last
+ Some(Alias(fieldExprs, aliasName)())
// No matches.
case Seq() =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index f304597bc9..09f6c6b0ec 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -285,7 +285,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
* @param rule the function use to transform this nodes children
*/
def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = {
- val afterRuleOnChildren = transformChildrenUp(rule);
+ val afterRuleOnChildren = transformChildrenUp(rule)
if (this fastEquals afterRuleOnChildren) {
CurrentOrigin.withOrigin(origin) {
rule.applyOrElse(this, identity[BaseType])
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 1a6ee8169c..30db840166 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1440,4 +1440,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
checkAnswer(sql("select i <=> b from t"), sql("select r2 from t"))
}
}
+
+ test("SPARK-7067: order by queries for complex ExtractValue chain") {
+ withTempTable("t") {
+ sqlContext.read.json(sqlContext.sparkContext.makeRDD(
+ """{"a": {"b": [{"c": 1}]}, "b": [{"d": 1}]}""" :: Nil)).registerTempTable("t")
+ checkAnswer(sql("SELECT a.b FROM t ORDER BY b[0].d"), Row(Seq(Row(1))))
+ }
+ }
}