aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src/main
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2015-03-31 11:23:18 -0700
committerMichael Armbrust <michael@databricks.com>2015-03-31 11:23:18 -0700
commitcd48ca50129e8952f487051796244e7569275416 (patch)
tree86e5163226596544e711fb253e0ce76d7aeeacc4 /sql/catalyst/src/main
parent81020144708773ba3af4932288ffa09ef901269e (diff)
downloadspark-cd48ca50129e8952f487051796244e7569275416.tar.gz
spark-cd48ca50129e8952f487051796244e7569275416.tar.bz2
spark-cd48ca50129e8952f487051796244e7569275416.zip
[SPARK-6145][SQL] fix ORDER BY on nested fields
This PR is based on work by cloud-fan in #4904, but with two differences: - We isolate the logic for Sort's special handling into `ResolveSortReferences` - We avoid creating UnresolvedGetField expressions during resolution. Instead we either resolve GetField or we return None. This avoids us going down the wrong path early on. Author: Michael Armbrust <michael@databricks.com> Closes #5189 from marmbrus/nestedOrderBy and squashes the following commits: b8cae45 [Michael Armbrust] fix another test 0f36a11 [Michael Armbrust] WIP 91820cd [Michael Armbrust] Fix bug.
Diffstat (limited to 'sql/catalyst/src/main')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala76
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala76
4 files changed, 127 insertions, 39 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index dc14f49e6e..c578d084a4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -37,11 +37,12 @@ object SimpleAnalyzer extends Analyzer(EmptyCatalog, EmptyFunctionRegistry, true
* [[UnresolvedRelation]]s into fully typed objects using information in a schema [[Catalog]] and
* a [[FunctionRegistry]].
*/
-class Analyzer(catalog: Catalog,
- registry: FunctionRegistry,
- caseSensitive: Boolean,
- maxIterations: Int = 100)
- extends RuleExecutor[LogicalPlan] with HiveTypeCoercion {
+class Analyzer(
+ catalog: Catalog,
+ registry: FunctionRegistry,
+ caseSensitive: Boolean,
+ maxIterations: Int = 100)
+ extends RuleExecutor[LogicalPlan] with HiveTypeCoercion with CheckAnalysis {
val resolver = if (caseSensitive) caseSensitiveResolution else caseInsensitiveResolution
@@ -354,19 +355,16 @@ class Analyzer(catalog: Catalog,
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case s @ Sort(ordering, global, p @ Project(projectList, child))
if !s.resolved && p.resolved =>
- val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name })
- val resolved = unresolved.flatMap(child.resolve(_, resolver))
- val requiredAttributes =
- AttributeSet(resolved.flatMap(_.collect { case a: Attribute => a }))
+ val (resolvedOrdering, missing) = resolveAndFindMissing(ordering, p, child)
- val missingInProject = requiredAttributes -- p.output
- if (missingInProject.nonEmpty) {
+ // If this rule was not a no-op, return the transformed plan, otherwise return the original.
+ if (missing.nonEmpty) {
// Add missing attributes and then project them away after the sort.
- Project(projectList.map(_.toAttribute),
- Sort(ordering, global,
- Project(projectList ++ missingInProject, child)))
+ Project(p.output,
+ Sort(resolvedOrdering, global,
+ Project(projectList ++ missing, child)))
} else {
- logDebug(s"Failed to find $missingInProject in ${p.output.mkString(", ")}")
+ logDebug(s"Failed to find $missing in ${p.output.mkString(", ")}")
s // Nothing we can do here. Return original plan.
}
case s @ Sort(ordering, global, a @ Aggregate(grouping, aggs, child))
@@ -378,18 +376,54 @@ class Analyzer(catalog: Catalog,
grouping.collect { case ne: NamedExpression => ne.toAttribute }
)
- logDebug(s"Grouping expressions: $groupingRelation")
- val resolved = unresolved.flatMap(groupingRelation.resolve(_, resolver))
- val missingInAggs = resolved.filterNot(a.outputSet.contains)
- logDebug(s"Resolved: $resolved Missing in aggs: $missingInAggs")
- if (missingInAggs.nonEmpty) {
+ val (resolvedOrdering, missing) = resolveAndFindMissing(ordering, a, groupingRelation)
+
+ if (missing.nonEmpty) {
// Add missing grouping exprs and then project them away after the sort.
Project(a.output,
- Sort(ordering, global, Aggregate(grouping, aggs ++ missingInAggs, child)))
+ Sort(resolvedOrdering, global,
+ Aggregate(grouping, aggs ++ missing, child)))
} else {
s // Nothing we can do here. Return original plan.
}
}
+
+ /**
+ * Given a child and a grandchild that are present beneath a sort operator, returns
+ * a resolved sort ordering and a list of attributes that are missing from the child
+ * but are present in the grandchild.
+ */
+ def resolveAndFindMissing(
+ ordering: Seq[SortOrder],
+ child: LogicalPlan,
+ grandchild: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = {
+ // Find any attributes that remain unresolved in the sort.
+ val unresolved: Seq[String] =
+ ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name })
+
+ // Create a map from name, to resolved attributes, when the desired name can be found
+ // prior to the projection.
+ val resolved: Map[String, NamedExpression] =
+ unresolved.flatMap(u => grandchild.resolve(u, resolver).map(a => u -> a)).toMap
+
+ // Construct a set that contains all of the attributes that we need to evaluate the
+ // ordering.
+ val requiredAttributes = AttributeSet(resolved.values)
+
+ // Figure out which ones are missing from the projection, so that we can add them and
+ // remove them after the sort.
+ val missingInProject = requiredAttributes -- child.output
+
+ // Now that we have all the attributes we need, reconstruct a resolved ordering.
+ // It is important to do it here, instead of waiting for the standard resolved as adding
+ // attributes to the project below can actually introduce ambiquity that was not present
+ // before.
+ val resolvedOrdering = ordering.map(_ transform {
+ case u @ UnresolvedAttribute(name) => resolved.getOrElse(name, u)
+ }).asInstanceOf[Seq[SortOrder]]
+
+ (resolvedOrdering, missingInProject.toSeq)
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 40472a1cbb..fa02111385 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -25,7 +25,8 @@ import org.apache.spark.sql.types._
/**
* Throws user facing errors when passed invalid queries that fail to analyze.
*/
-class CheckAnalysis {
+trait CheckAnalysis {
+ self: Analyzer =>
/**
* Override to provide additional checks for correct analysis.
@@ -33,17 +34,22 @@ class CheckAnalysis {
*/
val extendedCheckRules: Seq[LogicalPlan => Unit] = Nil
- def failAnalysis(msg: String): Nothing = {
+ protected def failAnalysis(msg: String): Nothing = {
throw new AnalysisException(msg)
}
- def apply(plan: LogicalPlan): Unit = {
+ def checkAnalysis(plan: LogicalPlan): Unit = {
// We transform up and order the rules so as to catch the first possible failure instead
// of the result of cascading resolution failures.
plan.foreachUp {
case operator: LogicalPlan =>
operator transformExpressionsUp {
case a: Attribute if !a.resolved =>
+ if (operator.childrenResolved) {
+ // Throw errors for specific problems with get field.
+ operator.resolveChildren(a.name, resolver, throwErrors = true)
+ }
+
val from = operator.inputSet.map(_.name).mkString(", ")
a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
index 11b4eb5c88..5345696570 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
@@ -34,7 +34,7 @@ object AttributeSet {
def apply(a: Attribute): AttributeSet = new AttributeSet(Set(new AttributeEquals(a)))
/** Constructs a new [[AttributeSet]] given a sequence of [[Expression Expressions]]. */
- def apply(baseSet: Seq[Expression]): AttributeSet = {
+ def apply(baseSet: Iterable[Expression]): AttributeSet = {
new AttributeSet(
baseSet
.flatMap(_.references)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index b01a61d7bf..2e9f3aa4ec 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.trees
+import org.apache.spark.sql.types.{ArrayType, StructType, StructField}
abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
@@ -109,16 +110,22 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
* nodes of this LogicalPlan. The attribute is expressed as
* as string in the following form: `[scope].AttributeName.[nested].[fields]...`.
*/
- def resolveChildren(name: String, resolver: Resolver): Option[NamedExpression] =
- resolve(name, children.flatMap(_.output), resolver)
+ def resolveChildren(
+ name: String,
+ resolver: Resolver,
+ throwErrors: Boolean = false): Option[NamedExpression] =
+ resolve(name, children.flatMap(_.output), resolver, throwErrors)
/**
* Optionally resolves the given string to a [[NamedExpression]] based on the output of this
* LogicalPlan. The attribute is expressed as string in the following form:
* `[scope].AttributeName.[nested].[fields]...`.
*/
- def resolve(name: String, resolver: Resolver): Option[NamedExpression] =
- resolve(name, output, resolver)
+ def resolve(
+ name: String,
+ resolver: Resolver,
+ throwErrors: Boolean = false): Option[NamedExpression] =
+ resolve(name, output, resolver, throwErrors)
/**
* Resolve the given `name` string against the given attribute, returning either 0 or 1 match.
@@ -162,7 +169,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
protected def resolve(
name: String,
input: Seq[Attribute],
- resolver: Resolver): Option[NamedExpression] = {
+ resolver: Resolver,
+ throwErrors: Boolean): Option[NamedExpression] = {
val parts = name.split("\\.")
@@ -196,14 +204,19 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
// One match, but we also need to extract the requested nested field.
case Seq((a, nestedFields)) =>
- // The foldLeft adds UnresolvedGetField for every remaining parts of the name,
- // and aliased it with the last part of the name.
- // For example, consider name "a.b.c", where "a" is resolved to an existing attribute.
- // Then this will add UnresolvedGetField("b") and UnresolvedGetField("c"), and alias
- // the final expression as "c".
- val fieldExprs = nestedFields.foldLeft(a: Expression)(UnresolvedGetField)
- val aliasName = nestedFields.last
- Some(Alias(fieldExprs, aliasName)())
+ try {
+
+ // The foldLeft adds UnresolvedGetField for every remaining parts of the name,
+ // and aliased it with the last part of the name.
+ // For example, consider name "a.b.c", where "a" is resolved to an existing attribute.
+ // Then this will add UnresolvedGetField("b") and UnresolvedGetField("c"), and alias
+ // the final expression as "c".
+ val fieldExprs = nestedFields.foldLeft(a: Expression)(resolveGetField(_, _, resolver))
+ val aliasName = nestedFields.last
+ Some(Alias(fieldExprs, aliasName)())
+ } catch {
+ case a: AnalysisException if !throwErrors => None
+ }
// No matches.
case Seq() =>
@@ -212,11 +225,46 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
// More than one match.
case ambiguousReferences =>
- val referenceNames = ambiguousReferences.map(_._1.qualifiedName).mkString(", ")
+ val referenceNames = ambiguousReferences.map(_._1).mkString(", ")
throw new AnalysisException(
s"Reference '$name' is ambiguous, could be: $referenceNames.")
}
}
+
+ /**
+ * Returns the resolved `GetField`, and report error if no desired field or over one
+ * desired fields are found.
+ *
+ * TODO: this code is duplicated from Analyzer and should be refactored to avoid this.
+ */
+ protected def resolveGetField(
+ expr: Expression,
+ fieldName: String,
+ resolver: Resolver): Expression = {
+ def findField(fields: Array[StructField]): Int = {
+ val checkField = (f: StructField) => resolver(f.name, fieldName)
+ val ordinal = fields.indexWhere(checkField)
+ if (ordinal == -1) {
+ throw new AnalysisException(
+ s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}")
+ } else if (fields.indexWhere(checkField, ordinal + 1) != -1) {
+ throw new AnalysisException(
+ s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}")
+ } else {
+ ordinal
+ }
+ }
+ expr.dataType match {
+ case StructType(fields) =>
+ val ordinal = findField(fields)
+ StructGetField(expr, fields(ordinal), ordinal)
+ case ArrayType(StructType(fields), containsNull) =>
+ val ordinal = findField(fields)
+ ArrayGetField(expr, fields(ordinal), ordinal, containsNull)
+ case otherType =>
+ throw new AnalysisException(s"GetField is not valid on fields of type $otherType")
+ }
+ }
}
/**