diff options
6 files changed, 294 insertions, 9 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 902e18081b..567010f23f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -66,6 +66,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { PushPredicateThroughProject, PushPredicateThroughGenerate, PushPredicateThroughAggregate, + // LimitPushDown, // Disabled until we have whole-stage codegen for limit ColumnPruning, // Operator combine CollapseRepartition, @@ -130,6 +131,69 @@ object EliminateSerialization extends Rule[LogicalPlan] { } /** + * Pushes down [[LocalLimit]] beneath UNION ALL and beneath the streamed inputs of outer joins. + */ +object LimitPushDown extends Rule[LogicalPlan] { + + private def stripGlobalLimitIfPresent(plan: LogicalPlan): LogicalPlan = { + plan match { + case GlobalLimit(expr, child) => child + case _ => plan + } + } + + private def maybePushLimit(limitExp: Expression, plan: LogicalPlan): LogicalPlan = { + (limitExp, plan.maxRows) match { + case (IntegerLiteral(maxRow), Some(childMaxRows)) if maxRow < childMaxRows => + LocalLimit(limitExp, stripGlobalLimitIfPresent(plan)) + case (_, None) => + LocalLimit(limitExp, stripGlobalLimitIfPresent(plan)) + case _ => plan + } + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // Adding extra Limits below UNION ALL for children which are not Limit or do not have Limit + // descendants whose maxRow is larger. This heuristic is valid assuming there does not exist any + // Limit push-down rule that is unable to infer the value of maxRows. + // Note: right now Union means UNION ALL, which does not de-duplicate rows, so it is safe to + // pushdown Limit through it. Once we add UNION DISTINCT, however, we will not be able to + // pushdown Limit. + case LocalLimit(exp, Union(children)) => + LocalLimit(exp, Union(children.map(maybePushLimit(exp, _)))) + // Add extra limits below OUTER JOIN. For LEFT OUTER and FULL OUTER JOIN we push limits to the + // left and right sides, respectively. For FULL OUTER JOIN, we can only push limits to one side + // because we need to ensure that rows from the limited side still have an opportunity to match + // against all candidates from the non-limited side. We also need to ensure that this limit + // pushdown rule will not eventually introduce limits on both sides if it is applied multiple + // times. Therefore: + // - If one side is already limited, stack another limit on top if the new limit is smaller. + // The redundant limit will be collapsed by the CombineLimits rule. + // - If neither side is limited, limit the side that is estimated to be bigger. + case LocalLimit(exp, join @ Join(left, right, joinType, condition)) => + val newJoin = joinType match { + case RightOuter => join.copy(right = maybePushLimit(exp, right)) + case LeftOuter => join.copy(left = maybePushLimit(exp, left)) + case FullOuter => + (left.maxRows, right.maxRows) match { + case (None, None) => + if (left.statistics.sizeInBytes >= right.statistics.sizeInBytes) { + join.copy(left = maybePushLimit(exp, left)) + } else { + join.copy(right = maybePushLimit(exp, right)) + } + case (Some(_), Some(_)) => join + case (Some(_), None) => join.copy(left = maybePushLimit(exp, left)) + case (None, Some(_)) => join.copy(right = maybePushLimit(exp, right)) + + } + case _ => join + } + LocalLimit(exp, newJoin) + } +} + +/** * Pushes certain operations to both sides of a Union or Except operator. * Operations that are safe to pushdown are listed as follows. * Union: @@ -985,8 +1049,12 @@ object RemoveDispensableExpressions extends Rule[LogicalPlan] { */ object CombineLimits extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case ll @ GlobalLimit(le, nl @ GlobalLimit(ne, grandChild)) => + GlobalLimit(Least(Seq(ne, le)), grandChild) + case ll @ LocalLimit(le, nl @ LocalLimit(ne, grandChild)) => + LocalLimit(Least(Seq(ne, le)), grandChild) case ll @ Limit(le, nl @ Limit(ne, grandChild)) => - Limit(If(LessThan(ne, le), ne, le), grandChild) + Limit(Least(Seq(ne, le)), grandChild) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 18b7bde906..35e0f5d563 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -91,6 +91,14 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { } /** + * Returns the maximum number of rows that this plan may compute. + * + * Any operator that a Limit can be pushed passed should override this function (e.g., Union). + * Any operator that can push through a Limit should override this function (e.g., Project). + */ + def maxRows: Option[Long] = None + + /** * Returns true if this expression and all its children have been resolved to a specific schema * and false if it still contains any unresolved placeholders. Implementations of LogicalPlan * can override this (e.g. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index e8e0a78904..502d898fea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -38,6 +38,7 @@ case class ReturnAnswer(child: LogicalPlan) extends UnaryNode { case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = projectList.map(_.toAttribute) + override def maxRows: Option[Long] = child.maxRows override lazy val resolved: Boolean = { val hasSpecialExpressions = projectList.exists ( _.collect { @@ -56,6 +57,7 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend * output of each into a new stream of rows. This operation is similar to a `flatMap` in functional * programming with one important additional feature, which allows the input rows to be joined with * their output. + * * @param generator the generator expression * @param join when true, each output row is implicitly joined with the input tuple that produced * it. @@ -102,6 +104,8 @@ case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode with PredicateHelper { override def output: Seq[Attribute] = child.output + override def maxRows: Option[Long] = child.maxRows + override protected def validConstraints: Set[Expression] = { child.constraints.union(splitConjunctivePredicates(condition).toSet) } @@ -144,6 +148,14 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation left.output.length == right.output.length && left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } && duplicateResolved + + override def maxRows: Option[Long] = { + if (children.exists(_.maxRows.isEmpty)) { + None + } else { + Some(children.flatMap(_.maxRows).min) + } + } } case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { @@ -166,6 +178,13 @@ object Union { } case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { + override def maxRows: Option[Long] = { + if (children.exists(_.maxRows.isEmpty)) { + None + } else { + Some(children.flatMap(_.maxRows).sum) + } + } // updating nullability to make all the children consistent override def output: Seq[Attribute] = @@ -305,6 +324,7 @@ case class InsertIntoTable( /** * A container for holding named common table expressions (CTEs) and a query plan. * This operator will be removed during analysis and the relations will be substituted into child. + * * @param child The final query of this CTE. * @param cteRelations Queries that this CTE defined, * key is the alias of the CTE definition, @@ -331,6 +351,7 @@ case class Sort( global: Boolean, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output + override def maxRows: Option[Long] = child.maxRows } /** Factory for constructing new `Range` nodes. */ @@ -384,6 +405,7 @@ case class Aggregate( } override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) + override def maxRows: Option[Long] = child.maxRows } case class Window( @@ -505,6 +527,7 @@ trait GroupingAnalytics extends UnaryNode { * to generated by a UNION ALL of multiple simple GROUP BY clauses. * * We will transform GROUPING SETS into logical plan Aggregate(.., Expand) in Analyzer + * * @param bitmasks A list of bitmasks, each of the bitmask indicates the selected * GroupBy expressions * @param groupByExprs The Group By expressions candidates, take effective only if the @@ -537,9 +560,42 @@ case class Pivot( } } -case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { +object Limit { + def apply(limitExpr: Expression, child: LogicalPlan): UnaryNode = { + GlobalLimit(limitExpr, LocalLimit(limitExpr, child)) + } + + def unapply(p: GlobalLimit): Option[(Expression, LogicalPlan)] = { + p match { + case GlobalLimit(le1, LocalLimit(le2, child)) if le1 == le2 => Some((le1, child)) + case _ => None + } + } +} + +case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output + override def maxRows: Option[Long] = { + limitExpr match { + case IntegerLiteral(limit) => Some(limit) + case _ => None + } + } + override lazy val statistics: Statistics = { + val limit = limitExpr.eval().asInstanceOf[Int] + val sizeInBytes = (limit: Long) * output.map(a => a.dataType.defaultSize).sum + Statistics(sizeInBytes = sizeInBytes) + } +} +case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output + override def maxRows: Option[Long] = { + limitExpr match { + case IntegerLiteral(limit) => Some(limit) + case _ => None + } + } override lazy val statistics: Statistics = { val limit = limitExpr.eval().asInstanceOf[Int] val sizeInBytes = (limit: Long) * output.map(a => a.dataType.defaultSize).sum @@ -576,6 +632,7 @@ case class Sample( * Returns a new logical plan that dedups input rows. */ case class Distinct(child: LogicalPlan) extends UnaryNode { + override def maxRows: Option[Long] = child.maxRows override def output: Seq[Attribute] = child.output } @@ -594,6 +651,7 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan) * A relation with one row. This is used in "SELECT ..." without a from clause. */ case object OneRowRelation extends LeafNode { + override def maxRows: Option[Long] = Some(1) override def output: Seq[Attribute] = Nil /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala new file mode 100644 index 0000000000..fc1e994581 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.Add +import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, PlanTest, RightOuter} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ + +class LimitPushdownSuite extends PlanTest { + + private object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateSubQueries) :: + Batch("Limit pushdown", FixedPoint(100), + LimitPushDown, + CombineLimits, + ConstantFolding, + BooleanSimplification) :: Nil + } + + private val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + private val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) + private val x = testRelation.subquery('x) + private val y = testRelation.subquery('y) + + // Union --------------------------------------------------------------------------------------- + + test("Union: limit to each side") { + val unionQuery = Union(testRelation, testRelation2).limit(1) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = + Limit(1, Union(LocalLimit(1, testRelation), LocalLimit(1, testRelation2))).analyze + comparePlans(unionOptimized, unionCorrectAnswer) + } + + test("Union: limit to each side with constant-foldable limit expressions") { + val unionQuery = Union(testRelation, testRelation2).limit(Add(1, 1)) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = + Limit(2, Union(LocalLimit(2, testRelation), LocalLimit(2, testRelation2))).analyze + comparePlans(unionOptimized, unionCorrectAnswer) + } + + test("Union: limit to each side with the new limit number") { + val unionQuery = Union(testRelation, testRelation2.limit(3)).limit(1) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = + Limit(1, Union(LocalLimit(1, testRelation), LocalLimit(1, testRelation2))).analyze + comparePlans(unionOptimized, unionCorrectAnswer) + } + + test("Union: no limit to both sides if children having smaller limit values") { + val unionQuery = Union(testRelation.limit(1), testRelation2.select('d).limit(1)).limit(2) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = + Limit(2, Union(testRelation.limit(1), testRelation2.select('d).limit(1))).analyze + comparePlans(unionOptimized, unionCorrectAnswer) + } + + test("Union: limit to each sides if children having larger limit values") { + val testLimitUnion = Union(testRelation.limit(3), testRelation2.select('d).limit(4)) + val unionQuery = testLimitUnion.limit(2) + val unionOptimized = Optimize.execute(unionQuery.analyze) + val unionCorrectAnswer = + Limit(2, Union(LocalLimit(2, testRelation), LocalLimit(2, testRelation2.select('d)))).analyze + comparePlans(unionOptimized, unionCorrectAnswer) + } + + // Outer join ---------------------------------------------------------------------------------- + + test("left outer join") { + val originalQuery = x.join(y, LeftOuter).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, LocalLimit(1, y).join(y, LeftOuter)).analyze + comparePlans(optimized, correctAnswer) + } + + test("right outer join") { + val originalQuery = x.join(y, RightOuter).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, x.join(LocalLimit(1, y), RightOuter)).analyze + comparePlans(optimized, correctAnswer) + } + + test("larger limits are not pushed on top of smaller ones in right outer join") { + val originalQuery = x.join(y.limit(5), RightOuter).limit(10) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(10, x.join(Limit(5, y), RightOuter)).analyze + comparePlans(optimized, correctAnswer) + } + + test("full outer join where neither side is limited and both sides have same statistics") { + assert(x.statistics.sizeInBytes === y.statistics.sizeInBytes) + val originalQuery = x.join(y, FullOuter).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, LocalLimit(1, x).join(y, FullOuter)).analyze + comparePlans(optimized, correctAnswer) + } + + test("full outer join where neither side is limited and left side has larger statistics") { + val xBig = testRelation.copy(data = Seq.fill(2)(null)).subquery('x) + assert(xBig.statistics.sizeInBytes > y.statistics.sizeInBytes) + val originalQuery = xBig.join(y, FullOuter).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, LocalLimit(1, xBig).join(y, FullOuter)).analyze + comparePlans(optimized, correctAnswer) + } + + test("full outer join where neither side is limited and right side has larger statistics") { + val yBig = testRelation.copy(data = Seq.fill(2)(null)).subquery('y) + assert(x.statistics.sizeInBytes < yBig.statistics.sizeInBytes) + val originalQuery = x.join(yBig, FullOuter).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, x.join(LocalLimit(1, yBig), FullOuter)).analyze + comparePlans(optimized, correctAnswer) + } + + test("full outer join where both sides are limited") { + val originalQuery = x.limit(2).join(y.limit(2), FullOuter).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, Limit(2, x).join(Limit(2, y), FullOuter)).analyze + comparePlans(optimized, correctAnswer) + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 73fd22b38e..042c99db4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -351,10 +351,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Sample(lb, ub, withReplacement, seed, planLater(child)) :: Nil case logical.LocalRelation(output, data) => LocalTableScan(output, data) :: Nil - case logical.Limit(IntegerLiteral(limit), child) => - val perPartitionLimit = execution.LocalLimit(limit, planLater(child)) - val globalLimit = execution.GlobalLimit(limit, perPartitionLimit) - globalLimit :: Nil + case logical.LocalLimit(IntegerLiteral(limit), child) => + execution.LocalLimit(limit, planLater(child)) :: Nil + case logical.GlobalLimit(IntegerLiteral(limit), child) => + execution.GlobalLimit(limit, planLater(child)) :: Nil case logical.Union(unionChildren) => execution.Union(unionChildren.map(planLater)) :: Nil case logical.Except(left, right) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index d7bae913f8..bf5edb4759 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -77,8 +77,8 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi case p: Aggregate => aggregateToSQL(p) - case p: Limit => - s"${toSQL(p.child)} LIMIT ${p.limitExpr.sql}" + case Limit(limitExpr, child) => + s"${toSQL(child)} LIMIT ${limitExpr.sql}" case p: Filter => val whereOrHaving = p.child match { @@ -203,7 +203,13 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi wrapChildWithSubquery(plan) case plan @ Project(_, - _: Subquery | _: Filter | _: Join | _: MetastoreRelation | OneRowRelation | _: Limit + _: Subquery + | _: Filter + | _: Join + | _: MetastoreRelation + | OneRowRelation + | _: LocalLimit + | _: GlobalLimit ) => plan case plan: Project => |