aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core/src')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala243
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala52
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala117
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/window/RowBuffer.scala115
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala72
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala97
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala136
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala233
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala351
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala33
10 files changed, 1157 insertions, 292 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala
new file mode 100644
index 0000000000..458ac4ba36
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala
@@ -0,0 +1,243 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import java.util.ConcurrentModificationException
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.internal.Logging
+import org.apache.spark.memory.TaskMemoryManager
+import org.apache.spark.serializer.SerializerManager
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray.DefaultInitialSizeOfInMemoryBuffer
+import org.apache.spark.storage.BlockManager
+import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, UnsafeSorterIterator}
+
+/**
+ * An append-only array for [[UnsafeRow]]s that spills content to disk when there a predefined
+ * threshold of rows is reached.
+ *
+ * Setting spill threshold faces following trade-off:
+ *
+ * - If the spill threshold is too high, the in-memory array may occupy more memory than is
+ * available, resulting in OOM.
+ * - If the spill threshold is too low, we spill frequently and incur unnecessary disk writes.
+ * This may lead to a performance regression compared to the normal case of using an
+ * [[ArrayBuffer]] or [[Array]].
+ */
+private[sql] class ExternalAppendOnlyUnsafeRowArray(
+ taskMemoryManager: TaskMemoryManager,
+ blockManager: BlockManager,
+ serializerManager: SerializerManager,
+ taskContext: TaskContext,
+ initialSize: Int,
+ pageSizeBytes: Long,
+ numRowsSpillThreshold: Int) extends Logging {
+
+ def this(numRowsSpillThreshold: Int) {
+ this(
+ TaskContext.get().taskMemoryManager(),
+ SparkEnv.get.blockManager,
+ SparkEnv.get.serializerManager,
+ TaskContext.get(),
+ 1024,
+ SparkEnv.get.memoryManager.pageSizeBytes,
+ numRowsSpillThreshold)
+ }
+
+ private val initialSizeOfInMemoryBuffer =
+ Math.min(DefaultInitialSizeOfInMemoryBuffer, numRowsSpillThreshold)
+
+ private val inMemoryBuffer = if (initialSizeOfInMemoryBuffer > 0) {
+ new ArrayBuffer[UnsafeRow](initialSizeOfInMemoryBuffer)
+ } else {
+ null
+ }
+
+ private var spillableArray: UnsafeExternalSorter = _
+ private var numRows = 0
+
+ // A counter to keep track of total modifications done to this array since its creation.
+ // This helps to invalidate iterators when there are changes done to the backing array.
+ private var modificationsCount: Long = 0
+
+ private var numFieldsPerRow = 0
+
+ def length: Int = numRows
+
+ def isEmpty: Boolean = numRows == 0
+
+ /**
+ * Clears up resources (eg. memory) held by the backing storage
+ */
+ def clear(): Unit = {
+ if (spillableArray != null) {
+ // The last `spillableArray` of this task will be cleaned up via task completion listener
+ // inside `UnsafeExternalSorter`
+ spillableArray.cleanupResources()
+ spillableArray = null
+ } else if (inMemoryBuffer != null) {
+ inMemoryBuffer.clear()
+ }
+ numFieldsPerRow = 0
+ numRows = 0
+ modificationsCount += 1
+ }
+
+ def add(unsafeRow: UnsafeRow): Unit = {
+ if (numRows < numRowsSpillThreshold) {
+ inMemoryBuffer += unsafeRow.copy()
+ } else {
+ if (spillableArray == null) {
+ logInfo(s"Reached spill threshold of $numRowsSpillThreshold rows, switching to " +
+ s"${classOf[UnsafeExternalSorter].getName}")
+
+ // We will not sort the rows, so prefixComparator and recordComparator are null
+ spillableArray = UnsafeExternalSorter.create(
+ taskMemoryManager,
+ blockManager,
+ serializerManager,
+ taskContext,
+ null,
+ null,
+ initialSize,
+ pageSizeBytes,
+ numRowsSpillThreshold,
+ false)
+
+ // populate with existing in-memory buffered rows
+ if (inMemoryBuffer != null) {
+ inMemoryBuffer.foreach(existingUnsafeRow =>
+ spillableArray.insertRecord(
+ existingUnsafeRow.getBaseObject,
+ existingUnsafeRow.getBaseOffset,
+ existingUnsafeRow.getSizeInBytes,
+ 0,
+ false)
+ )
+ inMemoryBuffer.clear()
+ }
+ numFieldsPerRow = unsafeRow.numFields()
+ }
+
+ spillableArray.insertRecord(
+ unsafeRow.getBaseObject,
+ unsafeRow.getBaseOffset,
+ unsafeRow.getSizeInBytes,
+ 0,
+ false)
+ }
+
+ numRows += 1
+ modificationsCount += 1
+ }
+
+ /**
+ * Creates an [[Iterator]] for the current rows in the array starting from a user provided index
+ *
+ * If there are subsequent [[add()]] or [[clear()]] calls made on this array after creation of
+ * the iterator, then the iterator is invalidated thus saving clients from thinking that they
+ * have read all the data while there were new rows added to this array.
+ */
+ def generateIterator(startIndex: Int): Iterator[UnsafeRow] = {
+ if (startIndex < 0 || (numRows > 0 && startIndex > numRows)) {
+ throw new ArrayIndexOutOfBoundsException(
+ "Invalid `startIndex` provided for generating iterator over the array. " +
+ s"Total elements: $numRows, requested `startIndex`: $startIndex")
+ }
+
+ if (spillableArray == null) {
+ new InMemoryBufferIterator(startIndex)
+ } else {
+ new SpillableArrayIterator(spillableArray.getIterator, numFieldsPerRow, startIndex)
+ }
+ }
+
+ def generateIterator(): Iterator[UnsafeRow] = generateIterator(startIndex = 0)
+
+ private[this]
+ abstract class ExternalAppendOnlyUnsafeRowArrayIterator extends Iterator[UnsafeRow] {
+ private val expectedModificationsCount = modificationsCount
+
+ protected def isModified(): Boolean = expectedModificationsCount != modificationsCount
+
+ protected def throwExceptionIfModified(): Unit = {
+ if (expectedModificationsCount != modificationsCount) {
+ throw new ConcurrentModificationException(
+ s"The backing ${classOf[ExternalAppendOnlyUnsafeRowArray].getName} has been modified " +
+ s"since the creation of this Iterator")
+ }
+ }
+ }
+
+ private[this] class InMemoryBufferIterator(startIndex: Int)
+ extends ExternalAppendOnlyUnsafeRowArrayIterator {
+
+ private var currentIndex = startIndex
+
+ override def hasNext(): Boolean = !isModified() && currentIndex < numRows
+
+ override def next(): UnsafeRow = {
+ throwExceptionIfModified()
+ val result = inMemoryBuffer(currentIndex)
+ currentIndex += 1
+ result
+ }
+ }
+
+ private[this] class SpillableArrayIterator(
+ iterator: UnsafeSorterIterator,
+ numFieldPerRow: Int,
+ startIndex: Int)
+ extends ExternalAppendOnlyUnsafeRowArrayIterator {
+
+ private val currentRow = new UnsafeRow(numFieldPerRow)
+
+ def init(): Unit = {
+ var i = 0
+ while (i < startIndex) {
+ if (iterator.hasNext) {
+ iterator.loadNext()
+ } else {
+ throw new ArrayIndexOutOfBoundsException(
+ "Invalid `startIndex` provided for generating iterator over the array. " +
+ s"Total elements: $numRows, requested `startIndex`: $startIndex")
+ }
+ i += 1
+ }
+ }
+
+ // Traverse upto the given [[startIndex]]
+ init()
+
+ override def hasNext(): Boolean = !isModified() && iterator.hasNext
+
+ override def next(): UnsafeRow = {
+ throwExceptionIfModified()
+ iterator.loadNext()
+ currentRow.pointTo(iterator.getBaseObject, iterator.getBaseOffset, iterator.getRecordLength)
+ currentRow
+ }
+ }
+}
+
+private[sql] object ExternalAppendOnlyUnsafeRowArray {
+ val DefaultInitialSizeOfInMemoryBuffer = 128
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
index 8341fe2ffd..f380986951 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
@@ -19,65 +19,39 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark._
import org.apache.spark.rdd.{CartesianPartition, CartesianRDD, RDD}
-import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
-import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
+import org.apache.spark.sql.execution.{BinaryExecNode, ExternalAppendOnlyUnsafeRowArray, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.CompletionIterator
-import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
/**
* An optimized CartesianRDD for UnsafeRow, which will cache the rows from second child RDD,
* will be much faster than building the right partition for every row in left RDD, it also
* materialize the right RDD (in case of the right RDD is nondeterministic).
*/
-class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numFieldsOfRight: Int)
+class UnsafeCartesianRDD(
+ left : RDD[UnsafeRow],
+ right : RDD[UnsafeRow],
+ numFieldsOfRight: Int,
+ spillThreshold: Int)
extends CartesianRDD[UnsafeRow, UnsafeRow](left.sparkContext, left, right) {
override def compute(split: Partition, context: TaskContext): Iterator[(UnsafeRow, UnsafeRow)] = {
- // We will not sort the rows, so prefixComparator and recordComparator are null.
- val sorter = UnsafeExternalSorter.create(
- context.taskMemoryManager(),
- SparkEnv.get.blockManager,
- SparkEnv.get.serializerManager,
- context,
- null,
- null,
- 1024,
- SparkEnv.get.memoryManager.pageSizeBytes,
- SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold",
- UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD),
- false)
+ val rowArray = new ExternalAppendOnlyUnsafeRowArray(spillThreshold)
val partition = split.asInstanceOf[CartesianPartition]
- for (y <- rdd2.iterator(partition.s2, context)) {
- sorter.insertRecord(y.getBaseObject, y.getBaseOffset, y.getSizeInBytes, 0, false)
- }
+ rdd2.iterator(partition.s2, context).foreach(rowArray.add)
- // Create an iterator from sorter and wrapper it as Iterator[UnsafeRow]
- def createIter(): Iterator[UnsafeRow] = {
- val iter = sorter.getIterator
- val unsafeRow = new UnsafeRow(numFieldsOfRight)
- new Iterator[UnsafeRow] {
- override def hasNext: Boolean = {
- iter.hasNext
- }
- override def next(): UnsafeRow = {
- iter.loadNext()
- unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength)
- unsafeRow
- }
- }
- }
+ // Create an iterator from rowArray
+ def createIter(): Iterator[UnsafeRow] = rowArray.generateIterator()
val resultIter =
for (x <- rdd1.iterator(partition.s1, context);
y <- createIter()) yield (x, y)
CompletionIterator[(UnsafeRow, UnsafeRow), Iterator[(UnsafeRow, UnsafeRow)]](
- resultIter, sorter.cleanupResources())
+ resultIter, rowArray.clear())
}
}
@@ -97,7 +71,9 @@ case class CartesianProductExec(
val leftResults = left.execute().asInstanceOf[RDD[UnsafeRow]]
val rightResults = right.execute().asInstanceOf[RDD[UnsafeRow]]
- val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size)
+ val spillThreshold = sqlContext.conf.cartesianProductExecBufferSpillThreshold
+
+ val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size, spillThreshold)
pair.mapPartitionsWithIndexInternal { (index, iter) =>
val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)
val filtered = if (condition.isDefined) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index ca9c0ed8ce..bcdc4dcdf7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -25,7 +25,8 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, RowIterator, SparkPlan}
+import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport,
+ExternalAppendOnlyUnsafeRowArray, RowIterator, SparkPlan}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.util.collection.BitSet
@@ -95,9 +96,13 @@ case class SortMergeJoinExec(
private def createRightKeyGenerator(): Projection =
UnsafeProjection.create(rightKeys, right.output)
+ private def getSpillThreshold: Int = {
+ sqlContext.conf.sortMergeJoinExecBufferSpillThreshold
+ }
+
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
-
+ val spillThreshold = getSpillThreshold
left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
val boundCondition: (InternalRow) => Boolean = {
condition.map { cond =>
@@ -115,39 +120,39 @@ case class SortMergeJoinExec(
case _: InnerLike =>
new RowIterator {
private[this] var currentLeftRow: InternalRow = _
- private[this] var currentRightMatches: ArrayBuffer[InternalRow] = _
- private[this] var currentMatchIdx: Int = -1
+ private[this] var currentRightMatches: ExternalAppendOnlyUnsafeRowArray = _
+ private[this] var rightMatchesIterator: Iterator[UnsafeRow] = null
private[this] val smjScanner = new SortMergeJoinScanner(
createLeftKeyGenerator(),
createRightKeyGenerator(),
keyOrdering,
RowIterator.fromScala(leftIter),
- RowIterator.fromScala(rightIter)
+ RowIterator.fromScala(rightIter),
+ spillThreshold
)
private[this] val joinRow = new JoinedRow
if (smjScanner.findNextInnerJoinRows()) {
currentRightMatches = smjScanner.getBufferedMatches
currentLeftRow = smjScanner.getStreamedRow
- currentMatchIdx = 0
+ rightMatchesIterator = currentRightMatches.generateIterator()
}
override def advanceNext(): Boolean = {
- while (currentMatchIdx >= 0) {
- if (currentMatchIdx == currentRightMatches.length) {
+ while (rightMatchesIterator != null) {
+ if (!rightMatchesIterator.hasNext) {
if (smjScanner.findNextInnerJoinRows()) {
currentRightMatches = smjScanner.getBufferedMatches
currentLeftRow = smjScanner.getStreamedRow
- currentMatchIdx = 0
+ rightMatchesIterator = currentRightMatches.generateIterator()
} else {
currentRightMatches = null
currentLeftRow = null
- currentMatchIdx = -1
+ rightMatchesIterator = null
return false
}
}
- joinRow(currentLeftRow, currentRightMatches(currentMatchIdx))
- currentMatchIdx += 1
+ joinRow(currentLeftRow, rightMatchesIterator.next())
if (boundCondition(joinRow)) {
numOutputRows += 1
return true
@@ -165,7 +170,8 @@ case class SortMergeJoinExec(
bufferedKeyGenerator = createRightKeyGenerator(),
keyOrdering,
streamedIter = RowIterator.fromScala(leftIter),
- bufferedIter = RowIterator.fromScala(rightIter)
+ bufferedIter = RowIterator.fromScala(rightIter),
+ spillThreshold
)
val rightNullRow = new GenericInternalRow(right.output.length)
new LeftOuterIterator(
@@ -177,7 +183,8 @@ case class SortMergeJoinExec(
bufferedKeyGenerator = createLeftKeyGenerator(),
keyOrdering,
streamedIter = RowIterator.fromScala(rightIter),
- bufferedIter = RowIterator.fromScala(leftIter)
+ bufferedIter = RowIterator.fromScala(leftIter),
+ spillThreshold
)
val leftNullRow = new GenericInternalRow(left.output.length)
new RightOuterIterator(
@@ -209,7 +216,8 @@ case class SortMergeJoinExec(
createRightKeyGenerator(),
keyOrdering,
RowIterator.fromScala(leftIter),
- RowIterator.fromScala(rightIter)
+ RowIterator.fromScala(rightIter),
+ spillThreshold
)
private[this] val joinRow = new JoinedRow
@@ -217,14 +225,15 @@ case class SortMergeJoinExec(
while (smjScanner.findNextInnerJoinRows()) {
val currentRightMatches = smjScanner.getBufferedMatches
currentLeftRow = smjScanner.getStreamedRow
- var i = 0
- while (i < currentRightMatches.length) {
- joinRow(currentLeftRow, currentRightMatches(i))
- if (boundCondition(joinRow)) {
- numOutputRows += 1
- return true
+ if (currentRightMatches != null && currentRightMatches.length > 0) {
+ val rightMatchesIterator = currentRightMatches.generateIterator()
+ while (rightMatchesIterator.hasNext) {
+ joinRow(currentLeftRow, rightMatchesIterator.next())
+ if (boundCondition(joinRow)) {
+ numOutputRows += 1
+ return true
+ }
}
- i += 1
}
}
false
@@ -241,7 +250,8 @@ case class SortMergeJoinExec(
createRightKeyGenerator(),
keyOrdering,
RowIterator.fromScala(leftIter),
- RowIterator.fromScala(rightIter)
+ RowIterator.fromScala(rightIter),
+ spillThreshold
)
private[this] val joinRow = new JoinedRow
@@ -249,17 +259,16 @@ case class SortMergeJoinExec(
while (smjScanner.findNextOuterJoinRows()) {
currentLeftRow = smjScanner.getStreamedRow
val currentRightMatches = smjScanner.getBufferedMatches
- if (currentRightMatches == null) {
+ if (currentRightMatches == null || currentRightMatches.length == 0) {
return true
}
- var i = 0
var found = false
- while (!found && i < currentRightMatches.length) {
- joinRow(currentLeftRow, currentRightMatches(i))
+ val rightMatchesIterator = currentRightMatches.generateIterator()
+ while (!found && rightMatchesIterator.hasNext) {
+ joinRow(currentLeftRow, rightMatchesIterator.next())
if (boundCondition(joinRow)) {
found = true
}
- i += 1
}
if (!found) {
numOutputRows += 1
@@ -281,7 +290,8 @@ case class SortMergeJoinExec(
createRightKeyGenerator(),
keyOrdering,
RowIterator.fromScala(leftIter),
- RowIterator.fromScala(rightIter)
+ RowIterator.fromScala(rightIter),
+ spillThreshold
)
private[this] val joinRow = new JoinedRow
@@ -290,14 +300,13 @@ case class SortMergeJoinExec(
currentLeftRow = smjScanner.getStreamedRow
val currentRightMatches = smjScanner.getBufferedMatches
var found = false
- if (currentRightMatches != null) {
- var i = 0
- while (!found && i < currentRightMatches.length) {
- joinRow(currentLeftRow, currentRightMatches(i))
+ if (currentRightMatches != null && currentRightMatches.length > 0) {
+ val rightMatchesIterator = currentRightMatches.generateIterator()
+ while (!found && rightMatchesIterator.hasNext) {
+ joinRow(currentLeftRow, rightMatchesIterator.next())
if (boundCondition(joinRow)) {
found = true
}
- i += 1
}
}
result.setBoolean(0, found)
@@ -376,8 +385,11 @@ case class SortMergeJoinExec(
// A list to hold all matched rows from right side.
val matches = ctx.freshName("matches")
- val clsName = classOf[java.util.ArrayList[InternalRow]].getName
- ctx.addMutableState(clsName, matches, s"$matches = new $clsName();")
+ val clsName = classOf[ExternalAppendOnlyUnsafeRowArray].getName
+
+ val spillThreshold = getSpillThreshold
+
+ ctx.addMutableState(clsName, matches, s"$matches = new $clsName($spillThreshold);")
// Copy the left keys as class members so they could be used in next function call.
val matchedKeyVars = copyKeys(ctx, leftKeyVars)
@@ -428,7 +440,7 @@ case class SortMergeJoinExec(
| }
| $leftRow = null;
| } else {
- | $matches.add($rightRow.copy());
+ | $matches.add((UnsafeRow) $rightRow);
| $rightRow = null;;
| }
| } while ($leftRow != null);
@@ -517,8 +529,7 @@ case class SortMergeJoinExec(
val rightRow = ctx.freshName("rightRow")
val rightVars = createRightVar(ctx, rightRow)
- val size = ctx.freshName("size")
- val i = ctx.freshName("i")
+ val iterator = ctx.freshName("iterator")
val numOutput = metricTerm(ctx, "numOutputRows")
val (beforeLoop, condCheck) = if (condition.isDefined) {
// Split the code of creating variables based on whether it's used by condition or not.
@@ -551,10 +562,10 @@ case class SortMergeJoinExec(
s"""
|while (findNextInnerJoinRows($leftInput, $rightInput)) {
- | int $size = $matches.size();
| ${beforeLoop.trim}
- | for (int $i = 0; $i < $size; $i ++) {
- | InternalRow $rightRow = (InternalRow) $matches.get($i);
+ | scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator();
+ | while ($iterator.hasNext()) {
+ | InternalRow $rightRow = (InternalRow) $iterator.next();
| ${condCheck.trim}
| $numOutput.add(1);
| ${consume(ctx, leftVars ++ rightVars)}
@@ -589,7 +600,8 @@ private[joins] class SortMergeJoinScanner(
bufferedKeyGenerator: Projection,
keyOrdering: Ordering[InternalRow],
streamedIter: RowIterator,
- bufferedIter: RowIterator) {
+ bufferedIter: RowIterator,
+ bufferThreshold: Int) {
private[this] var streamedRow: InternalRow = _
private[this] var streamedRowKey: InternalRow = _
private[this] var bufferedRow: InternalRow = _
@@ -600,7 +612,7 @@ private[joins] class SortMergeJoinScanner(
*/
private[this] var matchJoinKey: InternalRow = _
/** Buffered rows from the buffered side of the join. This is empty if there are no matches. */
- private[this] val bufferedMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow]
+ private[this] val bufferedMatches = new ExternalAppendOnlyUnsafeRowArray(bufferThreshold)
// Initialization (note: do _not_ want to advance streamed here).
advancedBufferedToRowWithNullFreeJoinKey()
@@ -609,7 +621,7 @@ private[joins] class SortMergeJoinScanner(
def getStreamedRow: InternalRow = streamedRow
- def getBufferedMatches: ArrayBuffer[InternalRow] = bufferedMatches
+ def getBufferedMatches: ExternalAppendOnlyUnsafeRowArray = bufferedMatches
/**
* Advances both input iterators, stopping when we have found rows with matching join keys.
@@ -755,7 +767,7 @@ private[joins] class SortMergeJoinScanner(
matchJoinKey = streamedRowKey.copy()
bufferedMatches.clear()
do {
- bufferedMatches += bufferedRow.copy() // need to copy mutable rows before buffering them
+ bufferedMatches.add(bufferedRow.asInstanceOf[UnsafeRow])
advancedBufferedToRowWithNullFreeJoinKey()
} while (bufferedRow != null && keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0)
}
@@ -819,7 +831,7 @@ private abstract class OneSideOuterIterator(
protected[this] val joinedRow: JoinedRow = new JoinedRow()
// Index of the buffered rows, reset to 0 whenever we advance to a new streamed row
- private[this] var bufferIndex: Int = 0
+ private[this] var rightMatchesIterator: Iterator[UnsafeRow] = null
// This iterator is initialized lazily so there should be no matches initially
assert(smjScanner.getBufferedMatches.length == 0)
@@ -833,7 +845,7 @@ private abstract class OneSideOuterIterator(
* @return whether there are more rows in the stream to consume.
*/
private def advanceStream(): Boolean = {
- bufferIndex = 0
+ rightMatchesIterator = null
if (smjScanner.findNextOuterJoinRows()) {
setStreamSideOutput(smjScanner.getStreamedRow)
if (smjScanner.getBufferedMatches.isEmpty) {
@@ -858,10 +870,13 @@ private abstract class OneSideOuterIterator(
*/
private def advanceBufferUntilBoundConditionSatisfied(): Boolean = {
var foundMatch: Boolean = false
- while (!foundMatch && bufferIndex < smjScanner.getBufferedMatches.length) {
- setBufferedSideOutput(smjScanner.getBufferedMatches(bufferIndex))
+ if (rightMatchesIterator == null) {
+ rightMatchesIterator = smjScanner.getBufferedMatches.generateIterator()
+ }
+
+ while (!foundMatch && rightMatchesIterator.hasNext) {
+ setBufferedSideOutput(rightMatchesIterator.next())
foundMatch = boundCondition(joinedRow)
- bufferIndex += 1
}
foundMatch
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/RowBuffer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/RowBuffer.scala
deleted file mode 100644
index ee36c84251..0000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/RowBuffer.scala
+++ /dev/null
@@ -1,115 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.window
-
-import scala.collection.mutable.ArrayBuffer
-
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.UnsafeRow
-import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, UnsafeSorterIterator}
-
-
-/**
- * The interface of row buffer for a partition. In absence of a buffer pool (with locking), the
- * row buffer is used to materialize a partition of rows since we need to repeatedly scan these
- * rows in window function processing.
- */
-private[window] abstract class RowBuffer {
-
- /** Number of rows. */
- def size: Int
-
- /** Return next row in the buffer, null if no more left. */
- def next(): InternalRow
-
- /** Skip the next `n` rows. */
- def skip(n: Int): Unit
-
- /** Return a new RowBuffer that has the same rows. */
- def copy(): RowBuffer
-}
-
-/**
- * A row buffer based on ArrayBuffer (the number of rows is limited).
- */
-private[window] class ArrayRowBuffer(buffer: ArrayBuffer[UnsafeRow]) extends RowBuffer {
-
- private[this] var cursor: Int = -1
-
- /** Number of rows. */
- override def size: Int = buffer.length
-
- /** Return next row in the buffer, null if no more left. */
- override def next(): InternalRow = {
- cursor += 1
- if (cursor < buffer.length) {
- buffer(cursor)
- } else {
- null
- }
- }
-
- /** Skip the next `n` rows. */
- override def skip(n: Int): Unit = {
- cursor += n
- }
-
- /** Return a new RowBuffer that has the same rows. */
- override def copy(): RowBuffer = {
- new ArrayRowBuffer(buffer)
- }
-}
-
-/**
- * An external buffer of rows based on UnsafeExternalSorter.
- */
-private[window] class ExternalRowBuffer(sorter: UnsafeExternalSorter, numFields: Int)
- extends RowBuffer {
-
- private[this] val iter: UnsafeSorterIterator = sorter.getIterator
-
- private[this] val currentRow = new UnsafeRow(numFields)
-
- /** Number of rows. */
- override def size: Int = iter.getNumRecords()
-
- /** Return next row in the buffer, null if no more left. */
- override def next(): InternalRow = {
- if (iter.hasNext) {
- iter.loadNext()
- currentRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength)
- currentRow
- } else {
- null
- }
- }
-
- /** Skip the next `n` rows. */
- override def skip(n: Int): Unit = {
- var i = 0
- while (i < n && iter.hasNext) {
- iter.loadNext()
- i += 1
- }
- }
-
- /** Return a new RowBuffer that has the same rows. */
- override def copy(): RowBuffer = {
- new ExternalRowBuffer(sorter, numFields)
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala
index 80b87d5ffa..950a6794a7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala
@@ -20,15 +20,13 @@ package org.apache.spark.sql.execution.window
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.types.IntegerType
-import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
/**
* This class calculates and outputs (windowed) aggregates over the rows in a single (sorted)
@@ -284,6 +282,7 @@ case class WindowExec(
// Unwrap the expressions and factories from the map.
val expressions = windowFrameExpressionFactoryPairs.flatMap(_._1)
val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray
+ val spillThreshold = sqlContext.conf.windowExecBufferSpillThreshold
// Start processing.
child.execute().mapPartitions { stream =>
@@ -310,10 +309,12 @@ case class WindowExec(
fetchNextRow()
// Manage the current partition.
- val rows = ArrayBuffer.empty[UnsafeRow]
val inputFields = child.output.length
- var sorter: UnsafeExternalSorter = null
- var rowBuffer: RowBuffer = null
+
+ val buffer: ExternalAppendOnlyUnsafeRowArray =
+ new ExternalAppendOnlyUnsafeRowArray(spillThreshold)
+ var bufferIterator: Iterator[UnsafeRow] = _
+
val windowFunctionResult = new SpecificInternalRow(expressions.map(_.dataType))
val frames = factories.map(_(windowFunctionResult))
val numFrames = frames.length
@@ -323,78 +324,43 @@ case class WindowExec(
val currentGroup = nextGroup.copy()
// clear last partition
- if (sorter != null) {
- // the last sorter of this task will be cleaned up via task completion listener
- sorter.cleanupResources()
- sorter = null
- } else {
- rows.clear()
- }
+ buffer.clear()
while (nextRowAvailable && nextGroup == currentGroup) {
- if (sorter == null) {
- rows += nextRow.copy()
-
- if (rows.length >= 4096) {
- // We will not sort the rows, so prefixComparator and recordComparator are null.
- sorter = UnsafeExternalSorter.create(
- TaskContext.get().taskMemoryManager(),
- SparkEnv.get.blockManager,
- SparkEnv.get.serializerManager,
- TaskContext.get(),
- null,
- null,
- 1024,
- SparkEnv.get.memoryManager.pageSizeBytes,
- SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold",
- UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD),
- false)
- rows.foreach { r =>
- sorter.insertRecord(r.getBaseObject, r.getBaseOffset, r.getSizeInBytes, 0, false)
- }
- rows.clear()
- }
- } else {
- sorter.insertRecord(nextRow.getBaseObject, nextRow.getBaseOffset,
- nextRow.getSizeInBytes, 0, false)
- }
+ buffer.add(nextRow)
fetchNextRow()
}
- if (sorter != null) {
- rowBuffer = new ExternalRowBuffer(sorter, inputFields)
- } else {
- rowBuffer = new ArrayRowBuffer(rows)
- }
// Setup the frames.
var i = 0
while (i < numFrames) {
- frames(i).prepare(rowBuffer.copy())
+ frames(i).prepare(buffer)
i += 1
}
// Setup iteration
rowIndex = 0
- rowsSize = rowBuffer.size
+ bufferIterator = buffer.generateIterator()
}
// Iteration
var rowIndex = 0
- var rowsSize = 0L
- override final def hasNext: Boolean = rowIndex < rowsSize || nextRowAvailable
+ override final def hasNext: Boolean =
+ (bufferIterator != null && bufferIterator.hasNext) || nextRowAvailable
val join = new JoinedRow
override final def next(): InternalRow = {
// Load the next partition if we need to.
- if (rowIndex >= rowsSize && nextRowAvailable) {
+ if ((bufferIterator == null || !bufferIterator.hasNext) && nextRowAvailable) {
fetchNextPartition()
}
- if (rowIndex < rowsSize) {
+ if (bufferIterator.hasNext) {
+ val current = bufferIterator.next()
+
// Get the results for the window frames.
var i = 0
- val current = rowBuffer.next()
while (i < numFrames) {
frames(i).write(rowIndex, current)
i += 1
@@ -406,7 +372,9 @@ case class WindowExec(
// Return the projection.
result(join)
- } else throw new NoSuchElementException
+ } else {
+ throw new NoSuchElementException
+ }
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
index 70efc0f78d..af2b4fb920 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
@@ -22,6 +22,7 @@ import java.util
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
+import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray
/**
@@ -35,7 +36,7 @@ private[window] abstract class WindowFunctionFrame {
*
* @param rows to calculate the frame results for.
*/
- def prepare(rows: RowBuffer): Unit
+ def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit
/**
* Write the current results to the target row.
@@ -43,6 +44,12 @@ private[window] abstract class WindowFunctionFrame {
def write(index: Int, current: InternalRow): Unit
}
+object WindowFunctionFrame {
+ def getNextOrNull(iterator: Iterator[UnsafeRow]): UnsafeRow = {
+ if (iterator.hasNext) iterator.next() else null
+ }
+}
+
/**
* The offset window frame calculates frames containing LEAD/LAG statements.
*
@@ -65,7 +72,12 @@ private[window] final class OffsetWindowFunctionFrame(
extends WindowFunctionFrame {
/** Rows of the partition currently being processed. */
- private[this] var input: RowBuffer = null
+ private[this] var input: ExternalAppendOnlyUnsafeRowArray = null
+
+ /**
+ * An iterator over the [[input]]
+ */
+ private[this] var inputIterator: Iterator[UnsafeRow] = _
/** Index of the input row currently used for output. */
private[this] var inputIndex = 0
@@ -103,20 +115,21 @@ private[window] final class OffsetWindowFunctionFrame(
newMutableProjection(boundExpressions, Nil).target(target)
}
- override def prepare(rows: RowBuffer): Unit = {
+ override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
input = rows
+ inputIterator = input.generateIterator()
// drain the first few rows if offset is larger than zero
inputIndex = 0
while (inputIndex < offset) {
- input.next()
+ if (inputIterator.hasNext) inputIterator.next()
inputIndex += 1
}
inputIndex = offset
}
override def write(index: Int, current: InternalRow): Unit = {
- if (inputIndex >= 0 && inputIndex < input.size) {
- val r = input.next()
+ if (inputIndex >= 0 && inputIndex < input.length) {
+ val r = WindowFunctionFrame.getNextOrNull(inputIterator)
projection(r)
} else {
// Use default values since the offset row does not exist.
@@ -143,7 +156,12 @@ private[window] final class SlidingWindowFunctionFrame(
extends WindowFunctionFrame {
/** Rows of the partition currently being processed. */
- private[this] var input: RowBuffer = null
+ private[this] var input: ExternalAppendOnlyUnsafeRowArray = null
+
+ /**
+ * An iterator over the [[input]]
+ */
+ private[this] var inputIterator: Iterator[UnsafeRow] = _
/** The next row from `input`. */
private[this] var nextRow: InternalRow = null
@@ -164,9 +182,10 @@ private[window] final class SlidingWindowFunctionFrame(
private[this] var inputLowIndex = 0
/** Prepare the frame for calculating a new partition. Reset all variables. */
- override def prepare(rows: RowBuffer): Unit = {
+ override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
input = rows
- nextRow = rows.next()
+ inputIterator = input.generateIterator()
+ nextRow = WindowFunctionFrame.getNextOrNull(inputIterator)
inputHighIndex = 0
inputLowIndex = 0
buffer.clear()
@@ -180,7 +199,7 @@ private[window] final class SlidingWindowFunctionFrame(
// the output row upper bound.
while (nextRow != null && ubound.compare(nextRow, inputHighIndex, current, index) <= 0) {
buffer.add(nextRow.copy())
- nextRow = input.next()
+ nextRow = WindowFunctionFrame.getNextOrNull(inputIterator)
inputHighIndex += 1
bufferUpdated = true
}
@@ -195,7 +214,7 @@ private[window] final class SlidingWindowFunctionFrame(
// Only recalculate and update when the buffer changes.
if (bufferUpdated) {
- processor.initialize(input.size)
+ processor.initialize(input.length)
val iter = buffer.iterator()
while (iter.hasNext) {
processor.update(iter.next())
@@ -222,13 +241,12 @@ private[window] final class UnboundedWindowFunctionFrame(
extends WindowFunctionFrame {
/** Prepare the frame for calculating a new partition. Process all rows eagerly. */
- override def prepare(rows: RowBuffer): Unit = {
- val size = rows.size
- processor.initialize(size)
- var i = 0
- while (i < size) {
- processor.update(rows.next())
- i += 1
+ override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
+ processor.initialize(rows.length)
+
+ val iterator = rows.generateIterator()
+ while (iterator.hasNext) {
+ processor.update(iterator.next())
}
}
@@ -261,7 +279,12 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame(
extends WindowFunctionFrame {
/** Rows of the partition currently being processed. */
- private[this] var input: RowBuffer = null
+ private[this] var input: ExternalAppendOnlyUnsafeRowArray = null
+
+ /**
+ * An iterator over the [[input]]
+ */
+ private[this] var inputIterator: Iterator[UnsafeRow] = _
/** The next row from `input`. */
private[this] var nextRow: InternalRow = null
@@ -273,11 +296,15 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame(
private[this] var inputIndex = 0
/** Prepare the frame for calculating a new partition. */
- override def prepare(rows: RowBuffer): Unit = {
+ override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
input = rows
- nextRow = rows.next()
inputIndex = 0
- processor.initialize(input.size)
+ inputIterator = input.generateIterator()
+ if (inputIterator.hasNext) {
+ nextRow = inputIterator.next()
+ }
+
+ processor.initialize(input.length)
}
/** Write the frame columns for the current row to the given target row. */
@@ -288,7 +315,7 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame(
// the output row upper bound.
while (nextRow != null && ubound.compare(nextRow, inputIndex, current, index) <= 0) {
processor.update(nextRow)
- nextRow = input.next()
+ nextRow = WindowFunctionFrame.getNextOrNull(inputIterator)
inputIndex += 1
bufferUpdated = true
}
@@ -323,7 +350,7 @@ private[window] final class UnboundedFollowingWindowFunctionFrame(
extends WindowFunctionFrame {
/** Rows of the partition currently being processed. */
- private[this] var input: RowBuffer = null
+ private[this] var input: ExternalAppendOnlyUnsafeRowArray = null
/**
* Index of the first input row with a value equal to or greater than the lower bound of the
@@ -332,7 +359,7 @@ private[window] final class UnboundedFollowingWindowFunctionFrame(
private[this] var inputIndex = 0
/** Prepare the frame for calculating a new partition. */
- override def prepare(rows: RowBuffer): Unit = {
+ override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
input = rows
inputIndex = 0
}
@@ -341,25 +368,25 @@ private[window] final class UnboundedFollowingWindowFunctionFrame(
override def write(index: Int, current: InternalRow): Unit = {
var bufferUpdated = index == 0
- // Duplicate the input to have a new iterator
- val tmp = input.copy()
-
- // Drop all rows from the buffer for which the input row value is smaller than
+ // Ignore all the rows from the buffer for which the input row value is smaller than
// the output row lower bound.
- tmp.skip(inputIndex)
- var nextRow = tmp.next()
+ val iterator = input.generateIterator(startIndex = inputIndex)
+
+ var nextRow = WindowFunctionFrame.getNextOrNull(iterator)
while (nextRow != null && lbound.compare(nextRow, inputIndex, current, index) < 0) {
- nextRow = tmp.next()
inputIndex += 1
bufferUpdated = true
+ nextRow = WindowFunctionFrame.getNextOrNull(iterator)
}
// Only recalculate and update when the buffer changes.
if (bufferUpdated) {
- processor.initialize(input.size)
- while (nextRow != null) {
+ processor.initialize(input.length)
+ if (nextRow != null) {
processor.update(nextRow)
- nextRow = tmp.next()
+ }
+ while (iterator.hasNext) {
+ processor.update(iterator.next())
}
processor.evaluate(target)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 2e006735d1..1a66aa85f5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql
+import scala.collection.mutable.ListBuffer
import scala.language.existentials
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
@@ -24,7 +25,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
-
+import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled}
class JoinSuite extends QueryTest with SharedSQLContext {
import testImplicits._
@@ -604,4 +605,137 @@ class JoinSuite extends QueryTest with SharedSQLContext {
cartesianQueries.foreach(checkCartesianDetection)
}
+
+ test("test SortMergeJoin (without spill)") {
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1",
+ "spark.sql.sortMergeJoinExec.buffer.spill.threshold" -> Int.MaxValue.toString) {
+
+ assertNotSpilled(sparkContext, "inner join") {
+ checkAnswer(
+ sql("SELECT * FROM testData JOIN testData2 ON key = a where key = 2"),
+ Row(2, "2", 2, 1) :: Row(2, "2", 2, 2) :: Nil
+ )
+ }
+
+ val expected = new ListBuffer[Row]()
+ expected.append(
+ Row(1, "1", 1, 1), Row(1, "1", 1, 2),
+ Row(2, "2", 2, 1), Row(2, "2", 2, 2),
+ Row(3, "3", 3, 1), Row(3, "3", 3, 2)
+ )
+ for (i <- 4 to 100) {
+ expected.append(Row(i, i.toString, null, null))
+ }
+
+ assertNotSpilled(sparkContext, "left outer join") {
+ checkAnswer(
+ sql(
+ """
+ |SELECT
+ | big.key, big.value, small.a, small.b
+ |FROM
+ | testData big
+ |LEFT OUTER JOIN
+ | testData2 small
+ |ON
+ | big.key = small.a
+ """.stripMargin),
+ expected
+ )
+ }
+
+ assertNotSpilled(sparkContext, "right outer join") {
+ checkAnswer(
+ sql(
+ """
+ |SELECT
+ | big.key, big.value, small.a, small.b
+ |FROM
+ | testData2 small
+ |RIGHT OUTER JOIN
+ | testData big
+ |ON
+ | big.key = small.a
+ """.stripMargin),
+ expected
+ )
+ }
+ }
+ }
+
+ test("test SortMergeJoin (with spill)") {
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1",
+ "spark.sql.sortMergeJoinExec.buffer.spill.threshold" -> "0") {
+
+ assertSpilled(sparkContext, "inner join") {
+ checkAnswer(
+ sql("SELECT * FROM testData JOIN testData2 ON key = a where key = 2"),
+ Row(2, "2", 2, 1) :: Row(2, "2", 2, 2) :: Nil
+ )
+ }
+
+ val expected = new ListBuffer[Row]()
+ expected.append(
+ Row(1, "1", 1, 1), Row(1, "1", 1, 2),
+ Row(2, "2", 2, 1), Row(2, "2", 2, 2),
+ Row(3, "3", 3, 1), Row(3, "3", 3, 2)
+ )
+ for (i <- 4 to 100) {
+ expected.append(Row(i, i.toString, null, null))
+ }
+
+ assertSpilled(sparkContext, "left outer join") {
+ checkAnswer(
+ sql(
+ """
+ |SELECT
+ | big.key, big.value, small.a, small.b
+ |FROM
+ | testData big
+ |LEFT OUTER JOIN
+ | testData2 small
+ |ON
+ | big.key = small.a
+ """.stripMargin),
+ expected
+ )
+ }
+
+ assertSpilled(sparkContext, "right outer join") {
+ checkAnswer(
+ sql(
+ """
+ |SELECT
+ | big.key, big.value, small.a, small.b
+ |FROM
+ | testData2 small
+ |RIGHT OUTER JOIN
+ | testData big
+ |ON
+ | big.key = small.a
+ """.stripMargin),
+ expected
+ )
+ }
+
+ // FULL OUTER JOIN still does not use [[ExternalAppendOnlyUnsafeRowArray]]
+ // so should not cause any spill
+ assertNotSpilled(sparkContext, "full outer join") {
+ checkAnswer(
+ sql(
+ """
+ |SELECT
+ | big.key, big.value, small.a, small.b
+ |FROM
+ | testData2 small
+ |FULL OUTER JOIN
+ | testData big
+ |ON
+ | big.key = small.a
+ """.stripMargin),
+ expected
+ )
+ }
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala
new file mode 100644
index 0000000000..00c5f2550c
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala
@@ -0,0 +1,233 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.{SparkConf, SparkContext, SparkEnv, TaskContext}
+import org.apache.spark.memory.MemoryTestingUtils
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.util.Benchmark
+import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
+
+object ExternalAppendOnlyUnsafeRowArrayBenchmark {
+
+ def testAgainstRawArrayBuffer(numSpillThreshold: Int, numRows: Int, iterations: Int): Unit = {
+ val random = new java.util.Random()
+ val rows = (1 to numRows).map(_ => {
+ val row = new UnsafeRow(1)
+ row.pointTo(new Array[Byte](64), 16)
+ row.setLong(0, random.nextLong())
+ row
+ })
+
+ val benchmark = new Benchmark(s"Array with $numRows rows", iterations * numRows)
+
+ // Internally, `ExternalAppendOnlyUnsafeRowArray` will create an
+ // in-memory buffer of size `numSpillThreshold`. This will mimic that
+ val initialSize =
+ Math.min(
+ ExternalAppendOnlyUnsafeRowArray.DefaultInitialSizeOfInMemoryBuffer,
+ numSpillThreshold)
+
+ benchmark.addCase("ArrayBuffer") { _: Int =>
+ var sum = 0L
+ for (_ <- 0L until iterations) {
+ val array = new ArrayBuffer[UnsafeRow](initialSize)
+
+ // Internally, `ExternalAppendOnlyUnsafeRowArray` will create a
+ // copy of the row. This will mimic that
+ rows.foreach(x => array += x.copy())
+
+ var i = 0
+ val n = array.length
+ while (i < n) {
+ sum = sum + array(i).getLong(0)
+ i += 1
+ }
+ array.clear()
+ }
+ }
+
+ benchmark.addCase("ExternalAppendOnlyUnsafeRowArray") { _: Int =>
+ var sum = 0L
+ for (_ <- 0L until iterations) {
+ val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold)
+ rows.foreach(x => array.add(x))
+
+ val iterator = array.generateIterator()
+ while (iterator.hasNext) {
+ sum = sum + iterator.next().getLong(0)
+ }
+ array.clear()
+ }
+ }
+
+ val conf = new SparkConf(false)
+ // Make the Java serializer write a reset instruction (TC_RESET) after each object to test
+ // for a bug we had with bytes written past the last object in a batch (SPARK-2792)
+ conf.set("spark.serializer.objectStreamReset", "1")
+ conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer")
+
+ val sc = new SparkContext("local", "test", conf)
+ val taskContext = MemoryTestingUtils.fakeTaskContext(SparkEnv.get)
+ TaskContext.setTaskContext(taskContext)
+ benchmark.run()
+ sc.stop()
+ }
+
+ def testAgainstRawUnsafeExternalSorter(
+ numSpillThreshold: Int,
+ numRows: Int,
+ iterations: Int): Unit = {
+
+ val random = new java.util.Random()
+ val rows = (1 to numRows).map(_ => {
+ val row = new UnsafeRow(1)
+ row.pointTo(new Array[Byte](64), 16)
+ row.setLong(0, random.nextLong())
+ row
+ })
+
+ val benchmark = new Benchmark(s"Spilling with $numRows rows", iterations * numRows)
+
+ benchmark.addCase("UnsafeExternalSorter") { _: Int =>
+ var sum = 0L
+ for (_ <- 0L until iterations) {
+ val array = UnsafeExternalSorter.create(
+ TaskContext.get().taskMemoryManager(),
+ SparkEnv.get.blockManager,
+ SparkEnv.get.serializerManager,
+ TaskContext.get(),
+ null,
+ null,
+ 1024,
+ SparkEnv.get.memoryManager.pageSizeBytes,
+ numSpillThreshold,
+ false)
+
+ rows.foreach(x =>
+ array.insertRecord(
+ x.getBaseObject,
+ x.getBaseOffset,
+ x.getSizeInBytes,
+ 0,
+ false))
+
+ val unsafeRow = new UnsafeRow(1)
+ val iter = array.getIterator
+ while (iter.hasNext) {
+ iter.loadNext()
+ unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength)
+ sum = sum + unsafeRow.getLong(0)
+ }
+ array.cleanupResources()
+ }
+ }
+
+ benchmark.addCase("ExternalAppendOnlyUnsafeRowArray") { _: Int =>
+ var sum = 0L
+ for (_ <- 0L until iterations) {
+ val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold)
+ rows.foreach(x => array.add(x))
+
+ val iterator = array.generateIterator()
+ while (iterator.hasNext) {
+ sum = sum + iterator.next().getLong(0)
+ }
+ array.clear()
+ }
+ }
+
+ val conf = new SparkConf(false)
+ // Make the Java serializer write a reset instruction (TC_RESET) after each object to test
+ // for a bug we had with bytes written past the last object in a batch (SPARK-2792)
+ conf.set("spark.serializer.objectStreamReset", "1")
+ conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer")
+
+ val sc = new SparkContext("local", "test", conf)
+ val taskContext = MemoryTestingUtils.fakeTaskContext(SparkEnv.get)
+ TaskContext.setTaskContext(taskContext)
+ benchmark.run()
+ sc.stop()
+ }
+
+ def main(args: Array[String]): Unit = {
+
+ // ========================================================================================= //
+ // WITHOUT SPILL
+ // ========================================================================================= //
+
+ val spillThreshold = 100 * 1000
+
+ /*
+ Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz
+
+ Array with 1000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ ------------------------------------------------------------------------------------------------
+ ArrayBuffer 7821 / 7941 33.5 29.8 1.0X
+ ExternalAppendOnlyUnsafeRowArray 8798 / 8819 29.8 33.6 0.9X
+ */
+ testAgainstRawArrayBuffer(spillThreshold, 1000, 1 << 18)
+
+ /*
+ Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz
+
+ Array with 30000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ ------------------------------------------------------------------------------------------------
+ ArrayBuffer 19200 / 19206 25.6 39.1 1.0X
+ ExternalAppendOnlyUnsafeRowArray 19558 / 19562 25.1 39.8 1.0X
+ */
+ testAgainstRawArrayBuffer(spillThreshold, 30 * 1000, 1 << 14)
+
+ /*
+ Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz
+
+ Array with 100000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ ------------------------------------------------------------------------------------------------
+ ArrayBuffer 5949 / 6028 17.2 58.1 1.0X
+ ExternalAppendOnlyUnsafeRowArray 6078 / 6138 16.8 59.4 1.0X
+ */
+ testAgainstRawArrayBuffer(spillThreshold, 100 * 1000, 1 << 10)
+
+ // ========================================================================================= //
+ // WITH SPILL
+ // ========================================================================================= //
+
+ /*
+ Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz
+
+ Spilling with 1000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ ------------------------------------------------------------------------------------------------
+ UnsafeExternalSorter 9239 / 9470 28.4 35.2 1.0X
+ ExternalAppendOnlyUnsafeRowArray 8857 / 8909 29.6 33.8 1.0X
+ */
+ testAgainstRawUnsafeExternalSorter(100 * 1000, 1000, 1 << 18)
+
+ /*
+ Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz
+
+ Spilling with 10000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ ------------------------------------------------------------------------------------------------
+ UnsafeExternalSorter 4 / 5 39.3 25.5 1.0X
+ ExternalAppendOnlyUnsafeRowArray 5 / 6 29.8 33.5 0.8X
+ */
+ testAgainstRawUnsafeExternalSorter(
+ UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt, 10 * 1000, 1 << 4)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala
new file mode 100644
index 0000000000..53c4163994
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala
@@ -0,0 +1,351 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import java.util.ConcurrentModificationException
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark._
+import org.apache.spark.memory.MemoryTestingUtils
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+
+class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSparkContext {
+ private val random = new java.util.Random()
+ private var taskContext: TaskContext = _
+
+ override def afterAll(): Unit = TaskContext.unset()
+
+ private def withExternalArray(spillThreshold: Int)
+ (f: ExternalAppendOnlyUnsafeRowArray => Unit): Unit = {
+ sc = new SparkContext("local", "test", new SparkConf(false))
+
+ taskContext = MemoryTestingUtils.fakeTaskContext(SparkEnv.get)
+ TaskContext.setTaskContext(taskContext)
+
+ val array = new ExternalAppendOnlyUnsafeRowArray(
+ taskContext.taskMemoryManager(),
+ SparkEnv.get.blockManager,
+ SparkEnv.get.serializerManager,
+ taskContext,
+ 1024,
+ SparkEnv.get.memoryManager.pageSizeBytes,
+ spillThreshold)
+ try f(array) finally {
+ array.clear()
+ }
+ }
+
+ private def insertRow(array: ExternalAppendOnlyUnsafeRowArray): Long = {
+ val valueInserted = random.nextLong()
+
+ val row = new UnsafeRow(1)
+ row.pointTo(new Array[Byte](64), 16)
+ row.setLong(0, valueInserted)
+ array.add(row)
+ valueInserted
+ }
+
+ private def checkIfValueExists(iterator: Iterator[UnsafeRow], expectedValue: Long): Unit = {
+ assert(iterator.hasNext)
+ val actualRow = iterator.next()
+ assert(actualRow.getLong(0) == expectedValue)
+ assert(actualRow.getSizeInBytes == 16)
+ }
+
+ private def validateData(
+ array: ExternalAppendOnlyUnsafeRowArray,
+ expectedValues: ArrayBuffer[Long]): Iterator[UnsafeRow] = {
+ val iterator = array.generateIterator()
+ for (value <- expectedValues) {
+ checkIfValueExists(iterator, value)
+ }
+
+ assert(!iterator.hasNext)
+ iterator
+ }
+
+ private def populateRows(
+ array: ExternalAppendOnlyUnsafeRowArray,
+ numRowsToBePopulated: Int): ArrayBuffer[Long] = {
+ val populatedValues = new ArrayBuffer[Long]
+ populateRows(array, numRowsToBePopulated, populatedValues)
+ }
+
+ private def populateRows(
+ array: ExternalAppendOnlyUnsafeRowArray,
+ numRowsToBePopulated: Int,
+ populatedValues: ArrayBuffer[Long]): ArrayBuffer[Long] = {
+ for (_ <- 0 until numRowsToBePopulated) {
+ populatedValues.append(insertRow(array))
+ }
+ populatedValues
+ }
+
+ private def getNumBytesSpilled: Long = {
+ TaskContext.get().taskMetrics().memoryBytesSpilled
+ }
+
+ private def assertNoSpill(): Unit = {
+ assert(getNumBytesSpilled == 0)
+ }
+
+ private def assertSpill(): Unit = {
+ assert(getNumBytesSpilled > 0)
+ }
+
+ test("insert rows less than the spillThreshold") {
+ val spillThreshold = 100
+ withExternalArray(spillThreshold) { array =>
+ assert(array.isEmpty)
+
+ val expectedValues = populateRows(array, 1)
+ assert(!array.isEmpty)
+ assert(array.length == 1)
+
+ val iterator1 = validateData(array, expectedValues)
+
+ // Add more rows (but not too many to trigger switch to [[UnsafeExternalSorter]])
+ // Verify that NO spill has happened
+ populateRows(array, spillThreshold - 1, expectedValues)
+ assert(array.length == spillThreshold)
+ assertNoSpill()
+
+ val iterator2 = validateData(array, expectedValues)
+
+ assert(!iterator1.hasNext)
+ assert(!iterator2.hasNext)
+ }
+ }
+
+ test("insert rows more than the spillThreshold to force spill") {
+ val spillThreshold = 100
+ withExternalArray(spillThreshold) { array =>
+ val numValuesInserted = 20 * spillThreshold
+
+ assert(array.isEmpty)
+ val expectedValues = populateRows(array, 1)
+ assert(array.length == 1)
+
+ val iterator1 = validateData(array, expectedValues)
+
+ // Populate more rows to trigger spill. Verify that spill has happened
+ populateRows(array, numValuesInserted - 1, expectedValues)
+ assert(array.length == numValuesInserted)
+ assertSpill()
+
+ val iterator2 = validateData(array, expectedValues)
+ assert(!iterator2.hasNext)
+
+ assert(!iterator1.hasNext)
+ intercept[ConcurrentModificationException](iterator1.next())
+ }
+ }
+
+ test("iterator on an empty array should be empty") {
+ withExternalArray(spillThreshold = 10) { array =>
+ val iterator = array.generateIterator()
+ assert(array.isEmpty)
+ assert(array.length == 0)
+ assert(!iterator.hasNext)
+ }
+ }
+
+ test("generate iterator with negative start index") {
+ withExternalArray(spillThreshold = 2) { array =>
+ val exception =
+ intercept[ArrayIndexOutOfBoundsException](array.generateIterator(startIndex = -10))
+
+ assert(exception.getMessage.contains(
+ "Invalid `startIndex` provided for generating iterator over the array")
+ )
+ }
+ }
+
+ test("generate iterator with start index exceeding array's size (without spill)") {
+ val spillThreshold = 2
+ withExternalArray(spillThreshold) { array =>
+ populateRows(array, spillThreshold / 2)
+
+ val exception =
+ intercept[ArrayIndexOutOfBoundsException](
+ array.generateIterator(startIndex = spillThreshold * 10))
+ assert(exception.getMessage.contains(
+ "Invalid `startIndex` provided for generating iterator over the array"))
+ }
+ }
+
+ test("generate iterator with start index exceeding array's size (with spill)") {
+ val spillThreshold = 2
+ withExternalArray(spillThreshold) { array =>
+ populateRows(array, spillThreshold * 2)
+
+ val exception =
+ intercept[ArrayIndexOutOfBoundsException](
+ array.generateIterator(startIndex = spillThreshold * 10))
+
+ assert(exception.getMessage.contains(
+ "Invalid `startIndex` provided for generating iterator over the array"))
+ }
+ }
+
+ test("generate iterator with custom start index (without spill)") {
+ val spillThreshold = 10
+ withExternalArray(spillThreshold) { array =>
+ val expectedValues = populateRows(array, spillThreshold)
+ val startIndex = spillThreshold / 2
+ val iterator = array.generateIterator(startIndex = startIndex)
+ for (i <- startIndex until expectedValues.length) {
+ checkIfValueExists(iterator, expectedValues(i))
+ }
+ }
+ }
+
+ test("generate iterator with custom start index (with spill)") {
+ val spillThreshold = 10
+ withExternalArray(spillThreshold) { array =>
+ val expectedValues = populateRows(array, spillThreshold * 10)
+ val startIndex = spillThreshold * 2
+ val iterator = array.generateIterator(startIndex = startIndex)
+ for (i <- startIndex until expectedValues.length) {
+ checkIfValueExists(iterator, expectedValues(i))
+ }
+ }
+ }
+
+ test("test iterator invalidation (without spill)") {
+ withExternalArray(spillThreshold = 10) { array =>
+ // insert 2 rows, iterate until the first row
+ populateRows(array, 2)
+
+ var iterator = array.generateIterator()
+ assert(iterator.hasNext)
+ iterator.next()
+
+ // Adding more row(s) should invalidate any old iterators
+ populateRows(array, 1)
+ assert(!iterator.hasNext)
+ intercept[ConcurrentModificationException](iterator.next())
+
+ // Clearing the array should also invalidate any old iterators
+ iterator = array.generateIterator()
+ assert(iterator.hasNext)
+ iterator.next()
+
+ array.clear()
+ assert(!iterator.hasNext)
+ intercept[ConcurrentModificationException](iterator.next())
+ }
+ }
+
+ test("test iterator invalidation (with spill)") {
+ val spillThreshold = 10
+ withExternalArray(spillThreshold) { array =>
+ // Populate enough rows so that spill has happens
+ populateRows(array, spillThreshold * 2)
+ assertSpill()
+
+ var iterator = array.generateIterator()
+ assert(iterator.hasNext)
+ iterator.next()
+
+ // Adding more row(s) should invalidate any old iterators
+ populateRows(array, 1)
+ assert(!iterator.hasNext)
+ intercept[ConcurrentModificationException](iterator.next())
+
+ // Clearing the array should also invalidate any old iterators
+ iterator = array.generateIterator()
+ assert(iterator.hasNext)
+ iterator.next()
+
+ array.clear()
+ assert(!iterator.hasNext)
+ intercept[ConcurrentModificationException](iterator.next())
+ }
+ }
+
+ test("clear on an empty the array") {
+ withExternalArray(spillThreshold = 2) { array =>
+ val iterator = array.generateIterator()
+ assert(!iterator.hasNext)
+
+ // multiple clear'ing should not have an side-effect
+ array.clear()
+ array.clear()
+ array.clear()
+ assert(array.isEmpty)
+ assert(array.length == 0)
+
+ // Clearing an empty array should also invalidate any old iterators
+ assert(!iterator.hasNext)
+ intercept[ConcurrentModificationException](iterator.next())
+ }
+ }
+
+ test("clear array (without spill)") {
+ val spillThreshold = 10
+ withExternalArray(spillThreshold) { array =>
+ // Populate rows ... but not enough to trigger spill
+ populateRows(array, spillThreshold / 2)
+ assertNoSpill()
+
+ // Clear the array
+ array.clear()
+ assert(array.isEmpty)
+
+ // Re-populate few rows so that there is no spill
+ // Verify the data. Verify that there was no spill
+ val expectedValues = populateRows(array, spillThreshold / 3)
+ validateData(array, expectedValues)
+ assertNoSpill()
+
+ // Populate more rows .. enough to not trigger a spill.
+ // Verify the data. Verify that there was no spill
+ populateRows(array, spillThreshold / 3, expectedValues)
+ validateData(array, expectedValues)
+ assertNoSpill()
+ }
+ }
+
+ test("clear array (with spill)") {
+ val spillThreshold = 10
+ withExternalArray(spillThreshold) { array =>
+ // Populate enough rows to trigger spill
+ populateRows(array, spillThreshold * 2)
+ val bytesSpilled = getNumBytesSpilled
+ assert(bytesSpilled > 0)
+
+ // Clear the array
+ array.clear()
+ assert(array.isEmpty)
+
+ // Re-populate the array ... but NOT upto the point that there is spill.
+ // Verify data. Verify that there was NO "extra" spill
+ val expectedValues = populateRows(array, spillThreshold / 2)
+ validateData(array, expectedValues)
+ assert(getNumBytesSpilled == bytesSpilled)
+
+ // Populate more rows to trigger spill
+ // Verify the data. Verify that there was "extra" spill
+ populateRows(array, spillThreshold * 2, expectedValues)
+ validateData(array, expectedValues)
+ assert(getNumBytesSpilled > bytesSpilled)
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala
index afd47897ed..52e4f04722 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.TestUtils.assertSpilled
case class WindowData(month: Int, area: String, product: Int)
@@ -412,4 +413,36 @@ class SQLWindowFunctionSuite extends QueryTest with SharedSQLContext {
""".stripMargin),
Row(1, 3, null) :: Row(2, null, 4) :: Nil)
}
+
+ test("test with low buffer spill threshold") {
+ val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y")
+ nums.createOrReplaceTempView("nums")
+
+ val expected =
+ Row(1, 1, 1) ::
+ Row(0, 2, 3) ::
+ Row(1, 3, 6) ::
+ Row(0, 4, 10) ::
+ Row(1, 5, 15) ::
+ Row(0, 6, 21) ::
+ Row(1, 7, 28) ::
+ Row(0, 8, 36) ::
+ Row(1, 9, 45) ::
+ Row(0, 10, 55) :: Nil
+
+ val actual = sql(
+ """
+ |SELECT y, x, sum(x) OVER w1 AS running_sum
+ |FROM nums
+ |WINDOW w1 AS (ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDiNG AND CURRENT RoW)
+ """.stripMargin)
+
+ withSQLConf("spark.sql.windowExec.buffer.spill.threshold" -> "1") {
+ assertSpilled(sparkContext, "test with low buffer spill threshold") {
+ checkAnswer(actual, expected)
+ }
+ }
+
+ spark.catalog.dropTempView("nums")
+ }
}