diff options
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 | ||||
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala (renamed from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolvedOrdinalSubstitution.scala) | 26 | ||||
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala | 13 | ||||
-rw-r--r-- | sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinalsSuite.scala (renamed from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnresolvedOrdinalSubstitutionSuite.scala) | 24 |
4 files changed, 28 insertions, 37 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index f540816366..cfab6ae7bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -84,7 +84,7 @@ class Analyzer( CTESubstitution, WindowsSubstitution, EliminateUnions, - new UnresolvedOrdinalSubstitution(conf)), + new SubstituteUnresolvedOrdinals(conf)), Batch("Resolution", fixedPoint, ResolveRelations :: ResolveReferences :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolvedOrdinalSubstitution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala index e21cd08af8..6d8dc86282 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnresolvedOrdinalSubstitution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala @@ -18,32 +18,34 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.CatalystConf -import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder} -import org.apache.spark.sql.catalyst.planning.IntegerIndex +import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Sort} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin +import org.apache.spark.sql.types.IntegerType /** * Replaces ordinal in 'order by' or 'group by' with UnresolvedOrdinal expression. */ -class UnresolvedOrdinalSubstitution(conf: CatalystConf) extends Rule[LogicalPlan] { - private def isIntegerLiteral(sorter: Expression) = IntegerIndex.unapply(sorter).nonEmpty +class SubstituteUnresolvedOrdinals(conf: CatalystConf) extends Rule[LogicalPlan] { + private def isIntLiteral(e: Expression) = e match { + case Literal(_, IntegerType) => true + case _ => false + } def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case s @ Sort(orders, global, child) if conf.orderByOrdinal && - orders.exists(o => isIntegerLiteral(o.child)) => - val newOrders = orders.map { - case order @ SortOrder(ordinal @ IntegerIndex(index: Int), _) => + case s: Sort if conf.orderByOrdinal && s.order.exists(o => isIntLiteral(o.child)) => + val newOrders = s.order.map { + case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _) => val newOrdinal = withOrigin(ordinal.origin)(UnresolvedOrdinal(index)) withOrigin(order.origin)(order.copy(child = newOrdinal)) case other => other } withOrigin(s.origin)(s.copy(order = newOrders)) - case a @ Aggregate(groups, aggs, child) if conf.groupByOrdinal && - groups.exists(isIntegerLiteral(_)) => - val newGroups = groups.map { - case ordinal @ IntegerIndex(index) => + + case a: Aggregate if conf.groupByOrdinal && a.groupingExpressions.exists(isIntLiteral) => + val newGroups = a.groupingExpressions.map { + case ordinal @ Literal(index: Int, IntegerType) => withOrigin(ordinal.origin)(UnresolvedOrdinal(index)) case other => other } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index f42e67ca6e..476c66af76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -209,19 +209,6 @@ object Unions { } /** - * Extractor for retrieving Int value. - */ -object IntegerIndex { - def unapply(a: Any): Option[Int] = a match { - case Literal(a: Int, IntegerType) => Some(a) - // When resolving ordinal in Sort and Group By, negative values are extracted - // for issuing error messages. - case UnaryMinus(IntegerLiteral(v)) => Some(-v) - case _ => None - } -} - -/** * An extractor used when planning the physical execution of an aggregation. Compared with a logical * aggregation, the following transformations are performed: * - Unnamed grouping expressions are named so that they can be referred to across phases of diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnresolvedOrdinalSubstitutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinalsSuite.scala index 23995e96e1..3c429ebce1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnresolvedOrdinalSubstitutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinalsSuite.scala @@ -23,20 +23,21 @@ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.SimpleCatalystConf -class UnresolvedOrdinalSubstitutionSuite extends AnalysisTest { - - test("test rule UnresolvedOrdinalSubstitution, replaces ordinal in order by or group by") { - val a = testRelation2.output(0) - val b = testRelation2.output(1) - val conf = new SimpleCatalystConf(caseSensitiveAnalysis = true) +class SubstituteUnresolvedOrdinalsSuite extends AnalysisTest { + private lazy val conf = SimpleCatalystConf(caseSensitiveAnalysis = true) + private lazy val a = testRelation2.output(0) + private lazy val b = testRelation2.output(1) + test("unresolved ordinal should not be unresolved") { // Expression OrderByOrdinal is unresolved. assert(!UnresolvedOrdinal(0).resolved) + } + test("order by ordinal") { // Tests order by ordinal, apply single rule. val plan = testRelation2.orderBy(Literal(1).asc, Literal(2).asc) comparePlans( - new UnresolvedOrdinalSubstitution(conf).apply(plan), + new SubstituteUnresolvedOrdinals(conf).apply(plan), testRelation2.orderBy(UnresolvedOrdinal(1).asc, UnresolvedOrdinal(2).asc)) // Tests order by ordinal, do full analysis @@ -44,14 +45,15 @@ class UnresolvedOrdinalSubstitutionSuite extends AnalysisTest { // order by ordinal can be turned off by config comparePlans( - new UnresolvedOrdinalSubstitution(conf.copy(orderByOrdinal = false)).apply(plan), + new SubstituteUnresolvedOrdinals(conf.copy(orderByOrdinal = false)).apply(plan), testRelation2.orderBy(Literal(1).asc, Literal(2).asc)) + } - + test("group by ordinal") { // Tests group by ordinal, apply single rule. val plan2 = testRelation2.groupBy(Literal(1), Literal(2))('a, 'b) comparePlans( - new UnresolvedOrdinalSubstitution(conf).apply(plan2), + new SubstituteUnresolvedOrdinals(conf).apply(plan2), testRelation2.groupBy(UnresolvedOrdinal(1), UnresolvedOrdinal(2))('a, 'b)) // Tests group by ordinal, do full analysis @@ -59,7 +61,7 @@ class UnresolvedOrdinalSubstitutionSuite extends AnalysisTest { // group by ordinal can be turned off by config comparePlans( - new UnresolvedOrdinalSubstitution(conf.copy(groupByOrdinal = false)).apply(plan2), + new SubstituteUnresolvedOrdinals(conf.copy(groupByOrdinal = false)).apply(plan2), testRelation2.groupBy(Literal(1), Literal(2))('a, 'b)) } } |