aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src/main/scala/org
diff options
context:
space:
mode:
Diffstat (limited to 'sql/catalyst/src/main/scala/org')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala23
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolvedOrdinalSubstitution.scala52
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala18
3 files changed, 81 insertions, 12 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index a2e276e8a2..a2a022c247 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -22,17 +22,16 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf}
-import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogRelation, InMemoryCatalog, SessionCatalog}
+import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.encoders.OuterScopes
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.objects.NewInstance
import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification
-import org.apache.spark.sql.catalyst.planning.IntegerIndex
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _}
import org.apache.spark.sql.catalyst.rules._
-import org.apache.spark.sql.catalyst.trees.TreeNodeRef
+import org.apache.spark.sql.catalyst.trees.{TreeNodeRef}
import org.apache.spark.sql.catalyst.util.toPrettySQL
import org.apache.spark.sql.types._
@@ -84,7 +83,8 @@ class Analyzer(
Batch("Substitution", fixedPoint,
CTESubstitution,
WindowsSubstitution,
- EliminateUnions),
+ EliminateUnions,
+ new UnresolvedOrdinalSubstitution(conf)),
Batch("Resolution", fixedPoint,
ResolveRelations ::
ResolveReferences ::
@@ -545,7 +545,7 @@ class Analyzer(
p.copy(projectList = buildExpandedProjectList(p.projectList, p.child))
// If the aggregate function argument contains Stars, expand it.
case a: Aggregate if containsStar(a.aggregateExpressions) =>
- if (conf.groupByOrdinal && a.groupingExpressions.exists(IntegerIndex.unapply(_).nonEmpty)) {
+ if (a.groupingExpressions.exists(_.isInstanceOf[UnresolvedOrdinal])) {
failAnalysis(
"Star (*) is not allowed in select list when GROUP BY ordinal position is used")
} else {
@@ -716,9 +716,9 @@ class Analyzer(
// Replace the index with the related attribute for ORDER BY,
// which is a 1-base position of the projection list.
case s @ Sort(orders, global, child)
- if conf.orderByOrdinal && orders.exists(o => IntegerIndex.unapply(o.child).nonEmpty) =>
+ if orders.exists(_.child.isInstanceOf[UnresolvedOrdinal]) =>
val newOrders = orders map {
- case s @ SortOrder(IntegerIndex(index), direction) =>
+ case s @ SortOrder(UnresolvedOrdinal(index), direction) =>
if (index > 0 && index <= child.output.size) {
SortOrder(child.output(index - 1), direction)
} else {
@@ -732,11 +732,10 @@ class Analyzer(
// Replace the index with the corresponding expression in aggregateExpressions. The index is
// a 1-base position of aggregateExpressions, which is output columns (select expression)
- case a @ Aggregate(groups, aggs, child)
- if conf.groupByOrdinal && aggs.forall(_.resolved) &&
- groups.exists(IntegerIndex.unapply(_).nonEmpty) =>
+ case a @ Aggregate(groups, aggs, child) if aggs.forall(_.resolved) &&
+ groups.exists(_.isInstanceOf[UnresolvedOrdinal]) =>
val newGroups = groups.map {
- case ordinal @ IntegerIndex(index) if index > 0 && index <= aggs.size =>
+ case ordinal @ UnresolvedOrdinal(index) if index > 0 && index <= aggs.size =>
aggs(index - 1) match {
case e if ResolveAggregateFunctions.containsAggregate(e) =>
ordinal.failAnalysis(
@@ -744,7 +743,7 @@ class Analyzer(
"aggregate functions are not allowed in GROUP BY")
case o => o
}
- case ordinal @ IntegerIndex(index) =>
+ case ordinal @ UnresolvedOrdinal(index) =>
ordinal.failAnalysis(
s"GROUP BY position $index is not in select list " +
s"(valid range is [1, ${aggs.size}])")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolvedOrdinalSubstitution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolvedOrdinalSubstitution.scala
new file mode 100644
index 0000000000..e21cd08af8
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolvedOrdinalSubstitution.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.CatalystConf
+import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder}
+import org.apache.spark.sql.catalyst.planning.IntegerIndex
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Sort}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
+
+/**
+ * Replaces ordinal in 'order by' or 'group by' with UnresolvedOrdinal expression.
+ */
+class UnresolvedOrdinalSubstitution(conf: CatalystConf) extends Rule[LogicalPlan] {
+ private def isIntegerLiteral(sorter: Expression) = IntegerIndex.unapply(sorter).nonEmpty
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case s @ Sort(orders, global, child) if conf.orderByOrdinal &&
+ orders.exists(o => isIntegerLiteral(o.child)) =>
+ val newOrders = orders.map {
+ case order @ SortOrder(ordinal @ IntegerIndex(index: Int), _) =>
+ val newOrdinal = withOrigin(ordinal.origin)(UnresolvedOrdinal(index))
+ withOrigin(order.origin)(order.copy(child = newOrdinal))
+ case other => other
+ }
+ withOrigin(s.origin)(s.copy(order = newOrders))
+ case a @ Aggregate(groups, aggs, child) if conf.groupByOrdinal &&
+ groups.exists(isIntegerLiteral(_)) =>
+ val newGroups = groups.map {
+ case ordinal @ IntegerIndex(index) =>
+ withOrigin(ordinal.origin)(UnresolvedOrdinal(index))
+ case other => other
+ }
+ withOrigin(a.origin)(a.copy(groupingExpressions = newGroups))
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 609089a302..42e7aae0b6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -370,3 +370,21 @@ case class GetColumnByOrdinal(ordinal: Int, dataType: DataType) extends LeafExpr
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
override lazy val resolved = false
}
+
+/**
+ * Represents unresolved ordinal used in order by or group by.
+ *
+ * For example:
+ * {{{
+ * select a from table order by 1
+ * select a from table group by 1
+ * }}}
+ * @param ordinal ordinal starts from 1, instead of 0
+ */
+case class UnresolvedOrdinal(ordinal: Int)
+ extends LeafExpression with Unevaluable with NonSQLExpression {
+ override def dataType: DataType = throw new UnresolvedException(this, "dataType")
+ override def foldable: Boolean = throw new UnresolvedException(this, "foldable")
+ override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
+ override lazy val resolved = false
+}