aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/main/scala/org/apache
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core/src/main/scala/org/apache')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala243
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala52
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala117
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/window/RowBuffer.scala115
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala72
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala97
6 files changed, 405 insertions, 291 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala
new file mode 100644
index 0000000000..458ac4ba36
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala
@@ -0,0 +1,243 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import java.util.ConcurrentModificationException
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.internal.Logging
+import org.apache.spark.memory.TaskMemoryManager
+import org.apache.spark.serializer.SerializerManager
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray.DefaultInitialSizeOfInMemoryBuffer
+import org.apache.spark.storage.BlockManager
+import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, UnsafeSorterIterator}
+
+/**
+ * An append-only array for [[UnsafeRow]]s that spills content to disk when there a predefined
+ * threshold of rows is reached.
+ *
+ * Setting spill threshold faces following trade-off:
+ *
+ * - If the spill threshold is too high, the in-memory array may occupy more memory than is
+ * available, resulting in OOM.
+ * - If the spill threshold is too low, we spill frequently and incur unnecessary disk writes.
+ * This may lead to a performance regression compared to the normal case of using an
+ * [[ArrayBuffer]] or [[Array]].
+ */
+private[sql] class ExternalAppendOnlyUnsafeRowArray(
+ taskMemoryManager: TaskMemoryManager,
+ blockManager: BlockManager,
+ serializerManager: SerializerManager,
+ taskContext: TaskContext,
+ initialSize: Int,
+ pageSizeBytes: Long,
+ numRowsSpillThreshold: Int) extends Logging {
+
+ def this(numRowsSpillThreshold: Int) {
+ this(
+ TaskContext.get().taskMemoryManager(),
+ SparkEnv.get.blockManager,
+ SparkEnv.get.serializerManager,
+ TaskContext.get(),
+ 1024,
+ SparkEnv.get.memoryManager.pageSizeBytes,
+ numRowsSpillThreshold)
+ }
+
+ private val initialSizeOfInMemoryBuffer =
+ Math.min(DefaultInitialSizeOfInMemoryBuffer, numRowsSpillThreshold)
+
+ private val inMemoryBuffer = if (initialSizeOfInMemoryBuffer > 0) {
+ new ArrayBuffer[UnsafeRow](initialSizeOfInMemoryBuffer)
+ } else {
+ null
+ }
+
+ private var spillableArray: UnsafeExternalSorter = _
+ private var numRows = 0
+
+ // A counter to keep track of total modifications done to this array since its creation.
+ // This helps to invalidate iterators when there are changes done to the backing array.
+ private var modificationsCount: Long = 0
+
+ private var numFieldsPerRow = 0
+
+ def length: Int = numRows
+
+ def isEmpty: Boolean = numRows == 0
+
+ /**
+ * Clears up resources (eg. memory) held by the backing storage
+ */
+ def clear(): Unit = {
+ if (spillableArray != null) {
+ // The last `spillableArray` of this task will be cleaned up via task completion listener
+ // inside `UnsafeExternalSorter`
+ spillableArray.cleanupResources()
+ spillableArray = null
+ } else if (inMemoryBuffer != null) {
+ inMemoryBuffer.clear()
+ }
+ numFieldsPerRow = 0
+ numRows = 0
+ modificationsCount += 1
+ }
+
+ def add(unsafeRow: UnsafeRow): Unit = {
+ if (numRows < numRowsSpillThreshold) {
+ inMemoryBuffer += unsafeRow.copy()
+ } else {
+ if (spillableArray == null) {
+ logInfo(s"Reached spill threshold of $numRowsSpillThreshold rows, switching to " +
+ s"${classOf[UnsafeExternalSorter].getName}")
+
+ // We will not sort the rows, so prefixComparator and recordComparator are null
+ spillableArray = UnsafeExternalSorter.create(
+ taskMemoryManager,
+ blockManager,
+ serializerManager,
+ taskContext,
+ null,
+ null,
+ initialSize,
+ pageSizeBytes,
+ numRowsSpillThreshold,
+ false)
+
+ // populate with existing in-memory buffered rows
+ if (inMemoryBuffer != null) {
+ inMemoryBuffer.foreach(existingUnsafeRow =>
+ spillableArray.insertRecord(
+ existingUnsafeRow.getBaseObject,
+ existingUnsafeRow.getBaseOffset,
+ existingUnsafeRow.getSizeInBytes,
+ 0,
+ false)
+ )
+ inMemoryBuffer.clear()
+ }
+ numFieldsPerRow = unsafeRow.numFields()
+ }
+
+ spillableArray.insertRecord(
+ unsafeRow.getBaseObject,
+ unsafeRow.getBaseOffset,
+ unsafeRow.getSizeInBytes,
+ 0,
+ false)
+ }
+
+ numRows += 1
+ modificationsCount += 1
+ }
+
+ /**
+ * Creates an [[Iterator]] for the current rows in the array starting from a user provided index
+ *
+ * If there are subsequent [[add()]] or [[clear()]] calls made on this array after creation of
+ * the iterator, then the iterator is invalidated thus saving clients from thinking that they
+ * have read all the data while there were new rows added to this array.
+ */
+ def generateIterator(startIndex: Int): Iterator[UnsafeRow] = {
+ if (startIndex < 0 || (numRows > 0 && startIndex > numRows)) {
+ throw new ArrayIndexOutOfBoundsException(
+ "Invalid `startIndex` provided for generating iterator over the array. " +
+ s"Total elements: $numRows, requested `startIndex`: $startIndex")
+ }
+
+ if (spillableArray == null) {
+ new InMemoryBufferIterator(startIndex)
+ } else {
+ new SpillableArrayIterator(spillableArray.getIterator, numFieldsPerRow, startIndex)
+ }
+ }
+
+ def generateIterator(): Iterator[UnsafeRow] = generateIterator(startIndex = 0)
+
+ private[this]
+ abstract class ExternalAppendOnlyUnsafeRowArrayIterator extends Iterator[UnsafeRow] {
+ private val expectedModificationsCount = modificationsCount
+
+ protected def isModified(): Boolean = expectedModificationsCount != modificationsCount
+
+ protected def throwExceptionIfModified(): Unit = {
+ if (expectedModificationsCount != modificationsCount) {
+ throw new ConcurrentModificationException(
+ s"The backing ${classOf[ExternalAppendOnlyUnsafeRowArray].getName} has been modified " +
+ s"since the creation of this Iterator")
+ }
+ }
+ }
+
+ private[this] class InMemoryBufferIterator(startIndex: Int)
+ extends ExternalAppendOnlyUnsafeRowArrayIterator {
+
+ private var currentIndex = startIndex
+
+ override def hasNext(): Boolean = !isModified() && currentIndex < numRows
+
+ override def next(): UnsafeRow = {
+ throwExceptionIfModified()
+ val result = inMemoryBuffer(currentIndex)
+ currentIndex += 1
+ result
+ }
+ }
+
+ private[this] class SpillableArrayIterator(
+ iterator: UnsafeSorterIterator,
+ numFieldPerRow: Int,
+ startIndex: Int)
+ extends ExternalAppendOnlyUnsafeRowArrayIterator {
+
+ private val currentRow = new UnsafeRow(numFieldPerRow)
+
+ def init(): Unit = {
+ var i = 0
+ while (i < startIndex) {
+ if (iterator.hasNext) {
+ iterator.loadNext()
+ } else {
+ throw new ArrayIndexOutOfBoundsException(
+ "Invalid `startIndex` provided for generating iterator over the array. " +
+ s"Total elements: $numRows, requested `startIndex`: $startIndex")
+ }
+ i += 1
+ }
+ }
+
+ // Traverse upto the given [[startIndex]]
+ init()
+
+ override def hasNext(): Boolean = !isModified() && iterator.hasNext
+
+ override def next(): UnsafeRow = {
+ throwExceptionIfModified()
+ iterator.loadNext()
+ currentRow.pointTo(iterator.getBaseObject, iterator.getBaseOffset, iterator.getRecordLength)
+ currentRow
+ }
+ }
+}
+
+private[sql] object ExternalAppendOnlyUnsafeRowArray {
+ val DefaultInitialSizeOfInMemoryBuffer = 128
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
index 8341fe2ffd..f380986951 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
@@ -19,65 +19,39 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark._
import org.apache.spark.rdd.{CartesianPartition, CartesianRDD, RDD}
-import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
-import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
+import org.apache.spark.sql.execution.{BinaryExecNode, ExternalAppendOnlyUnsafeRowArray, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.CompletionIterator
-import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
/**
* An optimized CartesianRDD for UnsafeRow, which will cache the rows from second child RDD,
* will be much faster than building the right partition for every row in left RDD, it also
* materialize the right RDD (in case of the right RDD is nondeterministic).
*/
-class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numFieldsOfRight: Int)
+class UnsafeCartesianRDD(
+ left : RDD[UnsafeRow],
+ right : RDD[UnsafeRow],
+ numFieldsOfRight: Int,
+ spillThreshold: Int)
extends CartesianRDD[UnsafeRow, UnsafeRow](left.sparkContext, left, right) {
override def compute(split: Partition, context: TaskContext): Iterator[(UnsafeRow, UnsafeRow)] = {
- // We will not sort the rows, so prefixComparator and recordComparator are null.
- val sorter = UnsafeExternalSorter.create(
- context.taskMemoryManager(),
- SparkEnv.get.blockManager,
- SparkEnv.get.serializerManager,
- context,
- null,
- null,
- 1024,
- SparkEnv.get.memoryManager.pageSizeBytes,
- SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold",
- UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD),
- false)
+ val rowArray = new ExternalAppendOnlyUnsafeRowArray(spillThreshold)
val partition = split.asInstanceOf[CartesianPartition]
- for (y <- rdd2.iterator(partition.s2, context)) {
- sorter.insertRecord(y.getBaseObject, y.getBaseOffset, y.getSizeInBytes, 0, false)
- }
+ rdd2.iterator(partition.s2, context).foreach(rowArray.add)
- // Create an iterator from sorter and wrapper it as Iterator[UnsafeRow]
- def createIter(): Iterator[UnsafeRow] = {
- val iter = sorter.getIterator
- val unsafeRow = new UnsafeRow(numFieldsOfRight)
- new Iterator[UnsafeRow] {
- override def hasNext: Boolean = {
- iter.hasNext
- }
- override def next(): UnsafeRow = {
- iter.loadNext()
- unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength)
- unsafeRow
- }
- }
- }
+ // Create an iterator from rowArray
+ def createIter(): Iterator[UnsafeRow] = rowArray.generateIterator()
val resultIter =
for (x <- rdd1.iterator(partition.s1, context);
y <- createIter()) yield (x, y)
CompletionIterator[(UnsafeRow, UnsafeRow), Iterator[(UnsafeRow, UnsafeRow)]](
- resultIter, sorter.cleanupResources())
+ resultIter, rowArray.clear())
}
}
@@ -97,7 +71,9 @@ case class CartesianProductExec(
val leftResults = left.execute().asInstanceOf[RDD[UnsafeRow]]
val rightResults = right.execute().asInstanceOf[RDD[UnsafeRow]]
- val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size)
+ val spillThreshold = sqlContext.conf.cartesianProductExecBufferSpillThreshold
+
+ val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size, spillThreshold)
pair.mapPartitionsWithIndexInternal { (index, iter) =>
val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)
val filtered = if (condition.isDefined) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index ca9c0ed8ce..bcdc4dcdf7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -25,7 +25,8 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, RowIterator, SparkPlan}
+import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport,
+ExternalAppendOnlyUnsafeRowArray, RowIterator, SparkPlan}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.util.collection.BitSet
@@ -95,9 +96,13 @@ case class SortMergeJoinExec(
private def createRightKeyGenerator(): Projection =
UnsafeProjection.create(rightKeys, right.output)
+ private def getSpillThreshold: Int = {
+ sqlContext.conf.sortMergeJoinExecBufferSpillThreshold
+ }
+
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
-
+ val spillThreshold = getSpillThreshold
left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
val boundCondition: (InternalRow) => Boolean = {
condition.map { cond =>
@@ -115,39 +120,39 @@ case class SortMergeJoinExec(
case _: InnerLike =>
new RowIterator {
private[this] var currentLeftRow: InternalRow = _
- private[this] var currentRightMatches: ArrayBuffer[InternalRow] = _
- private[this] var currentMatchIdx: Int = -1
+ private[this] var currentRightMatches: ExternalAppendOnlyUnsafeRowArray = _
+ private[this] var rightMatchesIterator: Iterator[UnsafeRow] = null
private[this] val smjScanner = new SortMergeJoinScanner(
createLeftKeyGenerator(),
createRightKeyGenerator(),
keyOrdering,
RowIterator.fromScala(leftIter),
- RowIterator.fromScala(rightIter)
+ RowIterator.fromScala(rightIter),
+ spillThreshold
)
private[this] val joinRow = new JoinedRow
if (smjScanner.findNextInnerJoinRows()) {
currentRightMatches = smjScanner.getBufferedMatches
currentLeftRow = smjScanner.getStreamedRow
- currentMatchIdx = 0
+ rightMatchesIterator = currentRightMatches.generateIterator()
}
override def advanceNext(): Boolean = {
- while (currentMatchIdx >= 0) {
- if (currentMatchIdx == currentRightMatches.length) {
+ while (rightMatchesIterator != null) {
+ if (!rightMatchesIterator.hasNext) {
if (smjScanner.findNextInnerJoinRows()) {
currentRightMatches = smjScanner.getBufferedMatches
currentLeftRow = smjScanner.getStreamedRow
- currentMatchIdx = 0
+ rightMatchesIterator = currentRightMatches.generateIterator()
} else {
currentRightMatches = null
currentLeftRow = null
- currentMatchIdx = -1
+ rightMatchesIterator = null
return false
}
}
- joinRow(currentLeftRow, currentRightMatches(currentMatchIdx))
- currentMatchIdx += 1
+ joinRow(currentLeftRow, rightMatchesIterator.next())
if (boundCondition(joinRow)) {
numOutputRows += 1
return true
@@ -165,7 +170,8 @@ case class SortMergeJoinExec(
bufferedKeyGenerator = createRightKeyGenerator(),
keyOrdering,
streamedIter = RowIterator.fromScala(leftIter),
- bufferedIter = RowIterator.fromScala(rightIter)
+ bufferedIter = RowIterator.fromScala(rightIter),
+ spillThreshold
)
val rightNullRow = new GenericInternalRow(right.output.length)
new LeftOuterIterator(
@@ -177,7 +183,8 @@ case class SortMergeJoinExec(
bufferedKeyGenerator = createLeftKeyGenerator(),
keyOrdering,
streamedIter = RowIterator.fromScala(rightIter),
- bufferedIter = RowIterator.fromScala(leftIter)
+ bufferedIter = RowIterator.fromScala(leftIter),
+ spillThreshold
)
val leftNullRow = new GenericInternalRow(left.output.length)
new RightOuterIterator(
@@ -209,7 +216,8 @@ case class SortMergeJoinExec(
createRightKeyGenerator(),
keyOrdering,
RowIterator.fromScala(leftIter),
- RowIterator.fromScala(rightIter)
+ RowIterator.fromScala(rightIter),
+ spillThreshold
)
private[this] val joinRow = new JoinedRow
@@ -217,14 +225,15 @@ case class SortMergeJoinExec(
while (smjScanner.findNextInnerJoinRows()) {
val currentRightMatches = smjScanner.getBufferedMatches
currentLeftRow = smjScanner.getStreamedRow
- var i = 0
- while (i < currentRightMatches.length) {
- joinRow(currentLeftRow, currentRightMatches(i))
- if (boundCondition(joinRow)) {
- numOutputRows += 1
- return true
+ if (currentRightMatches != null && currentRightMatches.length > 0) {
+ val rightMatchesIterator = currentRightMatches.generateIterator()
+ while (rightMatchesIterator.hasNext) {
+ joinRow(currentLeftRow, rightMatchesIterator.next())
+ if (boundCondition(joinRow)) {
+ numOutputRows += 1
+ return true
+ }
}
- i += 1
}
}
false
@@ -241,7 +250,8 @@ case class SortMergeJoinExec(
createRightKeyGenerator(),
keyOrdering,
RowIterator.fromScala(leftIter),
- RowIterator.fromScala(rightIter)
+ RowIterator.fromScala(rightIter),
+ spillThreshold
)
private[this] val joinRow = new JoinedRow
@@ -249,17 +259,16 @@ case class SortMergeJoinExec(
while (smjScanner.findNextOuterJoinRows()) {
currentLeftRow = smjScanner.getStreamedRow
val currentRightMatches = smjScanner.getBufferedMatches
- if (currentRightMatches == null) {
+ if (currentRightMatches == null || currentRightMatches.length == 0) {
return true
}
- var i = 0
var found = false
- while (!found && i < currentRightMatches.length) {
- joinRow(currentLeftRow, currentRightMatches(i))
+ val rightMatchesIterator = currentRightMatches.generateIterator()
+ while (!found && rightMatchesIterator.hasNext) {
+ joinRow(currentLeftRow, rightMatchesIterator.next())
if (boundCondition(joinRow)) {
found = true
}
- i += 1
}
if (!found) {
numOutputRows += 1
@@ -281,7 +290,8 @@ case class SortMergeJoinExec(
createRightKeyGenerator(),
keyOrdering,
RowIterator.fromScala(leftIter),
- RowIterator.fromScala(rightIter)
+ RowIterator.fromScala(rightIter),
+ spillThreshold
)
private[this] val joinRow = new JoinedRow
@@ -290,14 +300,13 @@ case class SortMergeJoinExec(
currentLeftRow = smjScanner.getStreamedRow
val currentRightMatches = smjScanner.getBufferedMatches
var found = false
- if (currentRightMatches != null) {
- var i = 0
- while (!found && i < currentRightMatches.length) {
- joinRow(currentLeftRow, currentRightMatches(i))
+ if (currentRightMatches != null && currentRightMatches.length > 0) {
+ val rightMatchesIterator = currentRightMatches.generateIterator()
+ while (!found && rightMatchesIterator.hasNext) {
+ joinRow(currentLeftRow, rightMatchesIterator.next())
if (boundCondition(joinRow)) {
found = true
}
- i += 1
}
}
result.setBoolean(0, found)
@@ -376,8 +385,11 @@ case class SortMergeJoinExec(
// A list to hold all matched rows from right side.
val matches = ctx.freshName("matches")
- val clsName = classOf[java.util.ArrayList[InternalRow]].getName
- ctx.addMutableState(clsName, matches, s"$matches = new $clsName();")
+ val clsName = classOf[ExternalAppendOnlyUnsafeRowArray].getName
+
+ val spillThreshold = getSpillThreshold
+
+ ctx.addMutableState(clsName, matches, s"$matches = new $clsName($spillThreshold);")
// Copy the left keys as class members so they could be used in next function call.
val matchedKeyVars = copyKeys(ctx, leftKeyVars)
@@ -428,7 +440,7 @@ case class SortMergeJoinExec(
| }
| $leftRow = null;
| } else {
- | $matches.add($rightRow.copy());
+ | $matches.add((UnsafeRow) $rightRow);
| $rightRow = null;;
| }
| } while ($leftRow != null);
@@ -517,8 +529,7 @@ case class SortMergeJoinExec(
val rightRow = ctx.freshName("rightRow")
val rightVars = createRightVar(ctx, rightRow)
- val size = ctx.freshName("size")
- val i = ctx.freshName("i")
+ val iterator = ctx.freshName("iterator")
val numOutput = metricTerm(ctx, "numOutputRows")
val (beforeLoop, condCheck) = if (condition.isDefined) {
// Split the code of creating variables based on whether it's used by condition or not.
@@ -551,10 +562,10 @@ case class SortMergeJoinExec(
s"""
|while (findNextInnerJoinRows($leftInput, $rightInput)) {
- | int $size = $matches.size();
| ${beforeLoop.trim}
- | for (int $i = 0; $i < $size; $i ++) {
- | InternalRow $rightRow = (InternalRow) $matches.get($i);
+ | scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator();
+ | while ($iterator.hasNext()) {
+ | InternalRow $rightRow = (InternalRow) $iterator.next();
| ${condCheck.trim}
| $numOutput.add(1);
| ${consume(ctx, leftVars ++ rightVars)}
@@ -589,7 +600,8 @@ private[joins] class SortMergeJoinScanner(
bufferedKeyGenerator: Projection,
keyOrdering: Ordering[InternalRow],
streamedIter: RowIterator,
- bufferedIter: RowIterator) {
+ bufferedIter: RowIterator,
+ bufferThreshold: Int) {
private[this] var streamedRow: InternalRow = _
private[this] var streamedRowKey: InternalRow = _
private[this] var bufferedRow: InternalRow = _
@@ -600,7 +612,7 @@ private[joins] class SortMergeJoinScanner(
*/
private[this] var matchJoinKey: InternalRow = _
/** Buffered rows from the buffered side of the join. This is empty if there are no matches. */
- private[this] val bufferedMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow]
+ private[this] val bufferedMatches = new ExternalAppendOnlyUnsafeRowArray(bufferThreshold)
// Initialization (note: do _not_ want to advance streamed here).
advancedBufferedToRowWithNullFreeJoinKey()
@@ -609,7 +621,7 @@ private[joins] class SortMergeJoinScanner(
def getStreamedRow: InternalRow = streamedRow
- def getBufferedMatches: ArrayBuffer[InternalRow] = bufferedMatches
+ def getBufferedMatches: ExternalAppendOnlyUnsafeRowArray = bufferedMatches
/**
* Advances both input iterators, stopping when we have found rows with matching join keys.
@@ -755,7 +767,7 @@ private[joins] class SortMergeJoinScanner(
matchJoinKey = streamedRowKey.copy()
bufferedMatches.clear()
do {
- bufferedMatches += bufferedRow.copy() // need to copy mutable rows before buffering them
+ bufferedMatches.add(bufferedRow.asInstanceOf[UnsafeRow])
advancedBufferedToRowWithNullFreeJoinKey()
} while (bufferedRow != null && keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0)
}
@@ -819,7 +831,7 @@ private abstract class OneSideOuterIterator(
protected[this] val joinedRow: JoinedRow = new JoinedRow()
// Index of the buffered rows, reset to 0 whenever we advance to a new streamed row
- private[this] var bufferIndex: Int = 0
+ private[this] var rightMatchesIterator: Iterator[UnsafeRow] = null
// This iterator is initialized lazily so there should be no matches initially
assert(smjScanner.getBufferedMatches.length == 0)
@@ -833,7 +845,7 @@ private abstract class OneSideOuterIterator(
* @return whether there are more rows in the stream to consume.
*/
private def advanceStream(): Boolean = {
- bufferIndex = 0
+ rightMatchesIterator = null
if (smjScanner.findNextOuterJoinRows()) {
setStreamSideOutput(smjScanner.getStreamedRow)
if (smjScanner.getBufferedMatches.isEmpty) {
@@ -858,10 +870,13 @@ private abstract class OneSideOuterIterator(
*/
private def advanceBufferUntilBoundConditionSatisfied(): Boolean = {
var foundMatch: Boolean = false
- while (!foundMatch && bufferIndex < smjScanner.getBufferedMatches.length) {
- setBufferedSideOutput(smjScanner.getBufferedMatches(bufferIndex))
+ if (rightMatchesIterator == null) {
+ rightMatchesIterator = smjScanner.getBufferedMatches.generateIterator()
+ }
+
+ while (!foundMatch && rightMatchesIterator.hasNext) {
+ setBufferedSideOutput(rightMatchesIterator.next())
foundMatch = boundCondition(joinedRow)
- bufferIndex += 1
}
foundMatch
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/RowBuffer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/RowBuffer.scala
deleted file mode 100644
index ee36c84251..0000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/RowBuffer.scala
+++ /dev/null
@@ -1,115 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.window
-
-import scala.collection.mutable.ArrayBuffer
-
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.UnsafeRow
-import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, UnsafeSorterIterator}
-
-
-/**
- * The interface of row buffer for a partition. In absence of a buffer pool (with locking), the
- * row buffer is used to materialize a partition of rows since we need to repeatedly scan these
- * rows in window function processing.
- */
-private[window] abstract class RowBuffer {
-
- /** Number of rows. */
- def size: Int
-
- /** Return next row in the buffer, null if no more left. */
- def next(): InternalRow
-
- /** Skip the next `n` rows. */
- def skip(n: Int): Unit
-
- /** Return a new RowBuffer that has the same rows. */
- def copy(): RowBuffer
-}
-
-/**
- * A row buffer based on ArrayBuffer (the number of rows is limited).
- */
-private[window] class ArrayRowBuffer(buffer: ArrayBuffer[UnsafeRow]) extends RowBuffer {
-
- private[this] var cursor: Int = -1
-
- /** Number of rows. */
- override def size: Int = buffer.length
-
- /** Return next row in the buffer, null if no more left. */
- override def next(): InternalRow = {
- cursor += 1
- if (cursor < buffer.length) {
- buffer(cursor)
- } else {
- null
- }
- }
-
- /** Skip the next `n` rows. */
- override def skip(n: Int): Unit = {
- cursor += n
- }
-
- /** Return a new RowBuffer that has the same rows. */
- override def copy(): RowBuffer = {
- new ArrayRowBuffer(buffer)
- }
-}
-
-/**
- * An external buffer of rows based on UnsafeExternalSorter.
- */
-private[window] class ExternalRowBuffer(sorter: UnsafeExternalSorter, numFields: Int)
- extends RowBuffer {
-
- private[this] val iter: UnsafeSorterIterator = sorter.getIterator
-
- private[this] val currentRow = new UnsafeRow(numFields)
-
- /** Number of rows. */
- override def size: Int = iter.getNumRecords()
-
- /** Return next row in the buffer, null if no more left. */
- override def next(): InternalRow = {
- if (iter.hasNext) {
- iter.loadNext()
- currentRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength)
- currentRow
- } else {
- null
- }
- }
-
- /** Skip the next `n` rows. */
- override def skip(n: Int): Unit = {
- var i = 0
- while (i < n && iter.hasNext) {
- iter.loadNext()
- i += 1
- }
- }
-
- /** Return a new RowBuffer that has the same rows. */
- override def copy(): RowBuffer = {
- new ExternalRowBuffer(sorter, numFields)
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala
index 80b87d5ffa..950a6794a7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala
@@ -20,15 +20,13 @@ package org.apache.spark.sql.execution.window
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.types.IntegerType
-import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
/**
* This class calculates and outputs (windowed) aggregates over the rows in a single (sorted)
@@ -284,6 +282,7 @@ case class WindowExec(
// Unwrap the expressions and factories from the map.
val expressions = windowFrameExpressionFactoryPairs.flatMap(_._1)
val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray
+ val spillThreshold = sqlContext.conf.windowExecBufferSpillThreshold
// Start processing.
child.execute().mapPartitions { stream =>
@@ -310,10 +309,12 @@ case class WindowExec(
fetchNextRow()
// Manage the current partition.
- val rows = ArrayBuffer.empty[UnsafeRow]
val inputFields = child.output.length
- var sorter: UnsafeExternalSorter = null
- var rowBuffer: RowBuffer = null
+
+ val buffer: ExternalAppendOnlyUnsafeRowArray =
+ new ExternalAppendOnlyUnsafeRowArray(spillThreshold)
+ var bufferIterator: Iterator[UnsafeRow] = _
+
val windowFunctionResult = new SpecificInternalRow(expressions.map(_.dataType))
val frames = factories.map(_(windowFunctionResult))
val numFrames = frames.length
@@ -323,78 +324,43 @@ case class WindowExec(
val currentGroup = nextGroup.copy()
// clear last partition
- if (sorter != null) {
- // the last sorter of this task will be cleaned up via task completion listener
- sorter.cleanupResources()
- sorter = null
- } else {
- rows.clear()
- }
+ buffer.clear()
while (nextRowAvailable && nextGroup == currentGroup) {
- if (sorter == null) {
- rows += nextRow.copy()
-
- if (rows.length >= 4096) {
- // We will not sort the rows, so prefixComparator and recordComparator are null.
- sorter = UnsafeExternalSorter.create(
- TaskContext.get().taskMemoryManager(),
- SparkEnv.get.blockManager,
- SparkEnv.get.serializerManager,
- TaskContext.get(),
- null,
- null,
- 1024,
- SparkEnv.get.memoryManager.pageSizeBytes,
- SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold",
- UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD),
- false)
- rows.foreach { r =>
- sorter.insertRecord(r.getBaseObject, r.getBaseOffset, r.getSizeInBytes, 0, false)
- }
- rows.clear()
- }
- } else {
- sorter.insertRecord(nextRow.getBaseObject, nextRow.getBaseOffset,
- nextRow.getSizeInBytes, 0, false)
- }
+ buffer.add(nextRow)
fetchNextRow()
}
- if (sorter != null) {
- rowBuffer = new ExternalRowBuffer(sorter, inputFields)
- } else {
- rowBuffer = new ArrayRowBuffer(rows)
- }
// Setup the frames.
var i = 0
while (i < numFrames) {
- frames(i).prepare(rowBuffer.copy())
+ frames(i).prepare(buffer)
i += 1
}
// Setup iteration
rowIndex = 0
- rowsSize = rowBuffer.size
+ bufferIterator = buffer.generateIterator()
}
// Iteration
var rowIndex = 0
- var rowsSize = 0L
- override final def hasNext: Boolean = rowIndex < rowsSize || nextRowAvailable
+ override final def hasNext: Boolean =
+ (bufferIterator != null && bufferIterator.hasNext) || nextRowAvailable
val join = new JoinedRow
override final def next(): InternalRow = {
// Load the next partition if we need to.
- if (rowIndex >= rowsSize && nextRowAvailable) {
+ if ((bufferIterator == null || !bufferIterator.hasNext) && nextRowAvailable) {
fetchNextPartition()
}
- if (rowIndex < rowsSize) {
+ if (bufferIterator.hasNext) {
+ val current = bufferIterator.next()
+
// Get the results for the window frames.
var i = 0
- val current = rowBuffer.next()
while (i < numFrames) {
frames(i).write(rowIndex, current)
i += 1
@@ -406,7 +372,9 @@ case class WindowExec(
// Return the projection.
result(join)
- } else throw new NoSuchElementException
+ } else {
+ throw new NoSuchElementException
+ }
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
index 70efc0f78d..af2b4fb920 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
@@ -22,6 +22,7 @@ import java.util
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
+import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray
/**
@@ -35,7 +36,7 @@ private[window] abstract class WindowFunctionFrame {
*
* @param rows to calculate the frame results for.
*/
- def prepare(rows: RowBuffer): Unit
+ def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit
/**
* Write the current results to the target row.
@@ -43,6 +44,12 @@ private[window] abstract class WindowFunctionFrame {
def write(index: Int, current: InternalRow): Unit
}
+object WindowFunctionFrame {
+ def getNextOrNull(iterator: Iterator[UnsafeRow]): UnsafeRow = {
+ if (iterator.hasNext) iterator.next() else null
+ }
+}
+
/**
* The offset window frame calculates frames containing LEAD/LAG statements.
*
@@ -65,7 +72,12 @@ private[window] final class OffsetWindowFunctionFrame(
extends WindowFunctionFrame {
/** Rows of the partition currently being processed. */
- private[this] var input: RowBuffer = null
+ private[this] var input: ExternalAppendOnlyUnsafeRowArray = null
+
+ /**
+ * An iterator over the [[input]]
+ */
+ private[this] var inputIterator: Iterator[UnsafeRow] = _
/** Index of the input row currently used for output. */
private[this] var inputIndex = 0
@@ -103,20 +115,21 @@ private[window] final class OffsetWindowFunctionFrame(
newMutableProjection(boundExpressions, Nil).target(target)
}
- override def prepare(rows: RowBuffer): Unit = {
+ override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
input = rows
+ inputIterator = input.generateIterator()
// drain the first few rows if offset is larger than zero
inputIndex = 0
while (inputIndex < offset) {
- input.next()
+ if (inputIterator.hasNext) inputIterator.next()
inputIndex += 1
}
inputIndex = offset
}
override def write(index: Int, current: InternalRow): Unit = {
- if (inputIndex >= 0 && inputIndex < input.size) {
- val r = input.next()
+ if (inputIndex >= 0 && inputIndex < input.length) {
+ val r = WindowFunctionFrame.getNextOrNull(inputIterator)
projection(r)
} else {
// Use default values since the offset row does not exist.
@@ -143,7 +156,12 @@ private[window] final class SlidingWindowFunctionFrame(
extends WindowFunctionFrame {
/** Rows of the partition currently being processed. */
- private[this] var input: RowBuffer = null
+ private[this] var input: ExternalAppendOnlyUnsafeRowArray = null
+
+ /**
+ * An iterator over the [[input]]
+ */
+ private[this] var inputIterator: Iterator[UnsafeRow] = _
/** The next row from `input`. */
private[this] var nextRow: InternalRow = null
@@ -164,9 +182,10 @@ private[window] final class SlidingWindowFunctionFrame(
private[this] var inputLowIndex = 0
/** Prepare the frame for calculating a new partition. Reset all variables. */
- override def prepare(rows: RowBuffer): Unit = {
+ override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
input = rows
- nextRow = rows.next()
+ inputIterator = input.generateIterator()
+ nextRow = WindowFunctionFrame.getNextOrNull(inputIterator)
inputHighIndex = 0
inputLowIndex = 0
buffer.clear()
@@ -180,7 +199,7 @@ private[window] final class SlidingWindowFunctionFrame(
// the output row upper bound.
while (nextRow != null && ubound.compare(nextRow, inputHighIndex, current, index) <= 0) {
buffer.add(nextRow.copy())
- nextRow = input.next()
+ nextRow = WindowFunctionFrame.getNextOrNull(inputIterator)
inputHighIndex += 1
bufferUpdated = true
}
@@ -195,7 +214,7 @@ private[window] final class SlidingWindowFunctionFrame(
// Only recalculate and update when the buffer changes.
if (bufferUpdated) {
- processor.initialize(input.size)
+ processor.initialize(input.length)
val iter = buffer.iterator()
while (iter.hasNext) {
processor.update(iter.next())
@@ -222,13 +241,12 @@ private[window] final class UnboundedWindowFunctionFrame(
extends WindowFunctionFrame {
/** Prepare the frame for calculating a new partition. Process all rows eagerly. */
- override def prepare(rows: RowBuffer): Unit = {
- val size = rows.size
- processor.initialize(size)
- var i = 0
- while (i < size) {
- processor.update(rows.next())
- i += 1
+ override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
+ processor.initialize(rows.length)
+
+ val iterator = rows.generateIterator()
+ while (iterator.hasNext) {
+ processor.update(iterator.next())
}
}
@@ -261,7 +279,12 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame(
extends WindowFunctionFrame {
/** Rows of the partition currently being processed. */
- private[this] var input: RowBuffer = null
+ private[this] var input: ExternalAppendOnlyUnsafeRowArray = null
+
+ /**
+ * An iterator over the [[input]]
+ */
+ private[this] var inputIterator: Iterator[UnsafeRow] = _
/** The next row from `input`. */
private[this] var nextRow: InternalRow = null
@@ -273,11 +296,15 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame(
private[this] var inputIndex = 0
/** Prepare the frame for calculating a new partition. */
- override def prepare(rows: RowBuffer): Unit = {
+ override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
input = rows
- nextRow = rows.next()
inputIndex = 0
- processor.initialize(input.size)
+ inputIterator = input.generateIterator()
+ if (inputIterator.hasNext) {
+ nextRow = inputIterator.next()
+ }
+
+ processor.initialize(input.length)
}
/** Write the frame columns for the current row to the given target row. */
@@ -288,7 +315,7 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame(
// the output row upper bound.
while (nextRow != null && ubound.compare(nextRow, inputIndex, current, index) <= 0) {
processor.update(nextRow)
- nextRow = input.next()
+ nextRow = WindowFunctionFrame.getNextOrNull(inputIterator)
inputIndex += 1
bufferUpdated = true
}
@@ -323,7 +350,7 @@ private[window] final class UnboundedFollowingWindowFunctionFrame(
extends WindowFunctionFrame {
/** Rows of the partition currently being processed. */
- private[this] var input: RowBuffer = null
+ private[this] var input: ExternalAppendOnlyUnsafeRowArray = null
/**
* Index of the first input row with a value equal to or greater than the lower bound of the
@@ -332,7 +359,7 @@ private[window] final class UnboundedFollowingWindowFunctionFrame(
private[this] var inputIndex = 0
/** Prepare the frame for calculating a new partition. */
- override def prepare(rows: RowBuffer): Unit = {
+ override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
input = rows
inputIndex = 0
}
@@ -341,25 +368,25 @@ private[window] final class UnboundedFollowingWindowFunctionFrame(
override def write(index: Int, current: InternalRow): Unit = {
var bufferUpdated = index == 0
- // Duplicate the input to have a new iterator
- val tmp = input.copy()
-
- // Drop all rows from the buffer for which the input row value is smaller than
+ // Ignore all the rows from the buffer for which the input row value is smaller than
// the output row lower bound.
- tmp.skip(inputIndex)
- var nextRow = tmp.next()
+ val iterator = input.generateIterator(startIndex = inputIndex)
+
+ var nextRow = WindowFunctionFrame.getNextOrNull(iterator)
while (nextRow != null && lbound.compare(nextRow, inputIndex, current, index) < 0) {
- nextRow = tmp.next()
inputIndex += 1
bufferUpdated = true
+ nextRow = WindowFunctionFrame.getNextOrNull(iterator)
}
// Only recalculate and update when the buffer changes.
if (bufferUpdated) {
- processor.initialize(input.size)
- while (nextRow != null) {
+ processor.initialize(input.length)
+ if (nextRow != null) {
processor.update(nextRow)
- nextRow = tmp.next()
+ }
+ while (iterator.hasNext) {
+ processor.update(iterator.next())
}
processor.evaluate(target)
}