aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala9
2 files changed, 11 insertions, 1 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
index 01fbe5b7c2..86a8770715 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
@@ -39,9 +39,10 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode
override def executeCollect(): Array[InternalRow] = child.executeTake(limit)
private val serializer: Serializer = new UnsafeRowSerializer(child.output.size)
protected override def doExecute(): RDD[InternalRow] = {
+ val locallyLimited = child.execute().mapPartitionsInternal(_.take(limit))
val shuffled = new ShuffledRowRDD(
ShuffleExchange.prepareShuffleDependency(
- child.execute(), child.output, SinglePartition, serializer))
+ locallyLimited, child.output, SinglePartition, serializer))
shuffled.mapPartitionsInternal(_.take(limit))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index eac266cba5..a2164f9ae3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -2661,4 +2661,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
data.selectExpr("`part.col1`", "`col.1`"))
}
}
+
+ test("SPARK-17515: CollectLimit.execute() should perform per-partition limits") {
+ val numRecordsRead = spark.sparkContext.longAccumulator
+ spark.range(1, 100, 1, numPartitions = 10).map { x =>
+ numRecordsRead.add(1)
+ x
+ }.limit(1).queryExecution.toRdd.count()
+ assert(numRecordsRead.value === 10)
+ }
}