aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala70
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala60
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala145
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala8
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala12
6 files changed, 294 insertions, 9 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 902e18081b..567010f23f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -66,6 +66,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
PushPredicateThroughProject,
PushPredicateThroughGenerate,
PushPredicateThroughAggregate,
+ // LimitPushDown, // Disabled until we have whole-stage codegen for limit
ColumnPruning,
// Operator combine
CollapseRepartition,
@@ -130,6 +131,69 @@ object EliminateSerialization extends Rule[LogicalPlan] {
}
/**
+ * Pushes down [[LocalLimit]] beneath UNION ALL and beneath the streamed inputs of outer joins.
+ */
+object LimitPushDown extends Rule[LogicalPlan] {
+
+ private def stripGlobalLimitIfPresent(plan: LogicalPlan): LogicalPlan = {
+ plan match {
+ case GlobalLimit(expr, child) => child
+ case _ => plan
+ }
+ }
+
+ private def maybePushLimit(limitExp: Expression, plan: LogicalPlan): LogicalPlan = {
+ (limitExp, plan.maxRows) match {
+ case (IntegerLiteral(maxRow), Some(childMaxRows)) if maxRow < childMaxRows =>
+ LocalLimit(limitExp, stripGlobalLimitIfPresent(plan))
+ case (_, None) =>
+ LocalLimit(limitExp, stripGlobalLimitIfPresent(plan))
+ case _ => plan
+ }
+ }
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ // Adding extra Limits below UNION ALL for children which are not Limit or do not have Limit
+ // descendants whose maxRow is larger. This heuristic is valid assuming there does not exist any
+ // Limit push-down rule that is unable to infer the value of maxRows.
+ // Note: right now Union means UNION ALL, which does not de-duplicate rows, so it is safe to
+ // pushdown Limit through it. Once we add UNION DISTINCT, however, we will not be able to
+ // pushdown Limit.
+ case LocalLimit(exp, Union(children)) =>
+ LocalLimit(exp, Union(children.map(maybePushLimit(exp, _))))
+ // Add extra limits below OUTER JOIN. For LEFT OUTER and FULL OUTER JOIN we push limits to the
+ // left and right sides, respectively. For FULL OUTER JOIN, we can only push limits to one side
+ // because we need to ensure that rows from the limited side still have an opportunity to match
+ // against all candidates from the non-limited side. We also need to ensure that this limit
+ // pushdown rule will not eventually introduce limits on both sides if it is applied multiple
+ // times. Therefore:
+ // - If one side is already limited, stack another limit on top if the new limit is smaller.
+ // The redundant limit will be collapsed by the CombineLimits rule.
+ // - If neither side is limited, limit the side that is estimated to be bigger.
+ case LocalLimit(exp, join @ Join(left, right, joinType, condition)) =>
+ val newJoin = joinType match {
+ case RightOuter => join.copy(right = maybePushLimit(exp, right))
+ case LeftOuter => join.copy(left = maybePushLimit(exp, left))
+ case FullOuter =>
+ (left.maxRows, right.maxRows) match {
+ case (None, None) =>
+ if (left.statistics.sizeInBytes >= right.statistics.sizeInBytes) {
+ join.copy(left = maybePushLimit(exp, left))
+ } else {
+ join.copy(right = maybePushLimit(exp, right))
+ }
+ case (Some(_), Some(_)) => join
+ case (Some(_), None) => join.copy(left = maybePushLimit(exp, left))
+ case (None, Some(_)) => join.copy(right = maybePushLimit(exp, right))
+
+ }
+ case _ => join
+ }
+ LocalLimit(exp, newJoin)
+ }
+}
+
+/**
* Pushes certain operations to both sides of a Union or Except operator.
* Operations that are safe to pushdown are listed as follows.
* Union:
@@ -985,8 +1049,12 @@ object RemoveDispensableExpressions extends Rule[LogicalPlan] {
*/
object CombineLimits extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case ll @ GlobalLimit(le, nl @ GlobalLimit(ne, grandChild)) =>
+ GlobalLimit(Least(Seq(ne, le)), grandChild)
+ case ll @ LocalLimit(le, nl @ LocalLimit(ne, grandChild)) =>
+ LocalLimit(Least(Seq(ne, le)), grandChild)
case ll @ Limit(le, nl @ Limit(ne, grandChild)) =>
- Limit(If(LessThan(ne, le), ne, le), grandChild)
+ Limit(Least(Seq(ne, le)), grandChild)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 18b7bde906..35e0f5d563 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -91,6 +91,14 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
}
/**
+ * Returns the maximum number of rows that this plan may compute.
+ *
+ * Any operator that a Limit can be pushed passed should override this function (e.g., Union).
+ * Any operator that can push through a Limit should override this function (e.g., Project).
+ */
+ def maxRows: Option[Long] = None
+
+ /**
* Returns true if this expression and all its children have been resolved to a specific schema
* and false if it still contains any unresolved placeholders. Implementations of LogicalPlan
* can override this (e.g.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index e8e0a78904..502d898fea 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -38,6 +38,7 @@ case class ReturnAnswer(child: LogicalPlan) extends UnaryNode {
case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
+ override def maxRows: Option[Long] = child.maxRows
override lazy val resolved: Boolean = {
val hasSpecialExpressions = projectList.exists ( _.collect {
@@ -56,6 +57,7 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend
* output of each into a new stream of rows. This operation is similar to a `flatMap` in functional
* programming with one important additional feature, which allows the input rows to be joined with
* their output.
+ *
* @param generator the generator expression
* @param join when true, each output row is implicitly joined with the input tuple that produced
* it.
@@ -102,6 +104,8 @@ case class Filter(condition: Expression, child: LogicalPlan)
extends UnaryNode with PredicateHelper {
override def output: Seq[Attribute] = child.output
+ override def maxRows: Option[Long] = child.maxRows
+
override protected def validConstraints: Set[Expression] = {
child.constraints.union(splitConjunctivePredicates(condition).toSet)
}
@@ -144,6 +148,14 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation
left.output.length == right.output.length &&
left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } &&
duplicateResolved
+
+ override def maxRows: Option[Long] = {
+ if (children.exists(_.maxRows.isEmpty)) {
+ None
+ } else {
+ Some(children.flatMap(_.maxRows).min)
+ }
+ }
}
case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) {
@@ -166,6 +178,13 @@ object Union {
}
case class Union(children: Seq[LogicalPlan]) extends LogicalPlan {
+ override def maxRows: Option[Long] = {
+ if (children.exists(_.maxRows.isEmpty)) {
+ None
+ } else {
+ Some(children.flatMap(_.maxRows).sum)
+ }
+ }
// updating nullability to make all the children consistent
override def output: Seq[Attribute] =
@@ -305,6 +324,7 @@ case class InsertIntoTable(
/**
* A container for holding named common table expressions (CTEs) and a query plan.
* This operator will be removed during analysis and the relations will be substituted into child.
+ *
* @param child The final query of this CTE.
* @param cteRelations Queries that this CTE defined,
* key is the alias of the CTE definition,
@@ -331,6 +351,7 @@ case class Sort(
global: Boolean,
child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
+ override def maxRows: Option[Long] = child.maxRows
}
/** Factory for constructing new `Range` nodes. */
@@ -384,6 +405,7 @@ case class Aggregate(
}
override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute)
+ override def maxRows: Option[Long] = child.maxRows
}
case class Window(
@@ -505,6 +527,7 @@ trait GroupingAnalytics extends UnaryNode {
* to generated by a UNION ALL of multiple simple GROUP BY clauses.
*
* We will transform GROUPING SETS into logical plan Aggregate(.., Expand) in Analyzer
+ *
* @param bitmasks A list of bitmasks, each of the bitmask indicates the selected
* GroupBy expressions
* @param groupByExprs The Group By expressions candidates, take effective only if the
@@ -537,9 +560,42 @@ case class Pivot(
}
}
-case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
+object Limit {
+ def apply(limitExpr: Expression, child: LogicalPlan): UnaryNode = {
+ GlobalLimit(limitExpr, LocalLimit(limitExpr, child))
+ }
+
+ def unapply(p: GlobalLimit): Option[(Expression, LogicalPlan)] = {
+ p match {
+ case GlobalLimit(le1, LocalLimit(le2, child)) if le1 == le2 => Some((le1, child))
+ case _ => None
+ }
+ }
+}
+
+case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
+ override def maxRows: Option[Long] = {
+ limitExpr match {
+ case IntegerLiteral(limit) => Some(limit)
+ case _ => None
+ }
+ }
+ override lazy val statistics: Statistics = {
+ val limit = limitExpr.eval().asInstanceOf[Int]
+ val sizeInBytes = (limit: Long) * output.map(a => a.dataType.defaultSize).sum
+ Statistics(sizeInBytes = sizeInBytes)
+ }
+}
+case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
+ override def output: Seq[Attribute] = child.output
+ override def maxRows: Option[Long] = {
+ limitExpr match {
+ case IntegerLiteral(limit) => Some(limit)
+ case _ => None
+ }
+ }
override lazy val statistics: Statistics = {
val limit = limitExpr.eval().asInstanceOf[Int]
val sizeInBytes = (limit: Long) * output.map(a => a.dataType.defaultSize).sum
@@ -576,6 +632,7 @@ case class Sample(
* Returns a new logical plan that dedups input rows.
*/
case class Distinct(child: LogicalPlan) extends UnaryNode {
+ override def maxRows: Option[Long] = child.maxRows
override def output: Seq[Attribute] = child.output
}
@@ -594,6 +651,7 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan)
* A relation with one row. This is used in "SELECT ..." without a from clause.
*/
case object OneRowRelation extends LeafNode {
+ override def maxRows: Option[Long] = Some(1)
override def output: Seq[Attribute] = Nil
/**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala
new file mode 100644
index 0000000000..fc1e994581
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.Add
+import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, PlanTest, RightOuter}
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules._
+
+class LimitPushdownSuite extends PlanTest {
+
+ private object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("Subqueries", Once,
+ EliminateSubQueries) ::
+ Batch("Limit pushdown", FixedPoint(100),
+ LimitPushDown,
+ CombineLimits,
+ ConstantFolding,
+ BooleanSimplification) :: Nil
+ }
+
+ private val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
+ private val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int)
+ private val x = testRelation.subquery('x)
+ private val y = testRelation.subquery('y)
+
+ // Union ---------------------------------------------------------------------------------------
+
+ test("Union: limit to each side") {
+ val unionQuery = Union(testRelation, testRelation2).limit(1)
+ val unionOptimized = Optimize.execute(unionQuery.analyze)
+ val unionCorrectAnswer =
+ Limit(1, Union(LocalLimit(1, testRelation), LocalLimit(1, testRelation2))).analyze
+ comparePlans(unionOptimized, unionCorrectAnswer)
+ }
+
+ test("Union: limit to each side with constant-foldable limit expressions") {
+ val unionQuery = Union(testRelation, testRelation2).limit(Add(1, 1))
+ val unionOptimized = Optimize.execute(unionQuery.analyze)
+ val unionCorrectAnswer =
+ Limit(2, Union(LocalLimit(2, testRelation), LocalLimit(2, testRelation2))).analyze
+ comparePlans(unionOptimized, unionCorrectAnswer)
+ }
+
+ test("Union: limit to each side with the new limit number") {
+ val unionQuery = Union(testRelation, testRelation2.limit(3)).limit(1)
+ val unionOptimized = Optimize.execute(unionQuery.analyze)
+ val unionCorrectAnswer =
+ Limit(1, Union(LocalLimit(1, testRelation), LocalLimit(1, testRelation2))).analyze
+ comparePlans(unionOptimized, unionCorrectAnswer)
+ }
+
+ test("Union: no limit to both sides if children having smaller limit values") {
+ val unionQuery = Union(testRelation.limit(1), testRelation2.select('d).limit(1)).limit(2)
+ val unionOptimized = Optimize.execute(unionQuery.analyze)
+ val unionCorrectAnswer =
+ Limit(2, Union(testRelation.limit(1), testRelation2.select('d).limit(1))).analyze
+ comparePlans(unionOptimized, unionCorrectAnswer)
+ }
+
+ test("Union: limit to each sides if children having larger limit values") {
+ val testLimitUnion = Union(testRelation.limit(3), testRelation2.select('d).limit(4))
+ val unionQuery = testLimitUnion.limit(2)
+ val unionOptimized = Optimize.execute(unionQuery.analyze)
+ val unionCorrectAnswer =
+ Limit(2, Union(LocalLimit(2, testRelation), LocalLimit(2, testRelation2.select('d)))).analyze
+ comparePlans(unionOptimized, unionCorrectAnswer)
+ }
+
+ // Outer join ----------------------------------------------------------------------------------
+
+ test("left outer join") {
+ val originalQuery = x.join(y, LeftOuter).limit(1)
+ val optimized = Optimize.execute(originalQuery.analyze)
+ val correctAnswer = Limit(1, LocalLimit(1, y).join(y, LeftOuter)).analyze
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("right outer join") {
+ val originalQuery = x.join(y, RightOuter).limit(1)
+ val optimized = Optimize.execute(originalQuery.analyze)
+ val correctAnswer = Limit(1, x.join(LocalLimit(1, y), RightOuter)).analyze
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("larger limits are not pushed on top of smaller ones in right outer join") {
+ val originalQuery = x.join(y.limit(5), RightOuter).limit(10)
+ val optimized = Optimize.execute(originalQuery.analyze)
+ val correctAnswer = Limit(10, x.join(Limit(5, y), RightOuter)).analyze
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("full outer join where neither side is limited and both sides have same statistics") {
+ assert(x.statistics.sizeInBytes === y.statistics.sizeInBytes)
+ val originalQuery = x.join(y, FullOuter).limit(1)
+ val optimized = Optimize.execute(originalQuery.analyze)
+ val correctAnswer = Limit(1, LocalLimit(1, x).join(y, FullOuter)).analyze
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("full outer join where neither side is limited and left side has larger statistics") {
+ val xBig = testRelation.copy(data = Seq.fill(2)(null)).subquery('x)
+ assert(xBig.statistics.sizeInBytes > y.statistics.sizeInBytes)
+ val originalQuery = xBig.join(y, FullOuter).limit(1)
+ val optimized = Optimize.execute(originalQuery.analyze)
+ val correctAnswer = Limit(1, LocalLimit(1, xBig).join(y, FullOuter)).analyze
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("full outer join where neither side is limited and right side has larger statistics") {
+ val yBig = testRelation.copy(data = Seq.fill(2)(null)).subquery('y)
+ assert(x.statistics.sizeInBytes < yBig.statistics.sizeInBytes)
+ val originalQuery = x.join(yBig, FullOuter).limit(1)
+ val optimized = Optimize.execute(originalQuery.analyze)
+ val correctAnswer = Limit(1, x.join(LocalLimit(1, yBig), FullOuter)).analyze
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("full outer join where both sides are limited") {
+ val originalQuery = x.limit(2).join(y.limit(2), FullOuter).limit(1)
+ val optimized = Optimize.execute(originalQuery.analyze)
+ val correctAnswer = Limit(1, Limit(2, x).join(Limit(2, y), FullOuter)).analyze
+ comparePlans(optimized, correctAnswer)
+ }
+}
+
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 73fd22b38e..042c99db4d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -351,10 +351,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.Sample(lb, ub, withReplacement, seed, planLater(child)) :: Nil
case logical.LocalRelation(output, data) =>
LocalTableScan(output, data) :: Nil
- case logical.Limit(IntegerLiteral(limit), child) =>
- val perPartitionLimit = execution.LocalLimit(limit, planLater(child))
- val globalLimit = execution.GlobalLimit(limit, perPartitionLimit)
- globalLimit :: Nil
+ case logical.LocalLimit(IntegerLiteral(limit), child) =>
+ execution.LocalLimit(limit, planLater(child)) :: Nil
+ case logical.GlobalLimit(IntegerLiteral(limit), child) =>
+ execution.GlobalLimit(limit, planLater(child)) :: Nil
case logical.Union(unionChildren) =>
execution.Union(unionChildren.map(planLater)) :: Nil
case logical.Except(left, right) =>
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
index d7bae913f8..bf5edb4759 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala
@@ -77,8 +77,8 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
case p: Aggregate =>
aggregateToSQL(p)
- case p: Limit =>
- s"${toSQL(p.child)} LIMIT ${p.limitExpr.sql}"
+ case Limit(limitExpr, child) =>
+ s"${toSQL(child)} LIMIT ${limitExpr.sql}"
case p: Filter =>
val whereOrHaving = p.child match {
@@ -203,7 +203,13 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi
wrapChildWithSubquery(plan)
case plan @ Project(_,
- _: Subquery | _: Filter | _: Join | _: MetastoreRelation | OneRowRelation | _: Limit
+ _: Subquery
+ | _: Filter
+ | _: Join
+ | _: MetastoreRelation
+ | OneRowRelation
+ | _: LocalLimit
+ | _: GlobalLimit
) => plan
case plan: Project =>