aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala8
4 files changed, 22 insertions, 4 deletions
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java
index 086547c793..730a4ae8d5 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java
@@ -70,6 +70,16 @@ public abstract class BufferedRowIterator {
}
/**
+ * Returns whether this iterator should stop fetching next row from [[CodegenSupport#inputRDDs]].
+ *
+ * If it returns true, the caller should exit the loop that [[InputAdapter]] generates.
+ * This interface is mainly used to limit the number of input rows.
+ */
+ protected boolean stopEarly() {
+ return false;
+ }
+
+ /**
* Returns whether `processNext()` should stop processing next row from `input` or not.
*
* If it returns true, the caller should exit the loop (return from processNext()).
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index 516b9d5444..2ead8f6baa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -241,7 +241,7 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp
ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];")
val row = ctx.freshName("row")
s"""
- | while ($input.hasNext()) {
+ | while ($input.hasNext() && !stopEarly()) {
| InternalRow $row = (InternalRow) $input.next();
| ${consume(ctx, null, row).trim}
| if (shouldStop()) return;
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
index 9918ac327f..757fe2185d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
@@ -70,10 +70,10 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport {
val stopEarly = ctx.freshName("stopEarly")
ctx.addMutableState("boolean", stopEarly, s"$stopEarly = false;")
- ctx.addNewFunction("shouldStop", s"""
+ ctx.addNewFunction("stopEarly", s"""
@Override
- protected boolean shouldStop() {
- return !currentRows.isEmpty() || $stopEarly;
+ protected boolean stopEarly() {
+ return $stopEarly;
}
""")
val countTerm = ctx.freshName("count")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 7aa4f0026f..645175900f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -513,4 +513,12 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
df.groupBy($"x").agg(countDistinct($"y"), sort_array(collect_list($"z"))),
Seq(Row(1, 2, Seq("a", "b")), Row(3, 2, Seq("c", "c", "d"))))
}
+
+ test("SPARK-18004 limit + aggregates") {
+ val df = Seq(("a", 1), ("b", 2), ("c", 1), ("d", 5)).toDF("id", "value")
+ val limit2Df = df.limit(2)
+ checkAnswer(
+ limit2Df.groupBy("id").count().select($"id"),
+ limit2Df.select($"id"))
+ }
}