aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@gmail.com>2016-03-01 08:43:02 -0800
committerDavies Liu <davies.liu@gmail.com>2016-03-01 08:43:02 -0800
commitc43899a04e4de18e238a1761bf4fe9f54e182320 (patch)
tree34c22b64f5034e14f7fe33b2975469ee0e09d2f5 /sql/core/src
parent12a2a57e1af21da0aa4275971365d76a8fc84a43 (diff)
downloadspark-c43899a04e4de18e238a1761bf4fe9f54e182320.tar.gz
spark-c43899a04e4de18e238a1761bf4fe9f54e182320.tar.bz2
spark-c43899a04e4de18e238a1761bf4fe9f54e182320.zip
[SPARK-13511] [SQL] Add wholestage codegen for limit
JIRA: https://issues.apache.org/jira/browse/SPARK-13511 ## What changes were proposed in this pull request? Current limit operator doesn't support wholestage codegen. This is open to add support for it. In the `doConsume` of `GlobalLimit` and `LocalLimit`, we use a count term to count the processed rows. Once the row numbers catches the limit number, we set the variable `stopEarly` of `BufferedRowIterator` newly added in this pr to `true` that indicates we want to stop processing remaining rows. Then when the wholestage codegen framework checks `shouldStop()`, it will stop the processing of the row iterator. Before this, the executed plan for a query `sqlContext.range(N).limit(100).groupBy().sum()` is: TungstenAggregate(key=[], functions=[(sum(id#5L),mode=Final,isDistinct=false)], output=[sum(id)#6L]) +- TungstenAggregate(key=[], functions=[(sum(id#5L),mode=Partial,isDistinct=false)], output=[sum#9L]) +- GlobalLimit 100 +- Exchange SinglePartition, None +- LocalLimit 100 +- Range 0, 1, 1, 524288000, [id#5L] After add wholestage codegen support: WholeStageCodegen : +- TungstenAggregate(key=[], functions=[(sum(id#40L),mode=Final,isDistinct=false)], output=[sum(id)#41L]) : +- TungstenAggregate(key=[], functions=[(sum(id#40L),mode=Partial,isDistinct=false)], output=[sum#44L]) : +- GlobalLimit 100 : +- INPUT +- Exchange SinglePartition, None +- WholeStageCodegen : +- LocalLimit 100 : +- Range 0, 1, 1, 524288000, [id#40L] ## How was this patch tested? A test is added into BenchmarkWholeStageCodegen. Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #11391 from viirya/wholestage-limit.
Diffstat (limited to 'sql/core/src')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala35
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala14
2 files changed, 47 insertions, 2 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
index cd543d4195..45175d36d5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
@@ -21,9 +21,10 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LazilyGeneratedOrdering}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.exchange.ShuffleExchange
+import org.apache.spark.sql.execution.metric.SQLMetrics
/**
@@ -48,7 +49,7 @@ case class CollectLimit(limit: Int, child: SparkPlan) extends UnaryNode {
/**
* Helper trait which defines methods that are shared by both [[LocalLimit]] and [[GlobalLimit]].
*/
-trait BaseLimit extends UnaryNode {
+trait BaseLimit extends UnaryNode with CodegenSupport {
val limit: Int
override def output: Seq[Attribute] = child.output
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
@@ -56,6 +57,36 @@ trait BaseLimit extends UnaryNode {
protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter =>
iter.take(limit)
}
+
+ override def upstreams(): Seq[RDD[InternalRow]] = {
+ child.asInstanceOf[CodegenSupport].upstreams()
+ }
+
+ protected override def doProduce(ctx: CodegenContext): String = {
+ child.asInstanceOf[CodegenSupport].produce(ctx, this)
+ }
+
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ val stopEarly = ctx.freshName("stopEarly")
+ ctx.addMutableState("boolean", stopEarly, s"$stopEarly = false;")
+
+ ctx.addNewFunction("shouldStop", s"""
+ @Override
+ protected boolean shouldStop() {
+ return !currentRows.isEmpty() || $stopEarly;
+ }
+ """)
+ val countTerm = ctx.freshName("count")
+ ctx.addMutableState("int", countTerm, s"$countTerm = 0;")
+ s"""
+ | if ($countTerm < $limit) {
+ | $countTerm += 1;
+ | ${consume(ctx, input)}
+ | } else {
+ | $stopEarly = true;
+ | }
+ """.stripMargin
+ }
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index 6d6cc0186a..2d3e34d0e1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -70,6 +70,20 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
*/
}
+ ignore("range/limit/sum") {
+ val N = 500 << 20
+ runBenchmark("range/limit/sum", N) {
+ sqlContext.range(N).limit(1000000).groupBy().sum().collect()
+ }
+ /*
+ Westmere E56xx/L56xx/X56xx (Nehalem-C)
+ range/limit/sum: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ -------------------------------------------------------------------------------------------
+ range/limit/sum codegen=false 609 / 672 861.6 1.2 1.0X
+ range/limit/sum codegen=true 561 / 621 935.3 1.1 1.1X
+ */
+ }
+
ignore("stat functions") {
val N = 100 << 20