aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala35
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala14
2 files changed, 47 insertions, 2 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
index cd543d4195..45175d36d5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
@@ -21,9 +21,10 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LazilyGeneratedOrdering}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.exchange.ShuffleExchange
+import org.apache.spark.sql.execution.metric.SQLMetrics
/**
@@ -48,7 +49,7 @@ case class CollectLimit(limit: Int, child: SparkPlan) extends UnaryNode {
/**
* Helper trait which defines methods that are shared by both [[LocalLimit]] and [[GlobalLimit]].
*/
-trait BaseLimit extends UnaryNode {
+trait BaseLimit extends UnaryNode with CodegenSupport {
val limit: Int
override def output: Seq[Attribute] = child.output
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
@@ -56,6 +57,36 @@ trait BaseLimit extends UnaryNode {
protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter =>
iter.take(limit)
}
+
+ override def upstreams(): Seq[RDD[InternalRow]] = {
+ child.asInstanceOf[CodegenSupport].upstreams()
+ }
+
+ protected override def doProduce(ctx: CodegenContext): String = {
+ child.asInstanceOf[CodegenSupport].produce(ctx, this)
+ }
+
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ val stopEarly = ctx.freshName("stopEarly")
+ ctx.addMutableState("boolean", stopEarly, s"$stopEarly = false;")
+
+ ctx.addNewFunction("shouldStop", s"""
+ @Override
+ protected boolean shouldStop() {
+ return !currentRows.isEmpty() || $stopEarly;
+ }
+ """)
+ val countTerm = ctx.freshName("count")
+ ctx.addMutableState("int", countTerm, s"$countTerm = 0;")
+ s"""
+ | if ($countTerm < $limit) {
+ | $countTerm += 1;
+ | ${consume(ctx, input)}
+ | } else {
+ | $stopEarly = true;
+ | }
+ """.stripMargin
+ }
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index 6d6cc0186a..2d3e34d0e1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -70,6 +70,20 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
*/
}
+ ignore("range/limit/sum") {
+ val N = 500 << 20
+ runBenchmark("range/limit/sum", N) {
+ sqlContext.range(N).limit(1000000).groupBy().sum().collect()
+ }
+ /*
+ Westmere E56xx/L56xx/X56xx (Nehalem-C)
+ range/limit/sum: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ -------------------------------------------------------------------------------------------
+ range/limit/sum codegen=false 609 / 672 861.6 1.2 1.0X
+ range/limit/sum codegen=true 561 / 621 935.3 1.1 1.1X
+ */
+ }
+
ignore("stat functions") {
val N = 100 << 20