aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2016-02-10 11:00:38 -0800
committerJosh Rosen <joshrosen@databricks.com>2016-02-10 11:00:38 -0800
commit5cf20598cec4e60b53c0e40dc4243f436396e7fc (patch)
treeb29ee05e5e5b165c78dcea076ad0611788aabc62
parent80cb963ad963e26c3a7f8388bdd4ffd5e99aad1a (diff)
downloadspark-5cf20598cec4e60b53c0e40dc4243f436396e7fc.tar.gz
spark-5cf20598cec4e60b53c0e40dc4243f436396e7fc.tar.bz2
spark-5cf20598cec4e60b53c0e40dc4243f436396e7fc.zip
[SPARK-13254][SQL] Fix planning of TakeOrderedAndProject operator
The patch for SPARK-8964 ("use Exchange to perform shuffle in Limit" / #7334) inadvertently broke the planning of the TakeOrderedAndProject operator: because ReturnAnswer was the new root of the query plan, the TakeOrderedAndProject rule was unable to match before BasicOperators. This patch fixes this by moving the `TakeOrderedAndCollect` and `CollectLimit` rules into the same strategy. In addition, I made changes to the TakeOrderedAndProject operator in order to make its `doExecute()` method lazy and added a new TakeOrderedAndProjectSuite which tests the new code path. /cc davies and marmbrus for review. Author: Josh Rosen <joshrosen@databricks.com> Closes #11145 from JoshRosen/take-ordered-and-project-fix.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala40
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala30
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala44
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala85
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala2
6 files changed, 159 insertions, 44 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
index 6e9a4df828..d1569a4ec2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
@@ -31,7 +31,7 @@ class SparkPlanner(val sqlContext: SQLContext) extends SparkStrategies {
sqlContext.experimental.extraStrategies ++ (
DataSourceStrategy ::
DDLStrategy ::
- TakeOrderedAndProject ::
+ SpecialLimits ::
Aggregation ::
LeftSemiJoin ::
EquiJoinSelection ::
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index ee392e4e8d..598ddd7161 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -33,6 +33,31 @@ import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight}
private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
self: SparkPlanner =>
+ /**
+ * Plans special cases of limit operators.
+ */
+ object SpecialLimits extends Strategy {
+ override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case logical.ReturnAnswer(rootPlan) => rootPlan match {
+ case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) =>
+ execution.TakeOrderedAndProject(limit, order, None, planLater(child)) :: Nil
+ case logical.Limit(
+ IntegerLiteral(limit),
+ logical.Project(projectList, logical.Sort(order, true, child))) =>
+ execution.TakeOrderedAndProject(limit, order, Some(projectList), planLater(child)) :: Nil
+ case logical.Limit(IntegerLiteral(limit), child) =>
+ execution.CollectLimit(limit, planLater(child)) :: Nil
+ case other => planLater(other) :: Nil
+ }
+ case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) =>
+ execution.TakeOrderedAndProject(limit, order, None, planLater(child)) :: Nil
+ case logical.Limit(
+ IntegerLiteral(limit), logical.Project(projectList, logical.Sort(order, true, child))) =>
+ execution.TakeOrderedAndProject(limit, order, Some(projectList), planLater(child)) :: Nil
+ case _ => Nil
+ }
+ }
+
object LeftSemiJoin extends Strategy with PredicateHelper {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case ExtractEquiJoinKeys(
@@ -264,18 +289,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
protected lazy val singleRowRdd = sparkContext.parallelize(Seq(InternalRow()), 1)
- object TakeOrderedAndProject extends Strategy {
- def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) =>
- execution.TakeOrderedAndProject(limit, order, None, planLater(child)) :: Nil
- case logical.Limit(
- IntegerLiteral(limit),
- logical.Project(projectList, logical.Sort(order, true, child))) =>
- execution.TakeOrderedAndProject(limit, order, Some(projectList), planLater(child)) :: Nil
- case _ => Nil
- }
- }
-
object InMemoryScans extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case PhysicalOperation(projectList, filters, mem: InMemoryRelation) =>
@@ -338,8 +351,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.Sample(lb, ub, withReplacement, seed, planLater(child)) :: Nil
case logical.LocalRelation(output, data) =>
LocalTableScan(output, data) :: Nil
- case logical.ReturnAnswer(logical.Limit(IntegerLiteral(limit), child)) =>
- execution.CollectLimit(limit, planLater(child)) :: Nil
case logical.Limit(IntegerLiteral(limit), child) =>
val perPartitionLimit = execution.LocalLimit(limit, planLater(child))
val globalLimit = execution.GlobalLimit(limit, perPartitionLimit)
@@ -362,7 +373,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil
case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil
case BroadcastHint(child) => planLater(child) :: Nil
- case logical.ReturnAnswer(child) => planLater(child) :: Nil
case _ => Nil
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
index 256f4228ae..04daf9d0ce 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
@@ -83,8 +83,7 @@ case class TakeOrderedAndProject(
child: SparkPlan) extends UnaryNode {
override def output: Seq[Attribute] = {
- val projectOutput = projectList.map(_.map(_.toAttribute))
- projectOutput.getOrElse(child.output)
+ projectList.map(_.map(_.toAttribute)).getOrElse(child.output)
}
override def outputPartitioning: Partitioning = SinglePartition
@@ -93,7 +92,7 @@ case class TakeOrderedAndProject(
// and this ordering needs to be created on the driver in order to be passed into Spark core code.
private val ord: InterpretedOrdering = new InterpretedOrdering(sortOrder, child.output)
- private def collectData(): Array[InternalRow] = {
+ override def executeCollect(): Array[InternalRow] = {
val data = child.execute().map(_.copy()).takeOrdered(limit)(ord)
if (projectList.isDefined) {
val proj = UnsafeProjection.create(projectList.get, child.output)
@@ -103,13 +102,26 @@ case class TakeOrderedAndProject(
}
}
- override def executeCollect(): Array[InternalRow] = {
- collectData()
- }
+ private val serializer: Serializer = new UnsafeRowSerializer(child.output.size)
- // TODO: Terminal split should be implemented differently from non-terminal split.
- // TODO: Pick num splits based on |limit|.
- protected override def doExecute(): RDD[InternalRow] = sparkContext.makeRDD(collectData(), 1)
+ protected override def doExecute(): RDD[InternalRow] = {
+ val localTopK: RDD[InternalRow] = {
+ child.execute().map(_.copy()).mapPartitions { iter =>
+ org.apache.spark.util.collection.Utils.takeOrdered(iter, limit)(ord)
+ }
+ }
+ val shuffled = new ShuffledRowRDD(
+ Exchange.prepareShuffleDependency(localTopK, child.output, SinglePartition, serializer))
+ shuffled.mapPartitions { iter =>
+ val topK = org.apache.spark.util.collection.Utils.takeOrdered(iter.map(_.copy()), limit)(ord)
+ if (projectList.isDefined) {
+ val proj = UnsafeProjection.create(projectList.get, child.output)
+ topK.map(r => proj(r))
+ } else {
+ topK
+ }
+ }
+ }
override def outputOrdering: Seq[SortOrder] = sortOrder
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index a64ad4038c..250ce8f866 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition}
import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.execution.columnar.{InMemoryColumnarTableScan, InMemoryRelation}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
@@ -161,30 +162,37 @@ class PlannerSuite extends SharedSQLContext {
}
}
- test("efficient limit -> project -> sort") {
- {
- val query =
- testData.select('key, 'value).sort('key).limit(2).logicalPlan
- val planned = sqlContext.planner.TakeOrderedAndProject(query)
- assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject])
- assert(planned.head.output === testData.select('key, 'value).logicalPlan.output)
- }
+ test("efficient terminal limit -> sort should use TakeOrderedAndProject") {
+ val query = testData.select('key, 'value).sort('key).limit(2)
+ val planned = query.queryExecution.executedPlan
+ assert(planned.isInstanceOf[execution.TakeOrderedAndProject])
+ assert(planned.output === testData.select('key, 'value).logicalPlan.output)
+ }
- {
- // We need to make sure TakeOrderedAndProject's output is correct when we push a project
- // into it.
- val query =
- testData.select('key, 'value).sort('key).select('value, 'key).limit(2).logicalPlan
- val planned = sqlContext.planner.TakeOrderedAndProject(query)
- assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject])
- assert(planned.head.output === testData.select('value, 'key).logicalPlan.output)
- }
+ test("terminal limit -> project -> sort should use TakeOrderedAndProject") {
+ val query = testData.select('key, 'value).sort('key).select('value, 'key).limit(2)
+ val planned = query.queryExecution.executedPlan
+ assert(planned.isInstanceOf[execution.TakeOrderedAndProject])
+ assert(planned.output === testData.select('value, 'key).logicalPlan.output)
}
- test("terminal limits use CollectLimit") {
+ test("terminal limits that are not handled by TakeOrderedAndProject should use CollectLimit") {
val query = testData.select('value).limit(2)
val planned = query.queryExecution.sparkPlan
assert(planned.isInstanceOf[CollectLimit])
+ assert(planned.output === testData.select('value).logicalPlan.output)
+ }
+
+ test("TakeOrderedAndProject can appear in the middle of plans") {
+ val query = testData.select('key, 'value).sort('key).limit(2).filter('key === 3)
+ val planned = query.queryExecution.executedPlan
+ assert(planned.find(_.isInstanceOf[TakeOrderedAndProject]).isDefined)
+ }
+
+ test("CollectLimit can appear in the middle of a plan when caching is used") {
+ val query = testData.select('key, 'value).limit(2).cache()
+ val planned = query.queryExecution.optimizedPlan.asInstanceOf[InMemoryRelation]
+ assert(planned.child.isInstanceOf[CollectLimit])
}
test("PartitioningCollection") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala
new file mode 100644
index 0000000000..03cb04a5f7
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import scala.util.Random
+
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.Literal
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types._
+
+
+class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext {
+
+ private var rand: Random = _
+ private var seed: Long = 0
+
+ protected override def beforeAll(): Unit = {
+ super.beforeAll()
+ seed = System.currentTimeMillis()
+ rand = new Random(seed)
+ }
+
+ private def generateRandomInputData(): DataFrame = {
+ val schema = new StructType()
+ .add("a", IntegerType, nullable = false)
+ .add("b", IntegerType, nullable = false)
+ val inputData = Seq.fill(10000)(Row(rand.nextInt(), rand.nextInt()))
+ sqlContext.createDataFrame(sparkContext.parallelize(Random.shuffle(inputData), 10), schema)
+ }
+
+ /**
+ * Adds a no-op filter to the child plan in order to prevent executeCollect() from being
+ * called directly on the child plan.
+ */
+ private def noOpFilter(plan: SparkPlan): SparkPlan = Filter(Literal(true), plan)
+
+ val limit = 250
+ val sortOrder = 'a.desc :: 'b.desc :: Nil
+
+ test("TakeOrderedAndProject.doExecute without project") {
+ withClue(s"seed = $seed") {
+ checkThatPlansAgree(
+ generateRandomInputData(),
+ input =>
+ noOpFilter(TakeOrderedAndProject(limit, sortOrder, None, input)),
+ input =>
+ GlobalLimit(limit,
+ LocalLimit(limit,
+ Sort(sortOrder, global = true, input))),
+ sortAnswers = false)
+ }
+ }
+
+ test("TakeOrderedAndProject.doExecute with project") {
+ withClue(s"seed = $seed") {
+ checkThatPlansAgree(
+ generateRandomInputData(),
+ input =>
+ noOpFilter(TakeOrderedAndProject(limit, sortOrder, Some(Seq(input.output.last)), input)),
+ input =>
+ GlobalLimit(limit,
+ LocalLimit(limit,
+ Project(Seq(input.output.last),
+ Sort(sortOrder, global = true, input)))),
+ sortAnswers = false)
+ }
+ }
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index 05863ae183..2433b54ffc 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -559,7 +559,7 @@ class HiveContext private[hive](
HiveCommandStrategy(self),
HiveDDLStrategy,
DDLStrategy,
- TakeOrderedAndProject,
+ SpecialLimits,
InMemoryScans,
HiveTableScans,
DataSinks,