aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2015-03-31 11:23:18 -0700
committerMichael Armbrust <michael@databricks.com>2015-03-31 11:23:18 -0700
commitcd48ca50129e8952f487051796244e7569275416 (patch)
tree86e5163226596544e711fb253e0ce76d7aeeacc4 /sql
parent81020144708773ba3af4932288ffa09ef901269e (diff)
downloadspark-cd48ca50129e8952f487051796244e7569275416.tar.gz
spark-cd48ca50129e8952f487051796244e7569275416.tar.bz2
spark-cd48ca50129e8952f487051796244e7569275416.zip
[SPARK-6145][SQL] fix ORDER BY on nested fields
This PR is based on work by cloud-fan in #4904, but with two differences: - We isolate the logic for Sort's special handling into `ResolveSortReferences` - We avoid creating UnresolvedGetField expressions during resolution. Instead we either resolve GetField or we return None. This avoids us going down the wrong path early on. Author: Michael Armbrust <michael@databricks.com> Closes #5189 from marmbrus/nestedOrderBy and squashes the following commits: b8cae45 [Michael Armbrust] fix another test 0f36a11 [Michael Armbrust] WIP 91820cd [Michael Armbrust] Fix bug.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala76
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala76
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala39
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala14
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala19
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala4
8 files changed, 185 insertions, 57 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index dc14f49e6e..c578d084a4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -37,11 +37,12 @@ object SimpleAnalyzer extends Analyzer(EmptyCatalog, EmptyFunctionRegistry, true
* [[UnresolvedRelation]]s into fully typed objects using information in a schema [[Catalog]] and
* a [[FunctionRegistry]].
*/
-class Analyzer(catalog: Catalog,
- registry: FunctionRegistry,
- caseSensitive: Boolean,
- maxIterations: Int = 100)
- extends RuleExecutor[LogicalPlan] with HiveTypeCoercion {
+class Analyzer(
+ catalog: Catalog,
+ registry: FunctionRegistry,
+ caseSensitive: Boolean,
+ maxIterations: Int = 100)
+ extends RuleExecutor[LogicalPlan] with HiveTypeCoercion with CheckAnalysis {
val resolver = if (caseSensitive) caseSensitiveResolution else caseInsensitiveResolution
@@ -354,19 +355,16 @@ class Analyzer(catalog: Catalog,
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case s @ Sort(ordering, global, p @ Project(projectList, child))
if !s.resolved && p.resolved =>
- val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name })
- val resolved = unresolved.flatMap(child.resolve(_, resolver))
- val requiredAttributes =
- AttributeSet(resolved.flatMap(_.collect { case a: Attribute => a }))
+ val (resolvedOrdering, missing) = resolveAndFindMissing(ordering, p, child)
- val missingInProject = requiredAttributes -- p.output
- if (missingInProject.nonEmpty) {
+ // If this rule was not a no-op, return the transformed plan, otherwise return the original.
+ if (missing.nonEmpty) {
// Add missing attributes and then project them away after the sort.
- Project(projectList.map(_.toAttribute),
- Sort(ordering, global,
- Project(projectList ++ missingInProject, child)))
+ Project(p.output,
+ Sort(resolvedOrdering, global,
+ Project(projectList ++ missing, child)))
} else {
- logDebug(s"Failed to find $missingInProject in ${p.output.mkString(", ")}")
+ logDebug(s"Failed to find $missing in ${p.output.mkString(", ")}")
s // Nothing we can do here. Return original plan.
}
case s @ Sort(ordering, global, a @ Aggregate(grouping, aggs, child))
@@ -378,18 +376,54 @@ class Analyzer(catalog: Catalog,
grouping.collect { case ne: NamedExpression => ne.toAttribute }
)
- logDebug(s"Grouping expressions: $groupingRelation")
- val resolved = unresolved.flatMap(groupingRelation.resolve(_, resolver))
- val missingInAggs = resolved.filterNot(a.outputSet.contains)
- logDebug(s"Resolved: $resolved Missing in aggs: $missingInAggs")
- if (missingInAggs.nonEmpty) {
+ val (resolvedOrdering, missing) = resolveAndFindMissing(ordering, a, groupingRelation)
+
+ if (missing.nonEmpty) {
// Add missing grouping exprs and then project them away after the sort.
Project(a.output,
- Sort(ordering, global, Aggregate(grouping, aggs ++ missingInAggs, child)))
+ Sort(resolvedOrdering, global,
+ Aggregate(grouping, aggs ++ missing, child)))
} else {
s // Nothing we can do here. Return original plan.
}
}
+
+ /**
+ * Given a child and a grandchild that are present beneath a sort operator, returns
+ * a resolved sort ordering and a list of attributes that are missing from the child
+ * but are present in the grandchild.
+ */
+ def resolveAndFindMissing(
+ ordering: Seq[SortOrder],
+ child: LogicalPlan,
+ grandchild: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = {
+ // Find any attributes that remain unresolved in the sort.
+ val unresolved: Seq[String] =
+ ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name })
+
+ // Create a map from name, to resolved attributes, when the desired name can be found
+ // prior to the projection.
+ val resolved: Map[String, NamedExpression] =
+ unresolved.flatMap(u => grandchild.resolve(u, resolver).map(a => u -> a)).toMap
+
+ // Construct a set that contains all of the attributes that we need to evaluate the
+ // ordering.
+ val requiredAttributes = AttributeSet(resolved.values)
+
+ // Figure out which ones are missing from the projection, so that we can add them and
+ // remove them after the sort.
+ val missingInProject = requiredAttributes -- child.output
+
+ // Now that we have all the attributes we need, reconstruct a resolved ordering.
+ // It is important to do it here, instead of waiting for the standard resolved as adding
+ // attributes to the project below can actually introduce ambiquity that was not present
+ // before.
+ val resolvedOrdering = ordering.map(_ transform {
+ case u @ UnresolvedAttribute(name) => resolved.getOrElse(name, u)
+ }).asInstanceOf[Seq[SortOrder]]
+
+ (resolvedOrdering, missingInProject.toSeq)
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 40472a1cbb..fa02111385 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -25,7 +25,8 @@ import org.apache.spark.sql.types._
/**
* Throws user facing errors when passed invalid queries that fail to analyze.
*/
-class CheckAnalysis {
+trait CheckAnalysis {
+ self: Analyzer =>
/**
* Override to provide additional checks for correct analysis.
@@ -33,17 +34,22 @@ class CheckAnalysis {
*/
val extendedCheckRules: Seq[LogicalPlan => Unit] = Nil
- def failAnalysis(msg: String): Nothing = {
+ protected def failAnalysis(msg: String): Nothing = {
throw new AnalysisException(msg)
}
- def apply(plan: LogicalPlan): Unit = {
+ def checkAnalysis(plan: LogicalPlan): Unit = {
// We transform up and order the rules so as to catch the first possible failure instead
// of the result of cascading resolution failures.
plan.foreachUp {
case operator: LogicalPlan =>
operator transformExpressionsUp {
case a: Attribute if !a.resolved =>
+ if (operator.childrenResolved) {
+ // Throw errors for specific problems with get field.
+ operator.resolveChildren(a.name, resolver, throwErrors = true)
+ }
+
val from = operator.inputSet.map(_.name).mkString(", ")
a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
index 11b4eb5c88..5345696570 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
@@ -34,7 +34,7 @@ object AttributeSet {
def apply(a: Attribute): AttributeSet = new AttributeSet(Set(new AttributeEquals(a)))
/** Constructs a new [[AttributeSet]] given a sequence of [[Expression Expressions]]. */
- def apply(baseSet: Seq[Expression]): AttributeSet = {
+ def apply(baseSet: Iterable[Expression]): AttributeSet = {
new AttributeSet(
baseSet
.flatMap(_.references)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index b01a61d7bf..2e9f3aa4ec 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.trees
+import org.apache.spark.sql.types.{ArrayType, StructType, StructField}
abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
@@ -109,16 +110,22 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
* nodes of this LogicalPlan. The attribute is expressed as
* as string in the following form: `[scope].AttributeName.[nested].[fields]...`.
*/
- def resolveChildren(name: String, resolver: Resolver): Option[NamedExpression] =
- resolve(name, children.flatMap(_.output), resolver)
+ def resolveChildren(
+ name: String,
+ resolver: Resolver,
+ throwErrors: Boolean = false): Option[NamedExpression] =
+ resolve(name, children.flatMap(_.output), resolver, throwErrors)
/**
* Optionally resolves the given string to a [[NamedExpression]] based on the output of this
* LogicalPlan. The attribute is expressed as string in the following form:
* `[scope].AttributeName.[nested].[fields]...`.
*/
- def resolve(name: String, resolver: Resolver): Option[NamedExpression] =
- resolve(name, output, resolver)
+ def resolve(
+ name: String,
+ resolver: Resolver,
+ throwErrors: Boolean = false): Option[NamedExpression] =
+ resolve(name, output, resolver, throwErrors)
/**
* Resolve the given `name` string against the given attribute, returning either 0 or 1 match.
@@ -162,7 +169,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
protected def resolve(
name: String,
input: Seq[Attribute],
- resolver: Resolver): Option[NamedExpression] = {
+ resolver: Resolver,
+ throwErrors: Boolean): Option[NamedExpression] = {
val parts = name.split("\\.")
@@ -196,14 +204,19 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
// One match, but we also need to extract the requested nested field.
case Seq((a, nestedFields)) =>
- // The foldLeft adds UnresolvedGetField for every remaining parts of the name,
- // and aliased it with the last part of the name.
- // For example, consider name "a.b.c", where "a" is resolved to an existing attribute.
- // Then this will add UnresolvedGetField("b") and UnresolvedGetField("c"), and alias
- // the final expression as "c".
- val fieldExprs = nestedFields.foldLeft(a: Expression)(UnresolvedGetField)
- val aliasName = nestedFields.last
- Some(Alias(fieldExprs, aliasName)())
+ try {
+
+ // The foldLeft adds UnresolvedGetField for every remaining parts of the name,
+ // and aliased it with the last part of the name.
+ // For example, consider name "a.b.c", where "a" is resolved to an existing attribute.
+ // Then this will add UnresolvedGetField("b") and UnresolvedGetField("c"), and alias
+ // the final expression as "c".
+ val fieldExprs = nestedFields.foldLeft(a: Expression)(resolveGetField(_, _, resolver))
+ val aliasName = nestedFields.last
+ Some(Alias(fieldExprs, aliasName)())
+ } catch {
+ case a: AnalysisException if !throwErrors => None
+ }
// No matches.
case Seq() =>
@@ -212,11 +225,46 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
// More than one match.
case ambiguousReferences =>
- val referenceNames = ambiguousReferences.map(_._1.qualifiedName).mkString(", ")
+ val referenceNames = ambiguousReferences.map(_._1).mkString(", ")
throw new AnalysisException(
s"Reference '$name' is ambiguous, could be: $referenceNames.")
}
}
+
+ /**
+ * Returns the resolved `GetField`, and report error if no desired field or over one
+ * desired fields are found.
+ *
+ * TODO: this code is duplicated from Analyzer and should be refactored to avoid this.
+ */
+ protected def resolveGetField(
+ expr: Expression,
+ fieldName: String,
+ resolver: Resolver): Expression = {
+ def findField(fields: Array[StructField]): Int = {
+ val checkField = (f: StructField) => resolver(f.name, fieldName)
+ val ordinal = fields.indexWhere(checkField)
+ if (ordinal == -1) {
+ throw new AnalysisException(
+ s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}")
+ } else if (fields.indexWhere(checkField, ordinal + 1) != -1) {
+ throw new AnalysisException(
+ s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}")
+ } else {
+ ordinal
+ }
+ }
+ expr.dataType match {
+ case StructType(fields) =>
+ val ordinal = findField(fields)
+ StructGetField(expr, fields(ordinal), ordinal)
+ case ArrayType(StructType(fields), containsNull) =>
+ val ordinal = findField(fields)
+ ArrayGetField(expr, fields(ordinal), ordinal, containsNull)
+ case otherType =>
+ throw new AnalysisException(s"GetField is not valid on fields of type $otherType")
+ }
+ }
}
/**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 756cd36f05..ee7b14c7a1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -40,14 +40,12 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
override val extendedResolutionRules = EliminateSubQueries :: Nil
}
- val checkAnalysis = new CheckAnalysis
-
def caseSensitiveAnalyze(plan: LogicalPlan) =
- checkAnalysis(caseSensitiveAnalyzer(plan))
+ caseSensitiveAnalyzer.checkAnalysis(caseSensitiveAnalyzer(plan))
def caseInsensitiveAnalyze(plan: LogicalPlan) =
- checkAnalysis(caseInsensitiveAnalyzer(plan))
+ caseInsensitiveAnalyzer.checkAnalysis(caseInsensitiveAnalyzer(plan))
val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)())
val testRelation2 = LocalRelation(
@@ -57,6 +55,21 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
AttributeReference("d", DecimalType.Unlimited)(),
AttributeReference("e", ShortType)())
+ val nestedRelation = LocalRelation(
+ AttributeReference("top", StructType(
+ StructField("duplicateField", StringType) ::
+ StructField("duplicateField", StringType) ::
+ StructField("differentCase", StringType) ::
+ StructField("differentcase", StringType) :: Nil
+ ))())
+
+ val nestedRelation2 = LocalRelation(
+ AttributeReference("top", StructType(
+ StructField("aField", StringType) ::
+ StructField("bField", StringType) ::
+ StructField("cField", StringType) :: Nil
+ ))())
+
before {
caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation)
caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation)
@@ -169,6 +182,24 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
"'b'" :: "group by" :: Nil
)
+ errorTest(
+ "ambiguous field",
+ nestedRelation.select($"top.duplicateField"),
+ "Ambiguous reference to fields" :: "duplicateField" :: Nil,
+ caseSensitive = false)
+
+ errorTest(
+ "ambiguous field due to case insensitivity",
+ nestedRelation.select($"top.differentCase"),
+ "Ambiguous reference to fields" :: "differentCase" :: "differentcase" :: Nil,
+ caseSensitive = false)
+
+ errorTest(
+ "missing field",
+ nestedRelation2.select($"top.c"),
+ "No such struct field" :: "aField" :: "bField" :: "cField" :: Nil,
+ caseSensitive = false)
+
case class UnresolvedTestPlan() extends LeafNode {
override lazy val resolved = false
override def output = Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index b8100782ec..1794936a52 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -120,6 +120,10 @@ class SQLContext(@transient val sparkContext: SparkContext)
ExtractPythonUdfs ::
sources.PreInsertCastAndRename ::
Nil
+
+ override val extendedCheckRules = Seq(
+ sources.PreWriteCheck(catalog)
+ )
}
@transient
@@ -1065,14 +1069,6 @@ class SQLContext(@transient val sparkContext: SparkContext)
Batch("Add exchange", Once, AddExchange(self)) :: Nil
}
- @transient
- protected[sql] lazy val checkAnalysis = new CheckAnalysis {
- override val extendedCheckRules = Seq(
- sources.PreWriteCheck(catalog)
- )
- }
-
-
protected[sql] def openSession(): SQLSession = {
detachSession()
val session = createSession()
@@ -1105,7 +1101,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
@DeveloperApi
protected[sql] class QueryExecution(val logical: LogicalPlan) {
- def assertAnalyzed(): Unit = checkAnalysis(analyzed)
+ def assertAnalyzed(): Unit = analyzer.checkAnalysis(analyzed)
lazy val analyzed: LogicalPlan = analyzer(logical)
lazy val withCachedData: LogicalPlan = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index a3c0076e16..87e7cf8c8a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1084,10 +1084,19 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
test("SPARK-6145: ORDER BY test for nested fields") {
jsonRDD(sparkContext.makeRDD(
"""{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)).registerTempTable("nestedOrder")
- // These should be successfully analyzed
- sql("SELECT 1 FROM nestedOrder ORDER BY a.b").queryExecution.analyzed
- sql("SELECT a.b FROM nestedOrder ORDER BY a.b").queryExecution.analyzed
- sql("SELECT 1 FROM nestedOrder ORDER BY a.a.a").queryExecution.analyzed
- sql("SELECT 1 FROM nestedOrder ORDER BY c[0].d").queryExecution.analyzed
+
+ checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY a.b"), Row(1))
+ checkAnswer(sql("SELECT a.b FROM nestedOrder ORDER BY a.b"), Row(1))
+ checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY a.a.a"), Row(1))
+ checkAnswer(sql("SELECT a.a.a FROM nestedOrder ORDER BY a.a.a"), Row(1))
+ checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY c[0].d"), Row(1))
+ checkAnswer(sql("SELECT c[0].d FROM nestedOrder ORDER BY c[0].d"), Row(1))
+ }
+
+ test("SPARK-6145: special cases") {
+ jsonRDD(sparkContext.makeRDD(
+ """{"a": {"b": [1]}, "b": [{"a": 1}], "c0": {"a": 1}}""" :: Nil)).registerTempTable("t")
+ checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY c0.a"), Row(1))
+ checkAnswer(sql("SELECT b[0].a FROM t ORDER BY c0.a"), Row(1))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
index 91c6367371..33c6735596 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
@@ -32,6 +32,10 @@ abstract class DataSourceTest extends QueryTest with BeforeAndAfter {
override val extendedResolutionRules =
PreInsertCastAndRename ::
Nil
+
+ override val extendedCheckRules = Seq(
+ sources.PreWriteCheck(catalog)
+ )
}
}
}