aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSameer Agarwal <sameer@databricks.com>2014-06-11 12:01:04 -0700
committerMichael Armbrust <michael@databricks.com>2014-06-11 12:01:04 -0700
commit4107cce58c41160a0dc20339621eacdf8a8b1191 (patch)
treece7fce598c61190f702e8baed2517d7efc873e0e
parent4d5c12aa1c54c49377a4bafe3bcc4993d5e1a552 (diff)
downloadspark-4107cce58c41160a0dc20339621eacdf8a8b1191.tar.gz
spark-4107cce58c41160a0dc20339621eacdf8a8b1191.tar.bz2
spark-4107cce58c41160a0dc20339621eacdf8a8b1191.zip
[SPARK-2042] Prevent unnecessary shuffle triggered by take()
This PR implements `take()` on a `SchemaRDD` by inserting a logical limit that is followed by a `collect()`. This is also accompanied by adding a catalyst optimizer rule for collapsing adjacent limits. Doing so prevents an unnecessary shuffle that is sometimes triggered by `take()`. Author: Sameer Agarwal <sameer@databricks.com> Closes #1048 from sameeragarwal/master and squashes the following commits: 3eeb848 [Sameer Agarwal] Fixing Tests 1b76ff1 [Sameer Agarwal] Deprecating limit(limitExpr: Expression) in v1.1.0 b723ac4 [Sameer Agarwal] Added limit folding tests a0ff7c4 [Sameer Agarwal] Adding catalyst rule to fold two consecutive limits 8d42d03 [Sameer Agarwal] Implement trigger() as limit() followed by collect()
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala13
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala71
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala12
5 files changed, 97 insertions, 5 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 3cf163f9a9..d177339d40 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -175,6 +175,8 @@ package object dsl {
def where(condition: Expression) = Filter(condition, logicalPlan)
+ def limit(limitExpr: Expression) = Limit(limitExpr, logicalPlan)
+
def join(
otherPlan: LogicalPlan,
joinType: JoinType = Inner,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index e41fd2db74..28d1aa2e3a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -29,6 +29,8 @@ import org.apache.spark.sql.catalyst.types._
object Optimizer extends RuleExecutor[LogicalPlan] {
val batches =
+ Batch("Combine Limits", FixedPoint(100),
+ CombineLimits) ::
Batch("ConstantFolding", FixedPoint(100),
NullPropagation,
ConstantFolding,
@@ -362,3 +364,14 @@ object SimplifyCasts extends Rule[LogicalPlan] {
case Cast(e, dataType) if e.dataType == dataType => e
}
}
+
+/**
+ * Combines two adjacent [[catalyst.plans.logical.Limit Limit]] operators into one, merging the
+ * expressions into one single expression.
+ */
+object CombineLimits extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case ll @ Limit(le, nl @ Limit(ne, grandChild)) =>
+ Limit(If(LessThan(ne, le), ne, le), grandChild)
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index d3347b622f..b777cf4249 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -135,9 +135,9 @@ case class Aggregate(
def references = (groupingExpressions ++ aggregateExpressions).flatMap(_.references).toSet
}
-case class Limit(limit: Expression, child: LogicalPlan) extends UnaryNode {
+case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
def output = child.output
- def references = limit.references
+ def references = limitExpr.references
}
case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
new file mode 100644
index 0000000000..714f01843c
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.dsl.expressions._
+
+class CombiningLimitsSuite extends OptimizerTest {
+
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("Combine Limit", FixedPoint(2),
+ CombineLimits) ::
+ Batch("Constant Folding", FixedPoint(3),
+ NullPropagation,
+ ConstantFolding,
+ BooleanSimplification) :: Nil
+ }
+
+ val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
+
+ test("limits: combines two limits") {
+ val originalQuery =
+ testRelation
+ .select('a)
+ .limit(10)
+ .limit(5)
+
+ val optimized = Optimize(originalQuery.analyze)
+ val correctAnswer =
+ testRelation
+ .select('a)
+ .limit(5).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("limits: combines three limits") {
+ val originalQuery =
+ testRelation
+ .select('a)
+ .limit(2)
+ .limit(7)
+ .limit(5)
+
+ val optimized = Optimize(originalQuery.analyze)
+ val correctAnswer =
+ testRelation
+ .select('a)
+ .limit(2).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index 8855c4e876..7ad8edf5a5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -178,14 +178,18 @@ class SchemaRDD(
def orderBy(sortExprs: SortOrder*): SchemaRDD =
new SchemaRDD(sqlContext, Sort(sortExprs, logicalPlan))
+ @deprecated("use limit with integer argument", "1.1.0")
+ def limit(limitExpr: Expression): SchemaRDD =
+ new SchemaRDD(sqlContext, Limit(limitExpr, logicalPlan))
+
/**
- * Limits the results by the given expressions.
+ * Limits the results by the given integer.
* {{{
* schemaRDD.limit(10)
* }}}
*/
- def limit(limitExpr: Expression): SchemaRDD =
- new SchemaRDD(sqlContext, Limit(limitExpr, logicalPlan))
+ def limit(limitNum: Int): SchemaRDD =
+ new SchemaRDD(sqlContext, Limit(Literal(limitNum), logicalPlan))
/**
* Performs a grouping followed by an aggregation.
@@ -374,6 +378,8 @@ class SchemaRDD(
override def collect(): Array[Row] = queryExecution.executedPlan.executeCollect()
+ override def take(num: Int): Array[Row] = limit(num).collect()
+
// =======================================================================
// Base RDD functions that do NOT change schema
// =======================================================================