aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala15
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala27
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala16
5 files changed, 55 insertions, 7 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
index cc576bbc4c..f98ae82574 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
@@ -177,6 +177,8 @@ case class SortExec(
""".stripMargin.trim
}
+ protected override val shouldStopRequired = false
+
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
s"""
|${row.code}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index c58474eba0..c31fd92447 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -206,6 +206,21 @@ trait CodegenSupport extends SparkPlan {
def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
throw new UnsupportedOperationException
}
+
+ /**
+ * For optimization to suppress shouldStop() in a loop of WholeStageCodegen.
+ * Returning true means we need to insert shouldStop() into the loop producing rows, if any.
+ */
+ def isShouldStopRequired: Boolean = {
+ return shouldStopRequired && (this.parent == null || this.parent.isShouldStopRequired)
+ }
+
+ /**
+ * Set to false if this plan consumes all rows produced by children but doesn't output row
+ * to buffer by calling append(), so the children don't require shouldStop()
+ * in the loop of producing rows.
+ */
+ protected def shouldStopRequired: Boolean = true
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index 4529ed067e..68c8e6ce62 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -238,6 +238,8 @@ case class HashAggregateExec(
""".stripMargin
}
+ protected override val shouldStopRequired = false
+
private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = {
// only have DeclarativeAggregate
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index 87e90ed685..d876688a8a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -387,8 +387,8 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
// How many values should be generated in the next batch.
val nextBatchTodo = ctx.freshName("nextBatchTodo")
- // The default size of a batch.
- val batchSize = 1000L
+ // The default size of a batch, which must be positive integer
+ val batchSize = 1000
ctx.addNewFunction("initRange",
s"""
@@ -434,6 +434,15 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
val input = ctx.freshName("input")
// Right now, Range is only used when there is one upstream.
ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];")
+
+ val localIdx = ctx.freshName("localIdx")
+ val localEnd = ctx.freshName("localEnd")
+ val range = ctx.freshName("range")
+ val shouldStop = if (isShouldStopRequired) {
+ s"if (shouldStop()) { $number = $value + ${step}L; return; }"
+ } else {
+ "// shouldStop check is eliminated"
+ }
s"""
| // initialize Range
| if (!$initTerm) {
@@ -442,11 +451,15 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
| }
|
| while (true) {
- | while ($number != $batchEnd) {
- | long $value = $number;
- | $number += ${step}L;
- | ${consume(ctx, Seq(ev))}
- | if (shouldStop()) return;
+ | long $range = $batchEnd - $number;
+ | if ($range != 0L) {
+ | int $localEnd = (int)($range / ${step}L);
+ | for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) {
+ | long $value = ((long)$localIdx * ${step}L) + $number;
+ | ${consume(ctx, Seq(ev))}
+ | $shouldStop
+ | }
+ | $number = $batchEnd;
| }
|
| if ($taskContext.isInterrupted()) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala
index acf393a9b0..5e323c02b2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala
@@ -89,6 +89,22 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall
val n = 9L * 1000 * 1000 * 1000 * 1000 * 1000 * 1000
val res13 = spark.range(-n, n, n / 9).select("id")
assert(res13.count == 18)
+
+ // range with non aggregation operation
+ val res14 = spark.range(0, 100, 2).toDF.filter("50 <= id")
+ val len14 = res14.collect.length
+ assert(len14 == 25)
+
+ val res15 = spark.range(100, -100, -2).toDF.filter("id <= 0")
+ val len15 = res15.collect.length
+ assert(len15 == 50)
+
+ val res16 = spark.range(-1500, 1500, 3).toDF.filter("0 <= id")
+ val len16 = res16.collect.length
+ assert(len16 == 500)
+
+ val res17 = spark.range(10, 0, -1, 1).toDF.sortWithinPartitions("id")
+ assert(res17.collect === (1 to 10).map(i => Row(i)).toArray)
}
test("Range with randomized parameters") {