diff options
author | Liang-Chi Hsieh <viirya@gmail.com> | 2016-03-01 08:43:02 -0800 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2016-03-01 08:43:02 -0800 |
commit | c43899a04e4de18e238a1761bf4fe9f54e182320 (patch) | |
tree | 34c22b64f5034e14f7fe33b2975469ee0e09d2f5 /sql/core/src/main/scala/org/apache | |
parent | 12a2a57e1af21da0aa4275971365d76a8fc84a43 (diff) | |
download | spark-c43899a04e4de18e238a1761bf4fe9f54e182320.tar.gz spark-c43899a04e4de18e238a1761bf4fe9f54e182320.tar.bz2 spark-c43899a04e4de18e238a1761bf4fe9f54e182320.zip |
[SPARK-13511] [SQL] Add wholestage codegen for limit
JIRA: https://issues.apache.org/jira/browse/SPARK-13511
## What changes were proposed in this pull request?
Current limit operator doesn't support wholestage codegen. This is open to add support for it.
In the `doConsume` of `GlobalLimit` and `LocalLimit`, we use a count term to count the processed rows. Once the row numbers catches the limit number, we set the variable `stopEarly` of `BufferedRowIterator` newly added in this pr to `true` that indicates we want to stop processing remaining rows. Then when the wholestage codegen framework checks `shouldStop()`, it will stop the processing of the row iterator.
Before this, the executed plan for a query `sqlContext.range(N).limit(100).groupBy().sum()` is:
TungstenAggregate(key=[], functions=[(sum(id#5L),mode=Final,isDistinct=false)], output=[sum(id)#6L])
+- TungstenAggregate(key=[], functions=[(sum(id#5L),mode=Partial,isDistinct=false)], output=[sum#9L])
+- GlobalLimit 100
+- Exchange SinglePartition, None
+- LocalLimit 100
+- Range 0, 1, 1, 524288000, [id#5L]
After add wholestage codegen support:
WholeStageCodegen
: +- TungstenAggregate(key=[], functions=[(sum(id#40L),mode=Final,isDistinct=false)], output=[sum(id)#41L])
: +- TungstenAggregate(key=[], functions=[(sum(id#40L),mode=Partial,isDistinct=false)], output=[sum#44L])
: +- GlobalLimit 100
: +- INPUT
+- Exchange SinglePartition, None
+- WholeStageCodegen
: +- LocalLimit 100
: +- Range 0, 1, 1, 524288000, [id#40L]
## How was this patch tested?
A test is added into BenchmarkWholeStageCodegen.
Author: Liang-Chi Hsieh <viirya@gmail.com>
Closes #11391 from viirya/wholestage-limit.
Diffstat (limited to 'sql/core/src/main/scala/org/apache')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala | 35 |
1 files changed, 33 insertions, 2 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index cd543d4195..45175d36d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -21,9 +21,10 @@ import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LazilyGeneratedOrdering} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.exchange.ShuffleExchange +import org.apache.spark.sql.execution.metric.SQLMetrics /** @@ -48,7 +49,7 @@ case class CollectLimit(limit: Int, child: SparkPlan) extends UnaryNode { /** * Helper trait which defines methods that are shared by both [[LocalLimit]] and [[GlobalLimit]]. */ -trait BaseLimit extends UnaryNode { +trait BaseLimit extends UnaryNode with CodegenSupport { val limit: Int override def output: Seq[Attribute] = child.output override def outputOrdering: Seq[SortOrder] = child.outputOrdering @@ -56,6 +57,36 @@ trait BaseLimit extends UnaryNode { protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter => iter.take(limit) } + + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() + } + + protected override def doProduce(ctx: CodegenContext): String = { + child.asInstanceOf[CodegenSupport].produce(ctx, this) + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val stopEarly = ctx.freshName("stopEarly") + ctx.addMutableState("boolean", stopEarly, s"$stopEarly = false;") + + ctx.addNewFunction("shouldStop", s""" + @Override + protected boolean shouldStop() { + return !currentRows.isEmpty() || $stopEarly; + } + """) + val countTerm = ctx.freshName("count") + ctx.addMutableState("int", countTerm, s"$countTerm = 0;") + s""" + | if ($countTerm < $limit) { + | $countTerm += 1; + | ${consume(ctx, input)} + | } else { + | $stopEarly = true; + | } + """.stripMargin + } } /** |