diff options
author | Sameer Agarwal <sameer@databricks.com> | 2014-06-11 12:01:04 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2014-06-11 12:01:04 -0700 |
commit | 4107cce58c41160a0dc20339621eacdf8a8b1191 (patch) | |
tree | ce7fce598c61190f702e8baed2517d7efc873e0e /sql/catalyst | |
parent | 4d5c12aa1c54c49377a4bafe3bcc4993d5e1a552 (diff) | |
download | spark-4107cce58c41160a0dc20339621eacdf8a8b1191.tar.gz spark-4107cce58c41160a0dc20339621eacdf8a8b1191.tar.bz2 spark-4107cce58c41160a0dc20339621eacdf8a8b1191.zip |
[SPARK-2042] Prevent unnecessary shuffle triggered by take()
This PR implements `take()` on a `SchemaRDD` by inserting a logical limit that is followed by a `collect()`. This is also accompanied by adding a catalyst optimizer rule for collapsing adjacent limits. Doing so prevents an unnecessary shuffle that is sometimes triggered by `take()`.
Author: Sameer Agarwal <sameer@databricks.com>
Closes #1048 from sameeragarwal/master and squashes the following commits:
3eeb848 [Sameer Agarwal] Fixing Tests
1b76ff1 [Sameer Agarwal] Deprecating limit(limitExpr: Expression) in v1.1.0
b723ac4 [Sameer Agarwal] Added limit folding tests
a0ff7c4 [Sameer Agarwal] Adding catalyst rule to fold two consecutive limits
8d42d03 [Sameer Agarwal] Implement trigger() as limit() followed by collect()
Diffstat (limited to 'sql/catalyst')
4 files changed, 88 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 3cf163f9a9..d177339d40 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -175,6 +175,8 @@ package object dsl { def where(condition: Expression) = Filter(condition, logicalPlan) + def limit(limitExpr: Expression) = Limit(limitExpr, logicalPlan) + def join( otherPlan: LogicalPlan, joinType: JoinType = Inner, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e41fd2db74..28d1aa2e3a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -29,6 +29,8 @@ import org.apache.spark.sql.catalyst.types._ object Optimizer extends RuleExecutor[LogicalPlan] { val batches = + Batch("Combine Limits", FixedPoint(100), + CombineLimits) :: Batch("ConstantFolding", FixedPoint(100), NullPropagation, ConstantFolding, @@ -362,3 +364,14 @@ object SimplifyCasts extends Rule[LogicalPlan] { case Cast(e, dataType) if e.dataType == dataType => e } } + +/** + * Combines two adjacent [[catalyst.plans.logical.Limit Limit]] operators into one, merging the + * expressions into one single expression. + */ +object CombineLimits extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case ll @ Limit(le, nl @ Limit(ne, grandChild)) => + Limit(If(LessThan(ne, le), ne, le), grandChild) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index d3347b622f..b777cf4249 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -135,9 +135,9 @@ case class Aggregate( def references = (groupingExpressions ++ aggregateExpressions).flatMap(_.references).toSet } -case class Limit(limit: Expression, child: LogicalPlan) extends UnaryNode { +case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { def output = child.output - def references = limit.references + def references = limitExpr.references } case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala new file mode 100644 index 0000000000..714f01843c --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ + +class CombiningLimitsSuite extends OptimizerTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Combine Limit", FixedPoint(2), + CombineLimits) :: + Batch("Constant Folding", FixedPoint(3), + NullPropagation, + ConstantFolding, + BooleanSimplification) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + test("limits: combines two limits") { + val originalQuery = + testRelation + .select('a) + .limit(10) + .limit(5) + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a) + .limit(5).analyze + + comparePlans(optimized, correctAnswer) + } + + test("limits: combines three limits") { + val originalQuery = + testRelation + .select('a) + .limit(2) + .limit(7) + .limit(5) + + val optimized = Optimize(originalQuery.analyze) + val correctAnswer = + testRelation + .select('a) + .limit(2).analyze + + comparePlans(optimized, correctAnswer) + } +} |