aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core/src')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala1
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala8
8 files changed, 25 insertions, 8 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
index a84e180ad1..05627ba9c7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
@@ -187,6 +187,7 @@ case class Expand(
val i = ctx.freshName("i")
// these column have to declared before the loop.
val evaluate = evaluateVariables(outputColumns)
+ ctx.copyResult = true
s"""
|$evaluate
|for (int $i = 0; $i < ${projections.length}; $i ++) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala
index ef84992e69..431f02102e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala
@@ -115,7 +115,8 @@ class GroupedIterator private(
false
} else {
// Skip to next group.
- while (input.hasNext && keyOrdering.compare(currentGroup, currentRow) == 0) {
+ // currentRow may be overwritten by `hasNext`, so we should compare them first.
+ while (keyOrdering.compare(currentGroup, currentRow) == 0 && input.hasNext) {
currentRow = input.next()
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
index 5a67cd0c24..b4dd77041e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
@@ -111,7 +111,6 @@ case class Sort(
val needToSort = ctx.freshName("needToSort")
ctx.addMutableState("boolean", needToSort, s"$needToSort = true;")
-
// Initialize the class member variables. This includes the instance of the Sorter and
// the iterator to return sorted rows.
val thisPlan = ctx.addReferenceObj("plan", this)
@@ -132,6 +131,10 @@ case class Sort(
| }
""".stripMargin.trim)
+ // The child could change `copyResult` to true, but we had already consumed all the rows,
+ // so `copyResult` should be reset to `false`.
+ ctx.copyResult = false
+
val outputRow = ctx.freshName("outputRow")
val dataSize = metricTerm(ctx, "dataSize")
val spillSize = metricTerm(ctx, "spillSize")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index a54b772045..67aef72ded 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -379,10 +379,15 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
input: Seq[ExprCode],
row: String = null): String = {
+ val doCopy = if (ctx.copyResult) {
+ ".copy()"
+ } else {
+ ""
+ }
if (row != null) {
// There is an UnsafeRow already
s"""
- |append($row.copy());
+ |append($row$doCopy);
""".stripMargin.trim
} else {
assert(input != null)
@@ -397,7 +402,7 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
s"""
|$evaluateInputs
|${code.code.trim}
- |append(${code.value}.copy());
+ |append(${code.value}$doCopy);
""".stripMargin.trim
} else {
// There is no columns
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 1c4d594cd8..28945a507c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -465,6 +465,10 @@ case class TungstenAggregate(
val outputCode = generateResultCode(ctx, keyTerm, bufferTerm, thisPlan)
val numOutput = metricTerm(ctx, "numOutputRows")
+ // The child could change `copyResult` to true, but we had already consumed all the rows,
+ // so `copyResult` should be reset to `false`.
+ ctx.copyResult = false
+
s"""
if (!$initAgg) {
$initAgg = true;
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index f84ed41f1d..aa2da283b1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -230,6 +230,7 @@ case class BroadcastHashJoin(
""".stripMargin
} else {
+ ctx.copyResult = true
val matches = ctx.freshName("matches")
val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
val i = ctx.freshName("i")
@@ -303,6 +304,7 @@ case class BroadcastHashJoin(
""".stripMargin
} else {
+ ctx.copyResult = true
val matches = ctx.freshName("matches")
val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
val i = ctx.freshName("i")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
index d0724ff6e5..807b39ace6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
@@ -404,6 +404,7 @@ case class SortMergeJoin(
}
override def doProduce(ctx: CodegenContext): String = {
+ ctx.copyResult = true
val leftInput = ctx.freshName("leftInput")
ctx.addMutableState("scala.collection.Iterator", leftInput, s"$leftInput = inputs[0];")
val rightInput = ctx.freshName("rightInput")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index cb672643f1..b6051b07c8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -457,12 +457,12 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
benchmark.run()
/**
- * Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+ Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
collect: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------
- collect 1 million 775 / 1170 1.4 738.9 1.0X
- collect 2 millions 1153 / 1758 0.9 1099.3 0.7X
- collect 4 millions 4451 / 5124 0.2 4244.9 0.2X
+ collect 1 million 439 / 654 2.4 418.7 1.0X
+ collect 2 millions 961 / 1907 1.1 916.4 0.5X
+ collect 4 millions 3193 / 3895 0.3 3044.7 0.1X
*/
}
}