aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2016-02-08 11:38:21 -0800
committerDavies Liu <davies.liu@gmail.com>2016-02-08 11:38:21 -0800
commit06f0df6df204c4722ff8a6bf909abaa32a715c41 (patch)
tree2e2772173ea0dc38a555c3edf08094aac2cd4ec8
parentedf4a0e62e6fdb849cca4f23a7060da5ec782b07 (diff)
downloadspark-06f0df6df204c4722ff8a6bf909abaa32a715c41.tar.gz
spark-06f0df6df204c4722ff8a6bf909abaa32a715c41.tar.bz2
spark-06f0df6df204c4722ff8a6bf909abaa32a715c41.zip
[SPARK-8964] [SQL] Use Exchange to perform shuffle in Limit
This patch changes the implementation of the physical `Limit` operator so that it relies on the `Exchange` operator to perform data movement rather than directly using `ShuffledRDD`. In addition to improving efficiency, this lays the necessary groundwork for further optimization of limit, such as limit pushdown or whole-stage codegen. At a high-level, this replaces the old physical `Limit` operator with two new operators, `LocalLimit` and `GlobalLimit`. `LocalLimit` performs per-partition limits, while `GlobalLimit` applies the final limit to a single partition; `GlobalLimit`'s declares that its `requiredInputDistribution` is `SinglePartition`, which will cause the planner to use an `Exchange` to perform the appropriate shuffles. Thus, a logical `Limit` appearing in the middle of a query plan will be expanded into `LocalLimit -> Exchange to one partition -> GlobalLimit`. In the old code, calling `someDataFrame.limit(100).collect()` or `someDataFrame.take(100)` would actually skip the shuffle and use a fast-path which used `executeTake()` in order to avoid computing all partitions in case only a small number of rows were requested. This patch preserves this optimization by treating logical `Limit` operators specially when they appear as the terminal operator in a query plan: if a `Limit` is the final operator, then we will plan a special `CollectLimit` physical operator which implements the old `take()`-based logic. In order to be able to match on operators only at the root of the query plan, this patch introduces a special `ReturnAnswer` logical operator which functions similar to `BroadcastHint`: this dummy operator is inserted at the root of the optimized logical plan before invoking the physical planner, allowing the planner to pattern-match on it. Author: Josh Rosen <joshrosen@databricks.com> Closes #7334 from JoshRosen/remove-copy-in-limit.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala130
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala95
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala122
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala4
8 files changed, 223 insertions, 160 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 03a79520cb..57575f9ee0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -25,6 +25,17 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.types._
+/**
+ * When planning take() or collect() operations, this special node that is inserted at the top of
+ * the logical plan before invoking the query planner.
+ *
+ * Rules can pattern-match on this node in order to apply transformations that only take effect
+ * at the top of the logical query plan.
+ */
+case class ReturnAnswer(child: LogicalPlan) extends UnaryNode {
+ override def output: Seq[Attribute] = child.output
+}
+
case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 3770883af1..97f65f18bf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -57,6 +57,69 @@ case class Exchange(
override def output: Seq[Attribute] = child.output
+ private val serializer: Serializer = new UnsafeRowSerializer(child.output.size)
+
+ override protected def doPrepare(): Unit = {
+ // If an ExchangeCoordinator is needed, we register this Exchange operator
+ // to the coordinator when we do prepare. It is important to make sure
+ // we register this operator right before the execution instead of register it
+ // in the constructor because it is possible that we create new instances of
+ // Exchange operators when we transform the physical plan
+ // (then the ExchangeCoordinator will hold references of unneeded Exchanges).
+ // So, we should only call registerExchange just before we start to execute
+ // the plan.
+ coordinator match {
+ case Some(exchangeCoordinator) => exchangeCoordinator.registerExchange(this)
+ case None =>
+ }
+ }
+
+ /**
+ * Returns a [[ShuffleDependency]] that will partition rows of its child based on
+ * the partitioning scheme defined in `newPartitioning`. Those partitions of
+ * the returned ShuffleDependency will be the input of shuffle.
+ */
+ private[sql] def prepareShuffleDependency(): ShuffleDependency[Int, InternalRow, InternalRow] = {
+ Exchange.prepareShuffleDependency(child.execute(), child.output, newPartitioning, serializer)
+ }
+
+ /**
+ * Returns a [[ShuffledRowRDD]] that represents the post-shuffle dataset.
+ * This [[ShuffledRowRDD]] is created based on a given [[ShuffleDependency]] and an optional
+ * partition start indices array. If this optional array is defined, the returned
+ * [[ShuffledRowRDD]] will fetch pre-shuffle partitions based on indices of this array.
+ */
+ private[sql] def preparePostShuffleRDD(
+ shuffleDependency: ShuffleDependency[Int, InternalRow, InternalRow],
+ specifiedPartitionStartIndices: Option[Array[Int]] = None): ShuffledRowRDD = {
+ // If an array of partition start indices is provided, we need to use this array
+ // to create the ShuffledRowRDD. Also, we need to update newPartitioning to
+ // update the number of post-shuffle partitions.
+ specifiedPartitionStartIndices.foreach { indices =>
+ assert(newPartitioning.isInstanceOf[HashPartitioning])
+ newPartitioning = UnknownPartitioning(indices.length)
+ }
+ new ShuffledRowRDD(shuffleDependency, specifiedPartitionStartIndices)
+ }
+
+ protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
+ coordinator match {
+ case Some(exchangeCoordinator) =>
+ val shuffleRDD = exchangeCoordinator.postShuffleRDD(this)
+ assert(shuffleRDD.partitions.length == newPartitioning.numPartitions)
+ shuffleRDD
+ case None =>
+ val shuffleDependency = prepareShuffleDependency()
+ preparePostShuffleRDD(shuffleDependency)
+ }
+ }
+}
+
+object Exchange {
+ def apply(newPartitioning: Partitioning, child: SparkPlan): Exchange = {
+ Exchange(newPartitioning, child, coordinator = None: Option[ExchangeCoordinator])
+ }
+
/**
* Determines whether records must be defensively copied before being sent to the shuffle.
* Several of Spark's shuffle components will buffer deserialized Java objects in memory. The
@@ -82,7 +145,7 @@ case class Exchange(
// passed instead of directly passing the number of partitions in order to guard against
// corner-cases where a partitioner constructed with `numPartitions` partitions may output
// fewer partitions (like RangePartitioner, for example).
- val conf = child.sqlContext.sparkContext.conf
+ val conf = SparkEnv.get.conf
val shuffleManager = SparkEnv.get.shuffleManager
val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager]
val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
@@ -117,30 +180,16 @@ case class Exchange(
}
}
- private val serializer: Serializer = new UnsafeRowSerializer(child.output.size)
-
- override protected def doPrepare(): Unit = {
- // If an ExchangeCoordinator is needed, we register this Exchange operator
- // to the coordinator when we do prepare. It is important to make sure
- // we register this operator right before the execution instead of register it
- // in the constructor because it is possible that we create new instances of
- // Exchange operators when we transform the physical plan
- // (then the ExchangeCoordinator will hold references of unneeded Exchanges).
- // So, we should only call registerExchange just before we start to execute
- // the plan.
- coordinator match {
- case Some(exchangeCoordinator) => exchangeCoordinator.registerExchange(this)
- case None =>
- }
- }
-
/**
* Returns a [[ShuffleDependency]] that will partition rows of its child based on
* the partitioning scheme defined in `newPartitioning`. Those partitions of
* the returned ShuffleDependency will be the input of shuffle.
*/
- private[sql] def prepareShuffleDependency(): ShuffleDependency[Int, InternalRow, InternalRow] = {
- val rdd = child.execute()
+ private[sql] def prepareShuffleDependency(
+ rdd: RDD[InternalRow],
+ outputAttributes: Seq[Attribute],
+ newPartitioning: Partitioning,
+ serializer: Serializer): ShuffleDependency[Int, InternalRow, InternalRow] = {
val part: Partitioner = newPartitioning match {
case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions)
case HashPartitioning(_, n) =>
@@ -160,7 +209,7 @@ case class Exchange(
// We need to use an interpreted ordering here because generated orderings cannot be
// serialized and this ordering needs to be created on the driver in order to be passed into
// Spark core code.
- implicit val ordering = new InterpretedOrdering(sortingExpressions, child.output)
+ implicit val ordering = new InterpretedOrdering(sortingExpressions, outputAttributes)
new RangePartitioner(numPartitions, rddForSampling, ascending = true)
case SinglePartition =>
new Partitioner {
@@ -180,7 +229,7 @@ case class Exchange(
position
}
case h: HashPartitioning =>
- val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, child.output)
+ val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes)
row => projection(row).getInt(0)
case RangePartitioning(_, _) | SinglePartition => identity
case _ => sys.error(s"Exchange not implemented for $newPartitioning")
@@ -211,43 +260,6 @@ case class Exchange(
dependency
}
-
- /**
- * Returns a [[ShuffledRowRDD]] that represents the post-shuffle dataset.
- * This [[ShuffledRowRDD]] is created based on a given [[ShuffleDependency]] and an optional
- * partition start indices array. If this optional array is defined, the returned
- * [[ShuffledRowRDD]] will fetch pre-shuffle partitions based on indices of this array.
- */
- private[sql] def preparePostShuffleRDD(
- shuffleDependency: ShuffleDependency[Int, InternalRow, InternalRow],
- specifiedPartitionStartIndices: Option[Array[Int]] = None): ShuffledRowRDD = {
- // If an array of partition start indices is provided, we need to use this array
- // to create the ShuffledRowRDD. Also, we need to update newPartitioning to
- // update the number of post-shuffle partitions.
- specifiedPartitionStartIndices.foreach { indices =>
- assert(newPartitioning.isInstanceOf[HashPartitioning])
- newPartitioning = UnknownPartitioning(indices.length)
- }
- new ShuffledRowRDD(shuffleDependency, specifiedPartitionStartIndices)
- }
-
- protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
- coordinator match {
- case Some(exchangeCoordinator) =>
- val shuffleRDD = exchangeCoordinator.postShuffleRDD(this)
- assert(shuffleRDD.partitions.length == newPartitioning.numPartitions)
- shuffleRDD
- case None =>
- val shuffleDependency = prepareShuffleDependency()
- preparePostShuffleRDD(shuffleDependency)
- }
- }
-}
-
-object Exchange {
- def apply(newPartitioning: Partitioning, child: SparkPlan): Exchange = {
- Exchange(newPartitioning, child, coordinator = None: Option[ExchangeCoordinator])
- }
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index 107570f9db..8616fe3170 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
/**
* The primary workflow for executing relational queries using Spark. Designed to allow easy
@@ -44,7 +44,7 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) {
lazy val sparkPlan: SparkPlan = {
SQLContext.setActive(sqlContext)
- sqlContext.planner.plan(optimizedPlan).next()
+ sqlContext.planner.plan(ReturnAnswer(optimizedPlan)).next()
}
// executedPlan should not be used to initialize any SparkPlan. It should be
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 830bb011be..ee392e4e8d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -338,8 +338,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.Sample(lb, ub, withReplacement, seed, planLater(child)) :: Nil
case logical.LocalRelation(output, data) =>
LocalTableScan(output, data) :: Nil
+ case logical.ReturnAnswer(logical.Limit(IntegerLiteral(limit), child)) =>
+ execution.CollectLimit(limit, planLater(child)) :: Nil
case logical.Limit(IntegerLiteral(limit), child) =>
- execution.Limit(limit, planLater(child)) :: Nil
+ val perPartitionLimit = execution.LocalLimit(limit, planLater(child))
+ val globalLimit = execution.GlobalLimit(limit, perPartitionLimit)
+ globalLimit :: Nil
case logical.Union(unionChildren) =>
execution.Union(unionChildren.map(planLater)) :: Nil
case logical.Except(left, right) =>
@@ -358,6 +362,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil
case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil
case BroadcastHint(child) => planLater(child) :: Nil
+ case logical.ReturnAnswer(child) => planLater(child) :: Nil
case _ => Nil
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 6e51c4d848..f63e8a9b6d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -17,16 +17,13 @@
package org.apache.spark.sql.execution
-import org.apache.spark.{HashPartitioner, SparkEnv}
-import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD, ShuffledRDD}
-import org.apache.spark.shuffle.sort.SortShuffleManager
+import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.LongType
-import org.apache.spark.util.MutablePair
import org.apache.spark.util.random.PoissonSampler
case class Project(projectList: Seq[NamedExpression], child: SparkPlan)
@@ -307,96 +304,6 @@ case class Union(children: Seq[SparkPlan]) extends SparkPlan {
}
/**
- * Take the first limit elements. Note that the implementation is different depending on whether
- * this is a terminal operator or not. If it is terminal and is invoked using executeCollect,
- * this operator uses something similar to Spark's take method on the Spark driver. If it is not
- * terminal or is invoked using execute, we first take the limit on each partition, and then
- * repartition all the data to a single partition to compute the global limit.
- */
-case class Limit(limit: Int, child: SparkPlan)
- extends UnaryNode {
- // TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan:
- // partition local limit -> exchange into one partition -> partition local limit again
-
- /** We must copy rows when sort based shuffle is on */
- private def sortBasedShuffleOn = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager]
-
- override def output: Seq[Attribute] = child.output
- override def outputPartitioning: Partitioning = SinglePartition
-
- override def executeCollect(): Array[InternalRow] = child.executeTake(limit)
-
- protected override def doExecute(): RDD[InternalRow] = {
- val rdd: RDD[_ <: Product2[Boolean, InternalRow]] = if (sortBasedShuffleOn) {
- child.execute().mapPartitionsInternal { iter =>
- iter.take(limit).map(row => (false, row.copy()))
- }
- } else {
- child.execute().mapPartitionsInternal { iter =>
- val mutablePair = new MutablePair[Boolean, InternalRow]()
- iter.take(limit).map(row => mutablePair.update(false, row))
- }
- }
- val part = new HashPartitioner(1)
- val shuffled = new ShuffledRDD[Boolean, InternalRow, InternalRow](rdd, part)
- shuffled.setSerializer(new SparkSqlSerializer(child.sqlContext.sparkContext.getConf))
- shuffled.mapPartitionsInternal(_.take(limit).map(_._2))
- }
-}
-
-/**
- * Take the first limit elements as defined by the sortOrder, and do projection if needed.
- * This is logically equivalent to having a [[Limit]] operator after a [[Sort]] operator,
- * or having a [[Project]] operator between them.
- * This could have been named TopK, but Spark's top operator does the opposite in ordering
- * so we name it TakeOrdered to avoid confusion.
- */
-case class TakeOrderedAndProject(
- limit: Int,
- sortOrder: Seq[SortOrder],
- projectList: Option[Seq[NamedExpression]],
- child: SparkPlan) extends UnaryNode {
-
- override def output: Seq[Attribute] = {
- val projectOutput = projectList.map(_.map(_.toAttribute))
- projectOutput.getOrElse(child.output)
- }
-
- override def outputPartitioning: Partitioning = SinglePartition
-
- // We need to use an interpreted ordering here because generated orderings cannot be serialized
- // and this ordering needs to be created on the driver in order to be passed into Spark core code.
- private val ord: InterpretedOrdering = new InterpretedOrdering(sortOrder, child.output)
-
- private def collectData(): Array[InternalRow] = {
- val data = child.execute().map(_.copy()).takeOrdered(limit)(ord)
- if (projectList.isDefined) {
- val proj = UnsafeProjection.create(projectList.get, child.output)
- data.map(r => proj(r).copy())
- } else {
- data
- }
- }
-
- override def executeCollect(): Array[InternalRow] = {
- collectData()
- }
-
- // TODO: Terminal split should be implemented differently from non-terminal split.
- // TODO: Pick num splits based on |limit|.
- protected override def doExecute(): RDD[InternalRow] = sparkContext.makeRDD(collectData(), 1)
-
- override def outputOrdering: Seq[SortOrder] = sortOrder
-
- override def simpleString: String = {
- val orderByString = sortOrder.mkString("[", ",", "]")
- val outputString = output.mkString("[", ",", "]")
-
- s"TakeOrderedAndProject(limit=$limit, orderBy=$orderByString, output=$outputString)"
- }
-}
-
-/**
* Return a new RDD that has exactly `numPartitions` partitions.
* Similar to coalesce defined on an [[RDD]], this operation results in a narrow dependency, e.g.
* if you go from 1000 partitions to 100 partitions, there will not be a shuffle, instead each of
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
new file mode 100644
index 0000000000..256f4228ae
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical._
+
+
+/**
+ * Take the first `limit` elements and collect them to a single partition.
+ *
+ * This operator will be used when a logical `Limit` operation is the final operator in an
+ * logical plan, which happens when the user is collecting results back to the driver.
+ */
+case class CollectLimit(limit: Int, child: SparkPlan) extends UnaryNode {
+ override def output: Seq[Attribute] = child.output
+ override def outputPartitioning: Partitioning = SinglePartition
+ override def executeCollect(): Array[InternalRow] = child.executeTake(limit)
+ private val serializer: Serializer = new UnsafeRowSerializer(child.output.size)
+ protected override def doExecute(): RDD[InternalRow] = {
+ val shuffled = new ShuffledRowRDD(
+ Exchange.prepareShuffleDependency(child.execute(), child.output, SinglePartition, serializer))
+ shuffled.mapPartitionsInternal(_.take(limit))
+ }
+}
+
+/**
+ * Helper trait which defines methods that are shared by both [[LocalLimit]] and [[GlobalLimit]].
+ */
+trait BaseLimit extends UnaryNode {
+ val limit: Int
+ override def output: Seq[Attribute] = child.output
+ override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+ override def outputPartitioning: Partitioning = child.outputPartitioning
+ protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter =>
+ iter.take(limit)
+ }
+}
+
+/**
+ * Take the first `limit` elements of each child partition, but do not collect or shuffle them.
+ */
+case class LocalLimit(limit: Int, child: SparkPlan) extends BaseLimit {
+ override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+}
+
+/**
+ * Take the first `limit` elements of the child's single output partition.
+ */
+case class GlobalLimit(limit: Int, child: SparkPlan) extends BaseLimit {
+ override def requiredChildDistribution: List[Distribution] = AllTuples :: Nil
+}
+
+/**
+ * Take the first limit elements as defined by the sortOrder, and do projection if needed.
+ * This is logically equivalent to having a Limit operator after a [[Sort]] operator,
+ * or having a [[Project]] operator between them.
+ * This could have been named TopK, but Spark's top operator does the opposite in ordering
+ * so we name it TakeOrdered to avoid confusion.
+ */
+case class TakeOrderedAndProject(
+ limit: Int,
+ sortOrder: Seq[SortOrder],
+ projectList: Option[Seq[NamedExpression]],
+ child: SparkPlan) extends UnaryNode {
+
+ override def output: Seq[Attribute] = {
+ val projectOutput = projectList.map(_.map(_.toAttribute))
+ projectOutput.getOrElse(child.output)
+ }
+
+ override def outputPartitioning: Partitioning = SinglePartition
+
+ // We need to use an interpreted ordering here because generated orderings cannot be serialized
+ // and this ordering needs to be created on the driver in order to be passed into Spark core code.
+ private val ord: InterpretedOrdering = new InterpretedOrdering(sortOrder, child.output)
+
+ private def collectData(): Array[InternalRow] = {
+ val data = child.execute().map(_.copy()).takeOrdered(limit)(ord)
+ if (projectList.isDefined) {
+ val proj = UnsafeProjection.create(projectList.get, child.output)
+ data.map(r => proj(r).copy())
+ } else {
+ data
+ }
+ }
+
+ override def executeCollect(): Array[InternalRow] = {
+ collectData()
+ }
+
+ // TODO: Terminal split should be implemented differently from non-terminal split.
+ // TODO: Pick num splits based on |limit|.
+ protected override def doExecute(): RDD[InternalRow] = sparkContext.makeRDD(collectData(), 1)
+
+ override def outputOrdering: Seq[SortOrder] = sortOrder
+
+ override def simpleString: String = {
+ val orderByString = sortOrder.mkString("[", ",", "]")
+ val outputString = output.mkString("[", ",", "]")
+
+ s"TakeOrderedAndProject(limit=$limit, orderBy=$orderByString, output=$outputString)"
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index adaeb513bc..a64ad4038c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -181,6 +181,12 @@ class PlannerSuite extends SharedSQLContext {
}
}
+ test("terminal limits use CollectLimit") {
+ val query = testData.select('value).limit(2)
+ val planned = query.queryExecution.sparkPlan
+ assert(planned.isInstanceOf[CollectLimit])
+ }
+
test("PartitioningCollection") {
withTempTable("normal", "small", "tiny") {
testData.registerTempTable("normal")
@@ -200,7 +206,7 @@ class PlannerSuite extends SharedSQLContext {
).queryExecution.executedPlan.collect {
case exchange: Exchange => exchange
}.length
- assert(numExchanges === 3)
+ assert(numExchanges === 5)
}
{
@@ -215,7 +221,7 @@ class PlannerSuite extends SharedSQLContext {
).queryExecution.executedPlan.collect {
case exchange: Exchange => exchange
}.length
- assert(numExchanges === 3)
+ assert(numExchanges === 5)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
index 6259453da2..cb6d68dc3a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
@@ -56,8 +56,8 @@ class SortSuite extends SparkPlanTest with SharedSQLContext {
test("sort followed by limit") {
checkThatPlansAgree(
(1 to 100).map(v => Tuple1(v)).toDF("a"),
- (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child = child)),
- (child: SparkPlan) => Limit(10, ReferenceSort('a.asc :: Nil, global = true, child)),
+ (child: SparkPlan) => GlobalLimit(10, Sort('a.asc :: Nil, global = true, child = child)),
+ (child: SparkPlan) => GlobalLimit(10, ReferenceSort('a.asc :: Nil, global = true, child)),
sortAnswers = false
)
}