aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2014-08-05 20:55:02 -0700
committerMichael Armbrust <michael@databricks.com>2014-08-05 20:55:02 -0700
commit1d70c4f66d3c688bd6750b344dff422d1c88cc22 (patch)
tree75d11b6f3853ef5d03d6153d2dea12ad26e14d11 /sql
parent69ec678d3aaeb6ece85e5e82353bf083bfc83667 (diff)
downloadspark-1d70c4f66d3c688bd6750b344dff422d1c88cc22.tar.gz
spark-1d70c4f66d3c688bd6750b344dff422d1c88cc22.tar.bz2
spark-1d70c4f66d3c688bd6750b344dff422d1c88cc22.zip
[SPARK-2866][SQL] Support attributes in ORDER BY that aren't in SELECT
Minor refactoring to allow resolution either using a nodes input or output. Author: Michael Armbrust <michael@databricks.com> Closes #1795 from marmbrus/ordering and squashes the following commits: 237f580 [Michael Armbrust] style 74d833b [Michael Armbrust] newline 705d963 [Michael Armbrust] Add a rule for resolving ORDER BY expressions that reference attributes not present in the SELECT clause. 82cabda [Michael Armbrust] Generalize attribute resolution.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala48
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala25
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala50
3 files changed, 116 insertions, 7 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 2ba68cab11..0293d578b0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -48,6 +48,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
Batch("Resolution", fixedPoint,
ResolveReferences ::
ResolveRelations ::
+ ResolveSortReferences ::
NewRelationInstances ::
ImplicitGenerate ::
StarExpansion ::
@@ -113,7 +114,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
q transformExpressions {
case u @ UnresolvedAttribute(name) =>
// Leave unchanged if resolution fails. Hopefully will be resolved next round.
- val result = q.resolve(name).getOrElse(u)
+ val result = q.resolveChildren(name).getOrElse(u)
logDebug(s"Resolving $u to $result")
result
}
@@ -121,6 +122,51 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
}
/**
+ * In many dialects of SQL is it valid to sort by attributes that are not present in the SELECT
+ * clause. This rule detects such queries and adds the required attributes to the original
+ * projection, so that they will be available during sorting. Another projection is added to
+ * remove these attributes after sorting.
+ */
+ object ResolveSortReferences extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+ case s @ Sort(ordering, p @ Project(projectList, child)) if !s.resolved && p.resolved =>
+ val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name })
+ val resolved = unresolved.flatMap(child.resolveChildren)
+ val requiredAttributes = resolved.collect { case a: Attribute => a }.toSet
+
+ val missingInProject = requiredAttributes -- p.output
+ if (missingInProject.nonEmpty) {
+ // Add missing attributes and then project them away after the sort.
+ Project(projectList,
+ Sort(ordering,
+ Project(projectList ++ missingInProject, child)))
+ } else {
+ s // Nothing we can do here. Return original plan.
+ }
+ case s @ Sort(ordering, a @ Aggregate(grouping, aggs, child)) if !s.resolved && a.resolved =>
+ val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name })
+ // A small hack to create an object that will allow us to resolve any references that
+ // refer to named expressions that are present in the grouping expressions.
+ val groupingRelation = LocalRelation(
+ grouping.collect { case ne: NamedExpression => ne.toAttribute }
+ )
+
+ logWarning(s"Grouping expressions: $groupingRelation")
+ val resolved = unresolved.flatMap(groupingRelation.resolve).toSet
+ val missingInAggs = resolved -- a.outputSet
+ logWarning(s"Resolved: $resolved Missing in aggs: $missingInAggs")
+ if (missingInAggs.nonEmpty) {
+ // Add missing grouping exprs and then project them away after the sort.
+ Project(a.output,
+ Sort(ordering,
+ Aggregate(grouping, aggs ++ missingInAggs, child)))
+ } else {
+ s // Nothing we can do here. Return original plan.
+ }
+ }
+ }
+
+ /**
* Replaces [[UnresolvedFunction]]s with concrete [[catalyst.expressions.Expression Expressions]].
*/
object ResolveFunctions extends Rule[LogicalPlan] {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 888cb08e95..278569f0cb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -72,16 +72,29 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] {
def childrenResolved: Boolean = !children.exists(!_.resolved)
/**
- * Optionally resolves the given string to a [[NamedExpression]]. The attribute is expressed as
+ * Optionally resolves the given string to a [[NamedExpression]] using the input from all child
+ * nodes of this LogicalPlan. The attribute is expressed as
* as string in the following form: `[scope].AttributeName.[nested].[fields]...`.
*/
- def resolve(name: String): Option[NamedExpression] = {
+ def resolveChildren(name: String): Option[NamedExpression] =
+ resolve(name, children.flatMap(_.output))
+
+ /**
+ * Optionally resolves the given string to a [[NamedExpression]] based on the output of this
+ * LogicalPlan. The attribute is expressed as string in the following form:
+ * `[scope].AttributeName.[nested].[fields]...`.
+ */
+ def resolve(name: String): Option[NamedExpression] =
+ resolve(name, output)
+
+ /** Performs attribute resolution given a name and a sequence of possible attributes. */
+ protected def resolve(name: String, input: Seq[Attribute]): Option[NamedExpression] = {
val parts = name.split("\\.")
// Collect all attributes that are output by this nodes children where either the first part
// matches the name or where the first part matches the scope and the second part matches the
// name. Return these matches along with any remaining parts, which represent dotted access to
// struct fields.
- val options = children.flatMap(_.output).flatMap { option =>
+ val options = input.flatMap { option =>
// If the first part of the desired name matches a qualifier for this possible match, drop it.
val remainingParts =
if (option.qualifiers.contains(parts.head) && parts.size > 1) parts.drop(1) else parts
@@ -89,15 +102,15 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] {
}
options.distinct match {
- case (a, Nil) :: Nil => Some(a) // One match, no nested fields, use it.
+ case Seq((a, Nil)) => Some(a) // One match, no nested fields, use it.
// One match, but we also need to extract the requested nested field.
- case (a, nestedFields) :: Nil =>
+ case Seq((a, nestedFields)) =>
a.dataType match {
case StructType(fields) =>
Some(Alias(nestedFields.foldLeft(a: Expression)(GetField), nestedFields.last)())
case _ => None // Don't know how to resolve these field references
}
- case Nil => None // No matches.
+ case Seq() => None // No matches.
case ambiguousReferences =>
throw new TreeNodeException(
this, s"Ambiguous references to $name: ${ambiguousReferences.mkString(",")}")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
new file mode 100644
index 0000000000..635a9fb0d5
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.sql.{SQLConf, QueryTest}
+import org.apache.spark.sql.execution.{BroadcastHashJoin, ShuffledHashJoin}
+import org.apache.spark.sql.hive.test.TestHive
+import org.apache.spark.sql.hive.test.TestHive._
+
+/**
+ * A collection of hive query tests where we generate the answers ourselves instead of depending on
+ * Hive to generate them (in contrast to HiveQuerySuite). Often this is because the query is
+ * valid, but Hive currently cannot execute it.
+ */
+class SQLQuerySuite extends QueryTest {
+ test("ordering not in select") {
+ checkAnswer(
+ sql("SELECT key FROM src ORDER BY value"),
+ sql("SELECT key FROM (SELECT key, value FROM src ORDER BY value) a").collect().toSeq)
+ }
+
+ test("ordering not in agg") {
+ checkAnswer(
+ sql("SELECT key FROM src GROUP BY key, value ORDER BY value"),
+ sql("""
+ SELECT key
+ FROM (
+ SELECT key, value
+ FROM src
+ GROUP BY key, value
+ ORDER BY value) a""").collect().toSeq)
+ }
+}