diff options
Diffstat (limited to 'sql/core/src')
8 files changed, 25 insertions, 8 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala index a84e180ad1..05627ba9c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala @@ -187,6 +187,7 @@ case class Expand( val i = ctx.freshName("i") // these column have to declared before the loop. val evaluate = evaluateVariables(outputColumns) + ctx.copyResult = true s""" |$evaluate |for (int $i = 0; $i < ${projections.length}; $i ++) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala index ef84992e69..431f02102e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala @@ -115,7 +115,8 @@ class GroupedIterator private( false } else { // Skip to next group. - while (input.hasNext && keyOrdering.compare(currentGroup, currentRow) == 0) { + // currentRow may be overwritten by `hasNext`, so we should compare them first. + while (keyOrdering.compare(currentGroup, currentRow) == 0 && input.hasNext) { currentRow = input.next() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala index 5a67cd0c24..b4dd77041e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala @@ -111,7 +111,6 @@ case class Sort( val needToSort = ctx.freshName("needToSort") ctx.addMutableState("boolean", needToSort, s"$needToSort = true;") - // Initialize the class member variables. This includes the instance of the Sorter and // the iterator to return sorted rows. val thisPlan = ctx.addReferenceObj("plan", this) @@ -132,6 +131,10 @@ case class Sort( | } """.stripMargin.trim) + // The child could change `copyResult` to true, but we had already consumed all the rows, + // so `copyResult` should be reset to `false`. + ctx.copyResult = false + val outputRow = ctx.freshName("outputRow") val dataSize = metricTerm(ctx, "dataSize") val spillSize = metricTerm(ctx, "spillSize") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index a54b772045..67aef72ded 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -379,10 +379,15 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup input: Seq[ExprCode], row: String = null): String = { + val doCopy = if (ctx.copyResult) { + ".copy()" + } else { + "" + } if (row != null) { // There is an UnsafeRow already s""" - |append($row.copy()); + |append($row$doCopy); """.stripMargin.trim } else { assert(input != null) @@ -397,7 +402,7 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup s""" |$evaluateInputs |${code.code.trim} - |append(${code.value}.copy()); + |append(${code.value}$doCopy); """.stripMargin.trim } else { // There is no columns diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 1c4d594cd8..28945a507c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -465,6 +465,10 @@ case class TungstenAggregate( val outputCode = generateResultCode(ctx, keyTerm, bufferTerm, thisPlan) val numOutput = metricTerm(ctx, "numOutputRows") + // The child could change `copyResult` to true, but we had already consumed all the rows, + // so `copyResult` should be reset to `false`. + ctx.copyResult = false + s""" if (!$initAgg) { $initAgg = true; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index f84ed41f1d..aa2da283b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -230,6 +230,7 @@ case class BroadcastHashJoin( """.stripMargin } else { + ctx.copyResult = true val matches = ctx.freshName("matches") val bufferType = classOf[CompactBuffer[UnsafeRow]].getName val i = ctx.freshName("i") @@ -303,6 +304,7 @@ case class BroadcastHashJoin( """.stripMargin } else { + ctx.copyResult = true val matches = ctx.freshName("matches") val bufferType = classOf[CompactBuffer[UnsafeRow]].getName val i = ctx.freshName("i") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index d0724ff6e5..807b39ace6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -404,6 +404,7 @@ case class SortMergeJoin( } override def doProduce(ctx: CodegenContext): String = { + ctx.copyResult = true val leftInput = ctx.freshName("leftInput") ctx.addMutableState("scala.collection.Iterator", leftInput, s"$leftInput = inputs[0];") val rightInput = ctx.freshName("rightInput") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index cb672643f1..b6051b07c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -457,12 +457,12 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { benchmark.run() /** - * Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz collect: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - collect 1 million 775 / 1170 1.4 738.9 1.0X - collect 2 millions 1153 / 1758 0.9 1099.3 0.7X - collect 4 millions 4451 / 5124 0.2 4244.9 0.2X + collect 1 million 439 / 654 2.4 418.7 1.0X + collect 2 millions 961 / 1907 1.1 916.4 0.5X + collect 4 millions 3193 / 3895 0.3 3044.7 0.1X */ } } |