aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2016-09-13 12:54:03 +0200
committerHerman van Hovell <hvanhovell@databricks.com>2016-09-13 12:54:03 +0200
commit3f6a2bb3f7beac4ce928eb660ee36258b5b9e8c8 (patch)
treee397fdba23c4cc536f452ba5fdf88934fbe74b07 /sql
parent46f5c201e70053635bdeab4984ba1b649478bd12 (diff)
downloadspark-3f6a2bb3f7beac4ce928eb660ee36258b5b9e8c8.tar.gz
spark-3f6a2bb3f7beac4ce928eb660ee36258b5b9e8c8.tar.bz2
spark-3f6a2bb3f7beac4ce928eb660ee36258b5b9e8c8.zip
[SPARK-17515] CollectLimit.execute() should perform per-partition limits
## What changes were proposed in this pull request? CollectLimit.execute() incorrectly omits per-partition limits, leading to performance regressions in case this case is hit (which should not happen in normal operation, but can occur in some cases (see #15068 for one example). ## How was this patch tested? Regression test in SQLQuerySuite that asserts the number of records scanned from the input RDD. Author: Josh Rosen <joshrosen@databricks.com> Closes #15070 from JoshRosen/SPARK-17515.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala9
2 files changed, 11 insertions, 1 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
index 01fbe5b7c2..86a8770715 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
@@ -39,9 +39,10 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode
override def executeCollect(): Array[InternalRow] = child.executeTake(limit)
private val serializer: Serializer = new UnsafeRowSerializer(child.output.size)
protected override def doExecute(): RDD[InternalRow] = {
+ val locallyLimited = child.execute().mapPartitionsInternal(_.take(limit))
val shuffled = new ShuffledRowRDD(
ShuffleExchange.prepareShuffleDependency(
- child.execute(), child.output, SinglePartition, serializer))
+ locallyLimited, child.output, SinglePartition, serializer))
shuffled.mapPartitionsInternal(_.take(limit))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index eac266cba5..a2164f9ae3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -2661,4 +2661,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
data.selectExpr("`part.col1`", "`col.1`"))
}
}
+
+ test("SPARK-17515: CollectLimit.execute() should perform per-partition limits") {
+ val numRecordsRead = spark.sparkContext.longAccumulator
+ spark.range(1, 100, 1, numPartitions = 10).map { x =>
+ numRecordsRead.add(1)
+ x
+ }.limit(1).queryExecution.toRdd.count()
+ assert(numRecordsRead.value === 10)
+ }
}