diff options
author | Josh Rosen <joshrosen@databricks.com> | 2016-09-13 12:54:03 +0200 |
---|---|---|
committer | Herman van Hovell <hvanhovell@databricks.com> | 2016-09-13 12:54:03 +0200 |
commit | 3f6a2bb3f7beac4ce928eb660ee36258b5b9e8c8 (patch) | |
tree | e397fdba23c4cc536f452ba5fdf88934fbe74b07 | |
parent | 46f5c201e70053635bdeab4984ba1b649478bd12 (diff) | |
download | spark-3f6a2bb3f7beac4ce928eb660ee36258b5b9e8c8.tar.gz spark-3f6a2bb3f7beac4ce928eb660ee36258b5b9e8c8.tar.bz2 spark-3f6a2bb3f7beac4ce928eb660ee36258b5b9e8c8.zip |
[SPARK-17515] CollectLimit.execute() should perform per-partition limits
## What changes were proposed in this pull request?
CollectLimit.execute() incorrectly omits per-partition limits, leading to performance regressions in case this case is hit (which should not happen in normal operation, but can occur in some cases (see #15068 for one example).
## How was this patch tested?
Regression test in SQLQuerySuite that asserts the number of records scanned from the input RDD.
Author: Josh Rosen <joshrosen@databricks.com>
Closes #15070 from JoshRosen/SPARK-17515.
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala | 3 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 9 |
2 files changed, 11 insertions, 1 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 01fbe5b7c2..86a8770715 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -39,9 +39,10 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode override def executeCollect(): Array[InternalRow] = child.executeTake(limit) private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) protected override def doExecute(): RDD[InternalRow] = { + val locallyLimited = child.execute().mapPartitionsInternal(_.take(limit)) val shuffled = new ShuffledRowRDD( ShuffleExchange.prepareShuffleDependency( - child.execute(), child.output, SinglePartition, serializer)) + locallyLimited, child.output, SinglePartition, serializer)) shuffled.mapPartitionsInternal(_.take(limit)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index eac266cba5..a2164f9ae3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2661,4 +2661,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { data.selectExpr("`part.col1`", "`col.1`")) } } + + test("SPARK-17515: CollectLimit.execute() should perform per-partition limits") { + val numRecordsRead = spark.sparkContext.longAccumulator + spark.range(1, 100, 1, numPartitions = 10).map { x => + numRecordsRead.add(1) + x + }.limit(1).queryExecution.toRdd.count() + assert(numRecordsRead.value === 10) + } } |