aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java21
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java18
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java2
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java7
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala314
6 files changed, 276 insertions, 94 deletions
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index 77d0b70bb8..68dc0c6d41 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -45,7 +45,9 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
private final Logger logger = LoggerFactory.getLogger(UnsafeExternalSorter.class);
+ @Nullable
private final PrefixComparator prefixComparator;
+ @Nullable
private final RecordComparator recordComparator;
private final TaskMemoryManager taskMemoryManager;
private final BlockManager blockManager;
@@ -431,7 +433,11 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
public SpillableIterator(UnsafeInMemorySorter.SortedIterator inMemIterator) {
this.upstream = inMemIterator;
- this.numRecords = inMemIterator.numRecordsLeft();
+ this.numRecords = inMemIterator.getNumRecords();
+ }
+
+ public int getNumRecords() {
+ return numRecords;
}
public long spill() throws IOException {
@@ -558,14 +564,24 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
private final Queue<UnsafeSorterIterator> iterators;
private UnsafeSorterIterator current;
+ private int numRecords;
public ChainedIterator(Queue<UnsafeSorterIterator> iterators) {
assert iterators.size() > 0;
+ this.numRecords = 0;
+ for (UnsafeSorterIterator iter: iterators) {
+ this.numRecords += iter.getNumRecords();
+ }
this.iterators = iterators;
this.current = iterators.remove();
}
@Override
+ public int getNumRecords() {
+ return numRecords;
+ }
+
+ @Override
public boolean hasNext() {
while (!current.hasNext() && !iterators.isEmpty()) {
current = iterators.remove();
@@ -575,6 +591,9 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
@Override
public void loadNext() throws IOException {
+ while (!current.hasNext() && !iterators.isEmpty()) {
+ current = iterators.remove();
+ }
current.loadNext();
}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index b7ab45675e..f71b8d154c 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -19,6 +19,8 @@ package org.apache.spark.util.collection.unsafe.sort;
import java.util.Comparator;
+import org.apache.avro.reflect.Nullable;
+
import org.apache.spark.memory.MemoryConsumer;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.unsafe.Platform;
@@ -66,7 +68,9 @@ public final class UnsafeInMemorySorter {
private final MemoryConsumer consumer;
private final TaskMemoryManager memoryManager;
+ @Nullable
private final Sorter<RecordPointerAndKeyPrefix, LongArray> sorter;
+ @Nullable
private final Comparator<RecordPointerAndKeyPrefix> sortComparator;
/**
@@ -98,10 +102,11 @@ public final class UnsafeInMemorySorter {
LongArray array) {
this.consumer = consumer;
this.memoryManager = memoryManager;
- this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE);
if (recordComparator != null) {
+ this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE);
this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager);
} else {
+ this.sorter = null;
this.sortComparator = null;
}
this.array = array;
@@ -190,12 +195,13 @@ public final class UnsafeInMemorySorter {
}
@Override
- public boolean hasNext() {
- return position / 2 < numRecords;
+ public int getNumRecords() {
+ return numRecords;
}
- public int numRecordsLeft() {
- return numRecords - position / 2;
+ @Override
+ public boolean hasNext() {
+ return position / 2 < numRecords;
}
@Override
@@ -227,7 +233,7 @@ public final class UnsafeInMemorySorter {
* {@code next()} will return the same mutable object.
*/
public SortedIterator getSortedIterator() {
- if (sortComparator != null) {
+ if (sorter != null) {
sorter.sort(array, 0, pos / 2, sortComparator);
}
return new SortedIterator(pos / 2);
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java
index 16ac2e8d82..1b3167fcc2 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java
@@ -32,4 +32,6 @@ public abstract class UnsafeSorterIterator {
public abstract int getRecordLength();
public abstract long getKeyPrefix();
+
+ public abstract int getNumRecords();
}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
index 3874a9f9cb..ceb59352af 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java
@@ -23,6 +23,7 @@ import java.util.PriorityQueue;
final class UnsafeSorterSpillMerger {
+ private int numRecords = 0;
private final PriorityQueue<UnsafeSorterIterator> priorityQueue;
public UnsafeSorterSpillMerger(
@@ -59,6 +60,7 @@ final class UnsafeSorterSpillMerger {
// priorityQueue, we will have n extra empty records in the result of the UnsafeSorterIterator.
spillReader.loadNext();
priorityQueue.add(spillReader);
+ numRecords += spillReader.getNumRecords();
}
}
@@ -68,6 +70,11 @@ final class UnsafeSorterSpillMerger {
private UnsafeSorterIterator spillReader;
@Override
+ public int getNumRecords() {
+ return numRecords;
+ }
+
+ @Override
public boolean hasNext() {
return !priorityQueue.isEmpty() || (spillReader != null && spillReader.hasNext());
}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
index dcb13e6581..20ee1c8eb0 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
@@ -38,6 +38,7 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen
// Variables that change with every record read:
private int recordLength;
private long keyPrefix;
+ private int numRecords;
private int numRecordsRemaining;
private byte[] arr = new byte[1024 * 1024];
@@ -53,7 +54,7 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen
try {
this.in = blockManager.wrapForCompression(blockId, bs);
this.din = new DataInputStream(this.in);
- numRecordsRemaining = din.readInt();
+ numRecords = numRecordsRemaining = din.readInt();
} catch (IOException e) {
Closeables.close(bs, /* swallowIOException = */ true);
throw e;
@@ -61,6 +62,11 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen
}
@Override
+ public int getNumRecords() {
+ return numRecords;
+ }
+
+ @Override
public boolean hasNext() {
return (numRecordsRemaining > 0);
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
index 89b17c8245..be885397a7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution
+import java.util
+
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
@@ -26,6 +28,8 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.types.IntegerType
+import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, UnsafeSorterIterator}
+import org.apache.spark.{SparkEnv, TaskContext}
/**
* This class calculates and outputs (windowed) aggregates over the rows in a single (sorted)
@@ -283,23 +287,26 @@ case class Window(
val grouping = UnsafeProjection.create(partitionSpec, child.output)
// Manage the stream and the grouping.
- var nextRow: InternalRow = EmptyRow
- var nextGroup: InternalRow = EmptyRow
+ var nextRow: UnsafeRow = null
+ var nextGroup: UnsafeRow = null
var nextRowAvailable: Boolean = false
private[this] def fetchNextRow() {
nextRowAvailable = stream.hasNext
if (nextRowAvailable) {
- nextRow = stream.next()
+ nextRow = stream.next().asInstanceOf[UnsafeRow]
nextGroup = grouping(nextRow)
} else {
- nextRow = EmptyRow
- nextGroup = EmptyRow
+ nextRow = null
+ nextGroup = null
}
}
fetchNextRow()
// Manage the current partition.
- val rows = ArrayBuffer.empty[InternalRow]
+ val rows = ArrayBuffer.empty[UnsafeRow]
+ val inputFields = child.output.length
+ var sorter: UnsafeExternalSorter = null
+ var rowBuffer: RowBuffer = null
val windowFunctionResult = new SpecificMutableRow(expressions.map(_.dataType))
val frames = factories.map(_(windowFunctionResult))
val numFrames = frames.length
@@ -307,27 +314,63 @@ case class Window(
// Collect all the rows in the current partition.
// Before we start to fetch new input rows, make a copy of nextGroup.
val currentGroup = nextGroup.copy()
- rows.clear()
+
+ // clear last partition
+ if (sorter != null) {
+ // the last sorter of this task will be cleaned up via task completion listener
+ sorter.cleanupResources()
+ sorter = null
+ } else {
+ rows.clear()
+ }
+
while (nextRowAvailable && nextGroup == currentGroup) {
- rows += nextRow.copy()
+ if (sorter == null) {
+ rows += nextRow.copy()
+
+ if (rows.length >= 4096) {
+ // We will not sort the rows, so prefixComparator and recordComparator are null.
+ sorter = UnsafeExternalSorter.create(
+ TaskContext.get().taskMemoryManager(),
+ SparkEnv.get.blockManager,
+ TaskContext.get(),
+ null,
+ null,
+ 1024,
+ SparkEnv.get.memoryManager.pageSizeBytes)
+ rows.foreach { r =>
+ sorter.insertRecord(r.getBaseObject, r.getBaseOffset, r.getSizeInBytes, 0)
+ }
+ rows.clear()
+ }
+ } else {
+ sorter.insertRecord(nextRow.getBaseObject, nextRow.getBaseOffset,
+ nextRow.getSizeInBytes, 0)
+ }
fetchNextRow()
}
+ if (sorter != null) {
+ rowBuffer = new ExternalRowBuffer(sorter, inputFields)
+ } else {
+ rowBuffer = new ArrayRowBuffer(rows)
+ }
// Setup the frames.
var i = 0
while (i < numFrames) {
- frames(i).prepare(rows)
+ frames(i).prepare(rowBuffer.copy())
i += 1
}
// Setup iteration
rowIndex = 0
- rowsSize = rows.size
+ rowsSize = rowBuffer.size()
}
// Iteration
var rowIndex = 0
- var rowsSize = 0
+ var rowsSize = 0L
+
override final def hasNext: Boolean = rowIndex < rowsSize || nextRowAvailable
val join = new JoinedRow
@@ -340,13 +383,14 @@ case class Window(
if (rowIndex < rowsSize) {
// Get the results for the window frames.
var i = 0
+ val current = rowBuffer.next()
while (i < numFrames) {
- frames(i).write()
+ frames(i).write(rowIndex, current)
i += 1
}
// 'Merge' the input row with the window function result
- join(rows(rowIndex), windowFunctionResult)
+ join(current, windowFunctionResult)
rowIndex += 1
// Return the projection.
@@ -362,14 +406,18 @@ case class Window(
* Function for comparing boundary values.
*/
private[execution] abstract class BoundOrdering {
- def compare(input: Seq[InternalRow], inputIndex: Int, outputIndex: Int): Int
+ def compare(inputRow: InternalRow, inputIndex: Int, outputRow: InternalRow, outputIndex: Int): Int
}
/**
* Compare the input index to the bound of the output index.
*/
private[execution] final case class RowBoundOrdering(offset: Int) extends BoundOrdering {
- override def compare(input: Seq[InternalRow], inputIndex: Int, outputIndex: Int): Int =
+ override def compare(
+ inputRow: InternalRow,
+ inputIndex: Int,
+ outputRow: InternalRow,
+ outputIndex: Int): Int =
inputIndex - (outputIndex + offset)
}
@@ -380,8 +428,100 @@ private[execution] final case class RangeBoundOrdering(
ordering: Ordering[InternalRow],
current: Projection,
bound: Projection) extends BoundOrdering {
- override def compare(input: Seq[InternalRow], inputIndex: Int, outputIndex: Int): Int =
- ordering.compare(current(input(inputIndex)), bound(input(outputIndex)))
+ override def compare(
+ inputRow: InternalRow,
+ inputIndex: Int,
+ outputRow: InternalRow,
+ outputIndex: Int): Int =
+ ordering.compare(current(inputRow), bound(outputRow))
+}
+
+/**
+ * The interface of row buffer for a partition
+ */
+private[execution] abstract class RowBuffer {
+
+ /** Number of rows. */
+ def size(): Int
+
+ /** Return next row in the buffer, null if no more left. */
+ def next(): InternalRow
+
+ /** Skip the next `n` rows. */
+ def skip(n: Int): Unit
+
+ /** Return a new RowBuffer that has the same rows. */
+ def copy(): RowBuffer
+}
+
+/**
+ * A row buffer based on ArrayBuffer (the number of rows is limited)
+ */
+private[execution] class ArrayRowBuffer(buffer: ArrayBuffer[UnsafeRow]) extends RowBuffer {
+
+ private[this] var cursor: Int = -1
+
+ /** Number of rows. */
+ def size(): Int = buffer.length
+
+ /** Return next row in the buffer, null if no more left. */
+ def next(): InternalRow = {
+ cursor += 1
+ if (cursor < buffer.length) {
+ buffer(cursor)
+ } else {
+ null
+ }
+ }
+
+ /** Skip the next `n` rows. */
+ def skip(n: Int): Unit = {
+ cursor += n
+ }
+
+ /** Return a new RowBuffer that has the same rows. */
+ def copy(): RowBuffer = {
+ new ArrayRowBuffer(buffer)
+ }
+}
+
+/**
+ * An external buffer of rows based on UnsafeExternalSorter
+ */
+private[execution] class ExternalRowBuffer(sorter: UnsafeExternalSorter, numFields: Int)
+ extends RowBuffer {
+
+ private[this] val iter: UnsafeSorterIterator = sorter.getIterator
+
+ private[this] val currentRow = new UnsafeRow(numFields)
+
+ /** Number of rows. */
+ def size(): Int = iter.getNumRecords()
+
+ /** Return next row in the buffer, null if no more left. */
+ def next(): InternalRow = {
+ if (iter.hasNext) {
+ iter.loadNext()
+ currentRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength)
+ currentRow
+ } else {
+ null
+ }
+ }
+
+ /** Skip the next `n` rows. */
+ def skip(n: Int): Unit = {
+ var i = 0
+ while (i < n && iter.hasNext) {
+ iter.loadNext()
+ i += 1
+ }
+ }
+
+ /** Return a new RowBuffer that has the same rows. */
+ def copy(): RowBuffer = {
+ new ExternalRowBuffer(sorter, numFields)
+ }
}
/**
@@ -395,12 +535,12 @@ private[execution] abstract class WindowFunctionFrame {
*
* @param rows to calculate the frame results for.
*/
- def prepare(rows: ArrayBuffer[InternalRow]): Unit
+ def prepare(rows: RowBuffer): Unit
/**
* Write the current results to the target row.
*/
- def write(): Unit
+ def write(index: Int, current: InternalRow): Unit
}
/**
@@ -421,14 +561,11 @@ private[execution] final class OffsetWindowFunctionFrame(
offset: Int) extends WindowFunctionFrame {
/** Rows of the partition currently being processed. */
- private[this] var input: ArrayBuffer[InternalRow] = null
+ private[this] var input: RowBuffer = null
/** Index of the input row currently used for output. */
private[this] var inputIndex = 0
- /** Index of the current output row. */
- private[this] var outputIndex = 0
-
/** Row used when there is no valid input. */
private[this] val emptyRow = new GenericInternalRow(inputSchema.size)
@@ -463,22 +600,26 @@ private[execution] final class OffsetWindowFunctionFrame(
newMutableProjection(boundExpressions, Nil)().target(target)
}
- override def prepare(rows: ArrayBuffer[InternalRow]): Unit = {
+ override def prepare(rows: RowBuffer): Unit = {
input = rows
+ // drain the first few rows if offset is larger than zero
+ inputIndex = 0
+ while (inputIndex < offset) {
+ input.next()
+ inputIndex += 1
+ }
inputIndex = offset
- outputIndex = 0
}
- override def write(): Unit = {
- val size = input.size
- if (inputIndex >= 0 && inputIndex < size) {
- join(input(inputIndex), input(outputIndex))
+ override def write(index: Int, current: InternalRow): Unit = {
+ if (inputIndex >= 0 && inputIndex < input.size) {
+ val r = input.next()
+ join(r, current)
} else {
- join(emptyRow, input(outputIndex))
+ join(emptyRow, current)
}
projection(join)
inputIndex += 1
- outputIndex += 1
}
}
@@ -498,7 +639,13 @@ private[execution] final class SlidingWindowFunctionFrame(
ubound: BoundOrdering) extends WindowFunctionFrame {
/** Rows of the partition currently being processed. */
- private[this] var input: ArrayBuffer[InternalRow] = null
+ private[this] var input: RowBuffer = null
+
+ /** The next row from `input`. */
+ private[this] var nextRow: InternalRow = null
+
+ /** The rows within current sliding window. */
+ private[this] val buffer = new util.ArrayDeque[InternalRow]()
/** Index of the first input row with a value greater than the upper bound of the current
* output row. */
@@ -508,33 +655,32 @@ private[execution] final class SlidingWindowFunctionFrame(
* current output row. */
private[this] var inputLowIndex = 0
- /** Index of the row we are currently writing. */
- private[this] var outputIndex = 0
-
/** Prepare the frame for calculating a new partition. Reset all variables. */
- override def prepare(rows: ArrayBuffer[InternalRow]): Unit = {
+ override def prepare(rows: RowBuffer): Unit = {
input = rows
+ nextRow = rows.next()
inputHighIndex = 0
inputLowIndex = 0
- outputIndex = 0
+ buffer.clear()
}
/** Write the frame columns for the current row to the given target row. */
- override def write(): Unit = {
- var bufferUpdated = outputIndex == 0
+ override def write(index: Int, current: InternalRow): Unit = {
+ var bufferUpdated = index == 0
// Add all rows to the buffer for which the input row value is equal to or less than
// the output row upper bound.
- while (inputHighIndex < input.size &&
- ubound.compare(input, inputHighIndex, outputIndex) <= 0) {
+ while (nextRow != null && ubound.compare(nextRow, inputHighIndex, current, index) <= 0) {
+ buffer.add(nextRow.copy())
+ nextRow = input.next()
inputHighIndex += 1
bufferUpdated = true
}
// Drop all rows from the buffer for which the input row value is smaller than
// the output row lower bound.
- while (inputLowIndex < inputHighIndex &&
- lbound.compare(input, inputLowIndex, outputIndex) < 0) {
+ while (!buffer.isEmpty && lbound.compare(buffer.peek(), inputLowIndex, current, index) < 0) {
+ buffer.remove()
inputLowIndex += 1
bufferUpdated = true
}
@@ -542,12 +688,12 @@ private[execution] final class SlidingWindowFunctionFrame(
// Only recalculate and update when the buffer changes.
if (bufferUpdated) {
processor.initialize(input.size)
- processor.update(input, inputLowIndex, inputHighIndex)
+ val iter = buffer.iterator()
+ while (iter.hasNext) {
+ processor.update(iter.next())
+ }
processor.evaluate(target)
}
-
- // Move to the next row.
- outputIndex += 1
}
}
@@ -567,13 +713,18 @@ private[execution] final class UnboundedWindowFunctionFrame(
processor: AggregateProcessor) extends WindowFunctionFrame {
/** Prepare the frame for calculating a new partition. Process all rows eagerly. */
- override def prepare(rows: ArrayBuffer[InternalRow]): Unit = {
- processor.initialize(rows.size)
- processor.update(rows, 0, rows.size)
+ override def prepare(rows: RowBuffer): Unit = {
+ val size = rows.size()
+ processor.initialize(size)
+ var i = 0
+ while (i < size) {
+ processor.update(rows.next())
+ i += 1
+ }
}
/** Write the frame columns for the current row to the given target row. */
- override def write(): Unit = {
+ override def write(index: Int, current: InternalRow): Unit = {
// Unfortunately we cannot assume that evaluation is deterministic. So we need to re-evaluate
// for each row.
processor.evaluate(target)
@@ -600,31 +751,32 @@ private[execution] final class UnboundedPrecedingWindowFunctionFrame(
ubound: BoundOrdering) extends WindowFunctionFrame {
/** Rows of the partition currently being processed. */
- private[this] var input: ArrayBuffer[InternalRow] = null
+ private[this] var input: RowBuffer = null
+
+ /** The next row from `input`. */
+ private[this] var nextRow: InternalRow = null
/** Index of the first input row with a value greater than the upper bound of the current
* output row. */
private[this] var inputIndex = 0
- /** Index of the row we are currently writing. */
- private[this] var outputIndex = 0
-
/** Prepare the frame for calculating a new partition. */
- override def prepare(rows: ArrayBuffer[InternalRow]): Unit = {
+ override def prepare(rows: RowBuffer): Unit = {
input = rows
+ nextRow = rows.next()
inputIndex = 0
- outputIndex = 0
processor.initialize(input.size)
}
/** Write the frame columns for the current row to the given target row. */
- override def write(): Unit = {
- var bufferUpdated = outputIndex == 0
+ override def write(index: Int, current: InternalRow): Unit = {
+ var bufferUpdated = index == 0
// Add all rows to the aggregates for which the input row value is equal to or less than
// the output row upper bound.
- while (inputIndex < input.size && ubound.compare(input, inputIndex, outputIndex) <= 0) {
- processor.update(input(inputIndex))
+ while (nextRow != null && ubound.compare(nextRow, inputIndex, current, index) <= 0) {
+ processor.update(nextRow)
+ nextRow = input.next()
inputIndex += 1
bufferUpdated = true
}
@@ -633,9 +785,6 @@ private[execution] final class UnboundedPrecedingWindowFunctionFrame(
if (bufferUpdated) {
processor.evaluate(target)
}
-
- // Move to the next row.
- outputIndex += 1
}
}
@@ -661,29 +810,31 @@ private[execution] final class UnboundedFollowingWindowFunctionFrame(
lbound: BoundOrdering) extends WindowFunctionFrame {
/** Rows of the partition currently being processed. */
- private[this] var input: ArrayBuffer[InternalRow] = null
+ private[this] var input: RowBuffer = null
/** Index of the first input row with a value equal to or greater than the lower bound of the
* current output row. */
private[this] var inputIndex = 0
- /** Index of the row we are currently writing. */
- private[this] var outputIndex = 0
-
/** Prepare the frame for calculating a new partition. */
- override def prepare(rows: ArrayBuffer[InternalRow]): Unit = {
+ override def prepare(rows: RowBuffer): Unit = {
input = rows
inputIndex = 0
- outputIndex = 0
}
/** Write the frame columns for the current row to the given target row. */
- override def write(): Unit = {
- var bufferUpdated = outputIndex == 0
+ override def write(index: Int, current: InternalRow): Unit = {
+ var bufferUpdated = index == 0
+
+ // Duplicate the input to have a new iterator
+ val tmp = input.copy()
// Drop all rows from the buffer for which the input row value is smaller than
// the output row lower bound.
- while (inputIndex < input.size && lbound.compare(input, inputIndex, outputIndex) < 0) {
+ tmp.skip(inputIndex)
+ var nextRow = tmp.next()
+ while (nextRow != null && lbound.compare(nextRow, inputIndex, current, index) < 0) {
+ nextRow = tmp.next()
inputIndex += 1
bufferUpdated = true
}
@@ -691,12 +842,12 @@ private[execution] final class UnboundedFollowingWindowFunctionFrame(
// Only recalculate and update when the buffer changes.
if (bufferUpdated) {
processor.initialize(input.size)
- processor.update(input, inputIndex, input.size)
+ while (nextRow != null) {
+ processor.update(nextRow)
+ nextRow = tmp.next()
+ }
processor.evaluate(target)
}
-
- // Move to the next row.
- outputIndex += 1
}
}
@@ -825,15 +976,6 @@ private[execution] final class AggregateProcessor(
}
}
- /** Bulk update the given buffer. */
- def update(input: ArrayBuffer[InternalRow], begin: Int, end: Int): Unit = {
- var i = begin
- while (i < end) {
- update(input(i))
- i += 1
- }
- }
-
/** Evaluate buffer. */
def evaluate(target: MutableRow): Unit =
evaluateProjection.target(target)(buffer)