From 3f6a2bb3f7beac4ce928eb660ee36258b5b9e8c8 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 13 Sep 2016 12:54:03 +0200 Subject: [SPARK-17515] CollectLimit.execute() should perform per-partition limits ## What changes were proposed in this pull request? CollectLimit.execute() incorrectly omits per-partition limits, leading to performance regressions in case this case is hit (which should not happen in normal operation, but can occur in some cases (see #15068 for one example). ## How was this patch tested? Regression test in SQLQuerySuite that asserts the number of records scanned from the input RDD. Author: Josh Rosen Closes #15070 from JoshRosen/SPARK-17515. --- .../src/main/scala/org/apache/spark/sql/execution/limit.scala | 3 ++- sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 9 +++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) (limited to 'sql') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 01fbe5b7c2..86a8770715 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -39,9 +39,10 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode override def executeCollect(): Array[InternalRow] = child.executeTake(limit) private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) protected override def doExecute(): RDD[InternalRow] = { + val locallyLimited = child.execute().mapPartitionsInternal(_.take(limit)) val shuffled = new ShuffledRowRDD( ShuffleExchange.prepareShuffleDependency( - child.execute(), child.output, SinglePartition, serializer)) + locallyLimited, child.output, SinglePartition, serializer)) shuffled.mapPartitionsInternal(_.take(limit)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index eac266cba5..a2164f9ae3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2661,4 +2661,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { data.selectExpr("`part.col1`", "`col.1`")) } } + + test("SPARK-17515: CollectLimit.execute() should perform per-partition limits") { + val numRecordsRead = spark.sparkContext.longAccumulator + spark.range(1, 100, 1, numPartitions = 10).map { x => + numRecordsRead.add(1) + x + }.limit(1).queryExecution.toRdd.count() + assert(numRecordsRead.value === 10) + } } -- cgit v1.2.3