aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/main/scala/org/apache
diff options
context:
space:
mode:
authorTejas Patil <tejasp@fb.com>2017-03-15 20:18:39 +0100
committerHerman van Hovell <hvanhovell@databricks.com>2017-03-15 20:18:39 +0100
commit02c274eaba0a8e7611226e0d4e93d3c36253f4ce (patch)
tree52852e05f5a0b0729a6c92c1d360a6379a52a380 /sql/core/src/main/scala/org/apache
parent7387126f83dc0489eb1df734bfeba705709b7861 (diff)
downloadspark-02c274eaba0a8e7611226e0d4e93d3c36253f4ce.tar.gz
spark-02c274eaba0a8e7611226e0d4e93d3c36253f4ce.tar.bz2
spark-02c274eaba0a8e7611226e0d4e93d3c36253f4ce.zip
[SPARK-13450] Introduce ExternalAppendOnlyUnsafeRowArray. Change CartesianProductExec, SortMergeJoin, WindowExec to use it
## What issue does this PR address ? Jira: https://issues.apache.org/jira/browse/SPARK-13450 In `SortMergeJoinExec`, rows of the right relation having the same value for a join key are buffered in-memory. In case of skew, this causes OOMs (see comments in SPARK-13450 for more details). Heap dump from a failed job confirms this : https://issues.apache.org/jira/secure/attachment/12846382/heap-dump-analysis.png . While its possible to increase the heap size to workaround, Spark should be resilient to such issues as skews can happen arbitrarily. ## Change proposed in this pull request - Introduces `ExternalAppendOnlyUnsafeRowArray` - It holds `UnsafeRow`s in-memory upto a certain threshold. - After the threshold is hit, it switches to `UnsafeExternalSorter` which enables spilling of the rows to disk. It does NOT sort the data. - Allows iterating the array multiple times. However, any alteration to the array (using `add` or `clear`) will invalidate the existing iterator(s) - `WindowExec` was already using `UnsafeExternalSorter` to support spilling. Changed it to use the new array - Changed `SortMergeJoinExec` to use the new array implementation - NOTE: I have not changed FULL OUTER JOIN to use this new array implementation. Changing that will need more surgery and I will rather put up a separate PR for that once this gets in. - Changed `CartesianProductExec` to use the new array implementation #### Note for reviewers The diff can be divided into 3 parts. My motive behind having all the changes in a single PR was to demonstrate that the API is sane and supports 2 use cases. If reviewing as 3 separate PRs would help, I am happy to make the split. ## How was this patch tested ? #### Unit testing - Added unit tests `ExternalAppendOnlyUnsafeRowArray` to validate all its APIs and access patterns - Added unit test for `SortMergeExec` - with and without spill for inner join, left outer join, right outer join to confirm that the spill threshold config behaves as expected and output is as expected. - This PR touches the scanning logic in `SortMergeExec` for _all_ joins (except FULL OUTER JOIN). However, I expect existing test cases to cover that there is no regression in correctness. - Added unit test for `WindowExec` to check behavior of spilling and correctness of results. #### Stress testing - Confirmed that OOM is gone by running against a production job which used to OOM - Since I cannot share details about prod workload externally, created synthetic data to mimic the issue. Ran before and after the fix to demonstrate the issue and query success with this PR Generating the synthetic data ``` ./bin/spark-shell --driver-memory=6G import org.apache.spark.sql._ val hc = SparkSession.builder.master("local").getOrCreate() hc.sql("DROP TABLE IF EXISTS spark_13450_large_table").collect hc.sql("DROP TABLE IF EXISTS spark_13450_one_row_table").collect val df1 = (0 until 1).map(i => ("10", "100", i.toString, (i * 2).toString)).toDF("i", "j", "str1", "str2") df1.write.format("org.apache.spark.sql.hive.orc.OrcFileFormat").bucketBy(100, "i", "j").sortBy("i", "j").saveAsTable("spark_13450_one_row_table") val df2 = (0 until 3000000).map(i => ("10", "100", i.toString, (i * 2).toString)).toDF("i", "j", "str1", "str2") df2.write.format("org.apache.spark.sql.hive.orc.OrcFileFormat").bucketBy(100, "i", "j").sortBy("i", "j").saveAsTable("spark_13450_large_table") ``` Ran this against trunk VS local build with this PR. OOM repros with trunk and with the fix this query runs fine. ``` ./bin/spark-shell --driver-java-options="-XX:+HeapDumpOnOutOfMemoryError -XX:HeapDumpPath=/tmp/spark.driver.heapdump.hprof" import org.apache.spark.sql._ val hc = SparkSession.builder.master("local").getOrCreate() hc.sql("SET spark.sql.autoBroadcastJoinThreshold=1") hc.sql("SET spark.sql.sortMergeJoinExec.buffer.spill.threshold=10000") hc.sql("DROP TABLE IF EXISTS spark_13450_result").collect hc.sql(""" CREATE TABLE spark_13450_result AS SELECT a.i AS a_i, a.j AS a_j, a.str1 AS a_str1, a.str2 AS a_str2, b.i AS b_i, b.j AS b_j, b.str1 AS b_str1, b.str2 AS b_str2 FROM spark_13450_one_row_table a JOIN spark_13450_large_table b ON a.i=b.i AND a.j=b.j """) ``` ## Performance comparison ### Macro-benchmark I ran a SMB join query over two real world tables (2 trillion rows (40 TB) and 6 million rows (120 GB)). Note that this dataset does not have skew so no spill happened. I saw improvement in CPU time by 2-4% over version without this PR. This did not add up as I was expected some regression. I think allocating array of capacity of 128 at the start (instead of starting with default size 16) is the sole reason for the perf. gain : https://github.com/tejasapatil/spark/blob/SPARK-13450_smb_buffer_oom/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala#L43 . I could remove that and rerun, but effectively the change will be deployed in this form and I wanted to see the effect of it over large workload. ### Micro-benchmark Two types of benchmarking can be found in `ExternalAppendOnlyUnsafeRowArrayBenchmark`: [A] Comparing `ExternalAppendOnlyUnsafeRowArray` against raw `ArrayBuffer` when all rows fit in-memory and there is no spill ``` Array with 1000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ ArrayBuffer 7821 / 7941 33.5 29.8 1.0X ExternalAppendOnlyUnsafeRowArray 8798 / 8819 29.8 33.6 0.9X Array with 30000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ ArrayBuffer 19200 / 19206 25.6 39.1 1.0X ExternalAppendOnlyUnsafeRowArray 19558 / 19562 25.1 39.8 1.0X Array with 100000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ ArrayBuffer 5949 / 6028 17.2 58.1 1.0X ExternalAppendOnlyUnsafeRowArray 6078 / 6138 16.8 59.4 1.0X ``` [B] Comparing `ExternalAppendOnlyUnsafeRowArray` against raw `UnsafeExternalSorter` when there is spilling of data ``` Spilling with 1000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ UnsafeExternalSorter 9239 / 9470 28.4 35.2 1.0X ExternalAppendOnlyUnsafeRowArray 8857 / 8909 29.6 33.8 1.0X Spilling with 10000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ UnsafeExternalSorter 4 / 5 39.3 25.5 1.0X ExternalAppendOnlyUnsafeRowArray 5 / 6 29.8 33.5 0.8X ``` Author: Tejas Patil <tejasp@fb.com> Closes #16909 from tejasapatil/SPARK-13450_smb_buffer_oom.
Diffstat (limited to 'sql/core/src/main/scala/org/apache')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala243
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala52
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala117
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/window/RowBuffer.scala115
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala72
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala97
6 files changed, 405 insertions, 291 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala
new file mode 100644
index 0000000000..458ac4ba36
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala
@@ -0,0 +1,243 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import java.util.ConcurrentModificationException
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.internal.Logging
+import org.apache.spark.memory.TaskMemoryManager
+import org.apache.spark.serializer.SerializerManager
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray.DefaultInitialSizeOfInMemoryBuffer
+import org.apache.spark.storage.BlockManager
+import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, UnsafeSorterIterator}
+
+/**
+ * An append-only array for [[UnsafeRow]]s that spills content to disk when there a predefined
+ * threshold of rows is reached.
+ *
+ * Setting spill threshold faces following trade-off:
+ *
+ * - If the spill threshold is too high, the in-memory array may occupy more memory than is
+ * available, resulting in OOM.
+ * - If the spill threshold is too low, we spill frequently and incur unnecessary disk writes.
+ * This may lead to a performance regression compared to the normal case of using an
+ * [[ArrayBuffer]] or [[Array]].
+ */
+private[sql] class ExternalAppendOnlyUnsafeRowArray(
+ taskMemoryManager: TaskMemoryManager,
+ blockManager: BlockManager,
+ serializerManager: SerializerManager,
+ taskContext: TaskContext,
+ initialSize: Int,
+ pageSizeBytes: Long,
+ numRowsSpillThreshold: Int) extends Logging {
+
+ def this(numRowsSpillThreshold: Int) {
+ this(
+ TaskContext.get().taskMemoryManager(),
+ SparkEnv.get.blockManager,
+ SparkEnv.get.serializerManager,
+ TaskContext.get(),
+ 1024,
+ SparkEnv.get.memoryManager.pageSizeBytes,
+ numRowsSpillThreshold)
+ }
+
+ private val initialSizeOfInMemoryBuffer =
+ Math.min(DefaultInitialSizeOfInMemoryBuffer, numRowsSpillThreshold)
+
+ private val inMemoryBuffer = if (initialSizeOfInMemoryBuffer > 0) {
+ new ArrayBuffer[UnsafeRow](initialSizeOfInMemoryBuffer)
+ } else {
+ null
+ }
+
+ private var spillableArray: UnsafeExternalSorter = _
+ private var numRows = 0
+
+ // A counter to keep track of total modifications done to this array since its creation.
+ // This helps to invalidate iterators when there are changes done to the backing array.
+ private var modificationsCount: Long = 0
+
+ private var numFieldsPerRow = 0
+
+ def length: Int = numRows
+
+ def isEmpty: Boolean = numRows == 0
+
+ /**
+ * Clears up resources (eg. memory) held by the backing storage
+ */
+ def clear(): Unit = {
+ if (spillableArray != null) {
+ // The last `spillableArray` of this task will be cleaned up via task completion listener
+ // inside `UnsafeExternalSorter`
+ spillableArray.cleanupResources()
+ spillableArray = null
+ } else if (inMemoryBuffer != null) {
+ inMemoryBuffer.clear()
+ }
+ numFieldsPerRow = 0
+ numRows = 0
+ modificationsCount += 1
+ }
+
+ def add(unsafeRow: UnsafeRow): Unit = {
+ if (numRows < numRowsSpillThreshold) {
+ inMemoryBuffer += unsafeRow.copy()
+ } else {
+ if (spillableArray == null) {
+ logInfo(s"Reached spill threshold of $numRowsSpillThreshold rows, switching to " +
+ s"${classOf[UnsafeExternalSorter].getName}")
+
+ // We will not sort the rows, so prefixComparator and recordComparator are null
+ spillableArray = UnsafeExternalSorter.create(
+ taskMemoryManager,
+ blockManager,
+ serializerManager,
+ taskContext,
+ null,
+ null,
+ initialSize,
+ pageSizeBytes,
+ numRowsSpillThreshold,
+ false)
+
+ // populate with existing in-memory buffered rows
+ if (inMemoryBuffer != null) {
+ inMemoryBuffer.foreach(existingUnsafeRow =>
+ spillableArray.insertRecord(
+ existingUnsafeRow.getBaseObject,
+ existingUnsafeRow.getBaseOffset,
+ existingUnsafeRow.getSizeInBytes,
+ 0,
+ false)
+ )
+ inMemoryBuffer.clear()
+ }
+ numFieldsPerRow = unsafeRow.numFields()
+ }
+
+ spillableArray.insertRecord(
+ unsafeRow.getBaseObject,
+ unsafeRow.getBaseOffset,
+ unsafeRow.getSizeInBytes,
+ 0,
+ false)
+ }
+
+ numRows += 1
+ modificationsCount += 1
+ }
+
+ /**
+ * Creates an [[Iterator]] for the current rows in the array starting from a user provided index
+ *
+ * If there are subsequent [[add()]] or [[clear()]] calls made on this array after creation of
+ * the iterator, then the iterator is invalidated thus saving clients from thinking that they
+ * have read all the data while there were new rows added to this array.
+ */
+ def generateIterator(startIndex: Int): Iterator[UnsafeRow] = {
+ if (startIndex < 0 || (numRows > 0 && startIndex > numRows)) {
+ throw new ArrayIndexOutOfBoundsException(
+ "Invalid `startIndex` provided for generating iterator over the array. " +
+ s"Total elements: $numRows, requested `startIndex`: $startIndex")
+ }
+
+ if (spillableArray == null) {
+ new InMemoryBufferIterator(startIndex)
+ } else {
+ new SpillableArrayIterator(spillableArray.getIterator, numFieldsPerRow, startIndex)
+ }
+ }
+
+ def generateIterator(): Iterator[UnsafeRow] = generateIterator(startIndex = 0)
+
+ private[this]
+ abstract class ExternalAppendOnlyUnsafeRowArrayIterator extends Iterator[UnsafeRow] {
+ private val expectedModificationsCount = modificationsCount
+
+ protected def isModified(): Boolean = expectedModificationsCount != modificationsCount
+
+ protected def throwExceptionIfModified(): Unit = {
+ if (expectedModificationsCount != modificationsCount) {
+ throw new ConcurrentModificationException(
+ s"The backing ${classOf[ExternalAppendOnlyUnsafeRowArray].getName} has been modified " +
+ s"since the creation of this Iterator")
+ }
+ }
+ }
+
+ private[this] class InMemoryBufferIterator(startIndex: Int)
+ extends ExternalAppendOnlyUnsafeRowArrayIterator {
+
+ private var currentIndex = startIndex
+
+ override def hasNext(): Boolean = !isModified() && currentIndex < numRows
+
+ override def next(): UnsafeRow = {
+ throwExceptionIfModified()
+ val result = inMemoryBuffer(currentIndex)
+ currentIndex += 1
+ result
+ }
+ }
+
+ private[this] class SpillableArrayIterator(
+ iterator: UnsafeSorterIterator,
+ numFieldPerRow: Int,
+ startIndex: Int)
+ extends ExternalAppendOnlyUnsafeRowArrayIterator {
+
+ private val currentRow = new UnsafeRow(numFieldPerRow)
+
+ def init(): Unit = {
+ var i = 0
+ while (i < startIndex) {
+ if (iterator.hasNext) {
+ iterator.loadNext()
+ } else {
+ throw new ArrayIndexOutOfBoundsException(
+ "Invalid `startIndex` provided for generating iterator over the array. " +
+ s"Total elements: $numRows, requested `startIndex`: $startIndex")
+ }
+ i += 1
+ }
+ }
+
+ // Traverse upto the given [[startIndex]]
+ init()
+
+ override def hasNext(): Boolean = !isModified() && iterator.hasNext
+
+ override def next(): UnsafeRow = {
+ throwExceptionIfModified()
+ iterator.loadNext()
+ currentRow.pointTo(iterator.getBaseObject, iterator.getBaseOffset, iterator.getRecordLength)
+ currentRow
+ }
+ }
+}
+
+private[sql] object ExternalAppendOnlyUnsafeRowArray {
+ val DefaultInitialSizeOfInMemoryBuffer = 128
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
index 8341fe2ffd..f380986951 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
@@ -19,65 +19,39 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark._
import org.apache.spark.rdd.{CartesianPartition, CartesianRDD, RDD}
-import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
-import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
+import org.apache.spark.sql.execution.{BinaryExecNode, ExternalAppendOnlyUnsafeRowArray, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.CompletionIterator
-import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
/**
* An optimized CartesianRDD for UnsafeRow, which will cache the rows from second child RDD,
* will be much faster than building the right partition for every row in left RDD, it also
* materialize the right RDD (in case of the right RDD is nondeterministic).
*/
-class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numFieldsOfRight: Int)
+class UnsafeCartesianRDD(
+ left : RDD[UnsafeRow],
+ right : RDD[UnsafeRow],
+ numFieldsOfRight: Int,
+ spillThreshold: Int)
extends CartesianRDD[UnsafeRow, UnsafeRow](left.sparkContext, left, right) {
override def compute(split: Partition, context: TaskContext): Iterator[(UnsafeRow, UnsafeRow)] = {
- // We will not sort the rows, so prefixComparator and recordComparator are null.
- val sorter = UnsafeExternalSorter.create(
- context.taskMemoryManager(),
- SparkEnv.get.blockManager,
- SparkEnv.get.serializerManager,
- context,
- null,
- null,
- 1024,
- SparkEnv.get.memoryManager.pageSizeBytes,
- SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold",
- UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD),
- false)
+ val rowArray = new ExternalAppendOnlyUnsafeRowArray(spillThreshold)
val partition = split.asInstanceOf[CartesianPartition]
- for (y <- rdd2.iterator(partition.s2, context)) {
- sorter.insertRecord(y.getBaseObject, y.getBaseOffset, y.getSizeInBytes, 0, false)
- }
+ rdd2.iterator(partition.s2, context).foreach(rowArray.add)
- // Create an iterator from sorter and wrapper it as Iterator[UnsafeRow]
- def createIter(): Iterator[UnsafeRow] = {
- val iter = sorter.getIterator
- val unsafeRow = new UnsafeRow(numFieldsOfRight)
- new Iterator[UnsafeRow] {
- override def hasNext: Boolean = {
- iter.hasNext
- }
- override def next(): UnsafeRow = {
- iter.loadNext()
- unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength)
- unsafeRow
- }
- }
- }
+ // Create an iterator from rowArray
+ def createIter(): Iterator[UnsafeRow] = rowArray.generateIterator()
val resultIter =
for (x <- rdd1.iterator(partition.s1, context);
y <- createIter()) yield (x, y)
CompletionIterator[(UnsafeRow, UnsafeRow), Iterator[(UnsafeRow, UnsafeRow)]](
- resultIter, sorter.cleanupResources())
+ resultIter, rowArray.clear())
}
}
@@ -97,7 +71,9 @@ case class CartesianProductExec(
val leftResults = left.execute().asInstanceOf[RDD[UnsafeRow]]
val rightResults = right.execute().asInstanceOf[RDD[UnsafeRow]]
- val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size)
+ val spillThreshold = sqlContext.conf.cartesianProductExecBufferSpillThreshold
+
+ val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size, spillThreshold)
pair.mapPartitionsWithIndexInternal { (index, iter) =>
val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)
val filtered = if (condition.isDefined) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index ca9c0ed8ce..bcdc4dcdf7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -25,7 +25,8 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, RowIterator, SparkPlan}
+import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport,
+ExternalAppendOnlyUnsafeRowArray, RowIterator, SparkPlan}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.util.collection.BitSet
@@ -95,9 +96,13 @@ case class SortMergeJoinExec(
private def createRightKeyGenerator(): Projection =
UnsafeProjection.create(rightKeys, right.output)
+ private def getSpillThreshold: Int = {
+ sqlContext.conf.sortMergeJoinExecBufferSpillThreshold
+ }
+
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
-
+ val spillThreshold = getSpillThreshold
left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
val boundCondition: (InternalRow) => Boolean = {
condition.map { cond =>
@@ -115,39 +120,39 @@ case class SortMergeJoinExec(
case _: InnerLike =>
new RowIterator {
private[this] var currentLeftRow: InternalRow = _
- private[this] var currentRightMatches: ArrayBuffer[InternalRow] = _
- private[this] var currentMatchIdx: Int = -1
+ private[this] var currentRightMatches: ExternalAppendOnlyUnsafeRowArray = _
+ private[this] var rightMatchesIterator: Iterator[UnsafeRow] = null
private[this] val smjScanner = new SortMergeJoinScanner(
createLeftKeyGenerator(),
createRightKeyGenerator(),
keyOrdering,
RowIterator.fromScala(leftIter),
- RowIterator.fromScala(rightIter)
+ RowIterator.fromScala(rightIter),
+ spillThreshold
)
private[this] val joinRow = new JoinedRow
if (smjScanner.findNextInnerJoinRows()) {
currentRightMatches = smjScanner.getBufferedMatches
currentLeftRow = smjScanner.getStreamedRow
- currentMatchIdx = 0
+ rightMatchesIterator = currentRightMatches.generateIterator()
}
override def advanceNext(): Boolean = {
- while (currentMatchIdx >= 0) {
- if (currentMatchIdx == currentRightMatches.length) {
+ while (rightMatchesIterator != null) {
+ if (!rightMatchesIterator.hasNext) {
if (smjScanner.findNextInnerJoinRows()) {
currentRightMatches = smjScanner.getBufferedMatches
currentLeftRow = smjScanner.getStreamedRow
- currentMatchIdx = 0
+ rightMatchesIterator = currentRightMatches.generateIterator()
} else {
currentRightMatches = null
currentLeftRow = null
- currentMatchIdx = -1
+ rightMatchesIterator = null
return false
}
}
- joinRow(currentLeftRow, currentRightMatches(currentMatchIdx))
- currentMatchIdx += 1
+ joinRow(currentLeftRow, rightMatchesIterator.next())
if (boundCondition(joinRow)) {
numOutputRows += 1
return true
@@ -165,7 +170,8 @@ case class SortMergeJoinExec(
bufferedKeyGenerator = createRightKeyGenerator(),
keyOrdering,
streamedIter = RowIterator.fromScala(leftIter),
- bufferedIter = RowIterator.fromScala(rightIter)
+ bufferedIter = RowIterator.fromScala(rightIter),
+ spillThreshold
)
val rightNullRow = new GenericInternalRow(right.output.length)
new LeftOuterIterator(
@@ -177,7 +183,8 @@ case class SortMergeJoinExec(
bufferedKeyGenerator = createLeftKeyGenerator(),
keyOrdering,
streamedIter = RowIterator.fromScala(rightIter),
- bufferedIter = RowIterator.fromScala(leftIter)
+ bufferedIter = RowIterator.fromScala(leftIter),
+ spillThreshold
)
val leftNullRow = new GenericInternalRow(left.output.length)
new RightOuterIterator(
@@ -209,7 +216,8 @@ case class SortMergeJoinExec(
createRightKeyGenerator(),
keyOrdering,
RowIterator.fromScala(leftIter),
- RowIterator.fromScala(rightIter)
+ RowIterator.fromScala(rightIter),
+ spillThreshold
)
private[this] val joinRow = new JoinedRow
@@ -217,14 +225,15 @@ case class SortMergeJoinExec(
while (smjScanner.findNextInnerJoinRows()) {
val currentRightMatches = smjScanner.getBufferedMatches
currentLeftRow = smjScanner.getStreamedRow
- var i = 0
- while (i < currentRightMatches.length) {
- joinRow(currentLeftRow, currentRightMatches(i))
- if (boundCondition(joinRow)) {
- numOutputRows += 1
- return true
+ if (currentRightMatches != null && currentRightMatches.length > 0) {
+ val rightMatchesIterator = currentRightMatches.generateIterator()
+ while (rightMatchesIterator.hasNext) {
+ joinRow(currentLeftRow, rightMatchesIterator.next())
+ if (boundCondition(joinRow)) {
+ numOutputRows += 1
+ return true
+ }
}
- i += 1
}
}
false
@@ -241,7 +250,8 @@ case class SortMergeJoinExec(
createRightKeyGenerator(),
keyOrdering,
RowIterator.fromScala(leftIter),
- RowIterator.fromScala(rightIter)
+ RowIterator.fromScala(rightIter),
+ spillThreshold
)
private[this] val joinRow = new JoinedRow
@@ -249,17 +259,16 @@ case class SortMergeJoinExec(
while (smjScanner.findNextOuterJoinRows()) {
currentLeftRow = smjScanner.getStreamedRow
val currentRightMatches = smjScanner.getBufferedMatches
- if (currentRightMatches == null) {
+ if (currentRightMatches == null || currentRightMatches.length == 0) {
return true
}
- var i = 0
var found = false
- while (!found && i < currentRightMatches.length) {
- joinRow(currentLeftRow, currentRightMatches(i))
+ val rightMatchesIterator = currentRightMatches.generateIterator()
+ while (!found && rightMatchesIterator.hasNext) {
+ joinRow(currentLeftRow, rightMatchesIterator.next())
if (boundCondition(joinRow)) {
found = true
}
- i += 1
}
if (!found) {
numOutputRows += 1
@@ -281,7 +290,8 @@ case class SortMergeJoinExec(
createRightKeyGenerator(),
keyOrdering,
RowIterator.fromScala(leftIter),
- RowIterator.fromScala(rightIter)
+ RowIterator.fromScala(rightIter),
+ spillThreshold
)
private[this] val joinRow = new JoinedRow
@@ -290,14 +300,13 @@ case class SortMergeJoinExec(
currentLeftRow = smjScanner.getStreamedRow
val currentRightMatches = smjScanner.getBufferedMatches
var found = false
- if (currentRightMatches != null) {
- var i = 0
- while (!found && i < currentRightMatches.length) {
- joinRow(currentLeftRow, currentRightMatches(i))
+ if (currentRightMatches != null && currentRightMatches.length > 0) {
+ val rightMatchesIterator = currentRightMatches.generateIterator()
+ while (!found && rightMatchesIterator.hasNext) {
+ joinRow(currentLeftRow, rightMatchesIterator.next())
if (boundCondition(joinRow)) {
found = true
}
- i += 1
}
}
result.setBoolean(0, found)
@@ -376,8 +385,11 @@ case class SortMergeJoinExec(
// A list to hold all matched rows from right side.
val matches = ctx.freshName("matches")
- val clsName = classOf[java.util.ArrayList[InternalRow]].getName
- ctx.addMutableState(clsName, matches, s"$matches = new $clsName();")
+ val clsName = classOf[ExternalAppendOnlyUnsafeRowArray].getName
+
+ val spillThreshold = getSpillThreshold
+
+ ctx.addMutableState(clsName, matches, s"$matches = new $clsName($spillThreshold);")
// Copy the left keys as class members so they could be used in next function call.
val matchedKeyVars = copyKeys(ctx, leftKeyVars)
@@ -428,7 +440,7 @@ case class SortMergeJoinExec(
| }
| $leftRow = null;
| } else {
- | $matches.add($rightRow.copy());
+ | $matches.add((UnsafeRow) $rightRow);
| $rightRow = null;;
| }
| } while ($leftRow != null);
@@ -517,8 +529,7 @@ case class SortMergeJoinExec(
val rightRow = ctx.freshName("rightRow")
val rightVars = createRightVar(ctx, rightRow)
- val size = ctx.freshName("size")
- val i = ctx.freshName("i")
+ val iterator = ctx.freshName("iterator")
val numOutput = metricTerm(ctx, "numOutputRows")
val (beforeLoop, condCheck) = if (condition.isDefined) {
// Split the code of creating variables based on whether it's used by condition or not.
@@ -551,10 +562,10 @@ case class SortMergeJoinExec(
s"""
|while (findNextInnerJoinRows($leftInput, $rightInput)) {
- | int $size = $matches.size();
| ${beforeLoop.trim}
- | for (int $i = 0; $i < $size; $i ++) {
- | InternalRow $rightRow = (InternalRow) $matches.get($i);
+ | scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator();
+ | while ($iterator.hasNext()) {
+ | InternalRow $rightRow = (InternalRow) $iterator.next();
| ${condCheck.trim}
| $numOutput.add(1);
| ${consume(ctx, leftVars ++ rightVars)}
@@ -589,7 +600,8 @@ private[joins] class SortMergeJoinScanner(
bufferedKeyGenerator: Projection,
keyOrdering: Ordering[InternalRow],
streamedIter: RowIterator,
- bufferedIter: RowIterator) {
+ bufferedIter: RowIterator,
+ bufferThreshold: Int) {
private[this] var streamedRow: InternalRow = _
private[this] var streamedRowKey: InternalRow = _
private[this] var bufferedRow: InternalRow = _
@@ -600,7 +612,7 @@ private[joins] class SortMergeJoinScanner(
*/
private[this] var matchJoinKey: InternalRow = _
/** Buffered rows from the buffered side of the join. This is empty if there are no matches. */
- private[this] val bufferedMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow]
+ private[this] val bufferedMatches = new ExternalAppendOnlyUnsafeRowArray(bufferThreshold)
// Initialization (note: do _not_ want to advance streamed here).
advancedBufferedToRowWithNullFreeJoinKey()
@@ -609,7 +621,7 @@ private[joins] class SortMergeJoinScanner(
def getStreamedRow: InternalRow = streamedRow
- def getBufferedMatches: ArrayBuffer[InternalRow] = bufferedMatches
+ def getBufferedMatches: ExternalAppendOnlyUnsafeRowArray = bufferedMatches
/**
* Advances both input iterators, stopping when we have found rows with matching join keys.
@@ -755,7 +767,7 @@ private[joins] class SortMergeJoinScanner(
matchJoinKey = streamedRowKey.copy()
bufferedMatches.clear()
do {
- bufferedMatches += bufferedRow.copy() // need to copy mutable rows before buffering them
+ bufferedMatches.add(bufferedRow.asInstanceOf[UnsafeRow])
advancedBufferedToRowWithNullFreeJoinKey()
} while (bufferedRow != null && keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0)
}
@@ -819,7 +831,7 @@ private abstract class OneSideOuterIterator(
protected[this] val joinedRow: JoinedRow = new JoinedRow()
// Index of the buffered rows, reset to 0 whenever we advance to a new streamed row
- private[this] var bufferIndex: Int = 0
+ private[this] var rightMatchesIterator: Iterator[UnsafeRow] = null
// This iterator is initialized lazily so there should be no matches initially
assert(smjScanner.getBufferedMatches.length == 0)
@@ -833,7 +845,7 @@ private abstract class OneSideOuterIterator(
* @return whether there are more rows in the stream to consume.
*/
private def advanceStream(): Boolean = {
- bufferIndex = 0
+ rightMatchesIterator = null
if (smjScanner.findNextOuterJoinRows()) {
setStreamSideOutput(smjScanner.getStreamedRow)
if (smjScanner.getBufferedMatches.isEmpty) {
@@ -858,10 +870,13 @@ private abstract class OneSideOuterIterator(
*/
private def advanceBufferUntilBoundConditionSatisfied(): Boolean = {
var foundMatch: Boolean = false
- while (!foundMatch && bufferIndex < smjScanner.getBufferedMatches.length) {
- setBufferedSideOutput(smjScanner.getBufferedMatches(bufferIndex))
+ if (rightMatchesIterator == null) {
+ rightMatchesIterator = smjScanner.getBufferedMatches.generateIterator()
+ }
+
+ while (!foundMatch && rightMatchesIterator.hasNext) {
+ setBufferedSideOutput(rightMatchesIterator.next())
foundMatch = boundCondition(joinedRow)
- bufferIndex += 1
}
foundMatch
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/RowBuffer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/RowBuffer.scala
deleted file mode 100644
index ee36c84251..0000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/RowBuffer.scala
+++ /dev/null
@@ -1,115 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.window
-
-import scala.collection.mutable.ArrayBuffer
-
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.UnsafeRow
-import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, UnsafeSorterIterator}
-
-
-/**
- * The interface of row buffer for a partition. In absence of a buffer pool (with locking), the
- * row buffer is used to materialize a partition of rows since we need to repeatedly scan these
- * rows in window function processing.
- */
-private[window] abstract class RowBuffer {
-
- /** Number of rows. */
- def size: Int
-
- /** Return next row in the buffer, null if no more left. */
- def next(): InternalRow
-
- /** Skip the next `n` rows. */
- def skip(n: Int): Unit
-
- /** Return a new RowBuffer that has the same rows. */
- def copy(): RowBuffer
-}
-
-/**
- * A row buffer based on ArrayBuffer (the number of rows is limited).
- */
-private[window] class ArrayRowBuffer(buffer: ArrayBuffer[UnsafeRow]) extends RowBuffer {
-
- private[this] var cursor: Int = -1
-
- /** Number of rows. */
- override def size: Int = buffer.length
-
- /** Return next row in the buffer, null if no more left. */
- override def next(): InternalRow = {
- cursor += 1
- if (cursor < buffer.length) {
- buffer(cursor)
- } else {
- null
- }
- }
-
- /** Skip the next `n` rows. */
- override def skip(n: Int): Unit = {
- cursor += n
- }
-
- /** Return a new RowBuffer that has the same rows. */
- override def copy(): RowBuffer = {
- new ArrayRowBuffer(buffer)
- }
-}
-
-/**
- * An external buffer of rows based on UnsafeExternalSorter.
- */
-private[window] class ExternalRowBuffer(sorter: UnsafeExternalSorter, numFields: Int)
- extends RowBuffer {
-
- private[this] val iter: UnsafeSorterIterator = sorter.getIterator
-
- private[this] val currentRow = new UnsafeRow(numFields)
-
- /** Number of rows. */
- override def size: Int = iter.getNumRecords()
-
- /** Return next row in the buffer, null if no more left. */
- override def next(): InternalRow = {
- if (iter.hasNext) {
- iter.loadNext()
- currentRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength)
- currentRow
- } else {
- null
- }
- }
-
- /** Skip the next `n` rows. */
- override def skip(n: Int): Unit = {
- var i = 0
- while (i < n && iter.hasNext) {
- iter.loadNext()
- i += 1
- }
- }
-
- /** Return a new RowBuffer that has the same rows. */
- override def copy(): RowBuffer = {
- new ExternalRowBuffer(sorter, numFields)
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala
index 80b87d5ffa..950a6794a7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala
@@ -20,15 +20,13 @@ package org.apache.spark.sql.execution.window
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.types.IntegerType
-import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
/**
* This class calculates and outputs (windowed) aggregates over the rows in a single (sorted)
@@ -284,6 +282,7 @@ case class WindowExec(
// Unwrap the expressions and factories from the map.
val expressions = windowFrameExpressionFactoryPairs.flatMap(_._1)
val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray
+ val spillThreshold = sqlContext.conf.windowExecBufferSpillThreshold
// Start processing.
child.execute().mapPartitions { stream =>
@@ -310,10 +309,12 @@ case class WindowExec(
fetchNextRow()
// Manage the current partition.
- val rows = ArrayBuffer.empty[UnsafeRow]
val inputFields = child.output.length
- var sorter: UnsafeExternalSorter = null
- var rowBuffer: RowBuffer = null
+
+ val buffer: ExternalAppendOnlyUnsafeRowArray =
+ new ExternalAppendOnlyUnsafeRowArray(spillThreshold)
+ var bufferIterator: Iterator[UnsafeRow] = _
+
val windowFunctionResult = new SpecificInternalRow(expressions.map(_.dataType))
val frames = factories.map(_(windowFunctionResult))
val numFrames = frames.length
@@ -323,78 +324,43 @@ case class WindowExec(
val currentGroup = nextGroup.copy()
// clear last partition
- if (sorter != null) {
- // the last sorter of this task will be cleaned up via task completion listener
- sorter.cleanupResources()
- sorter = null
- } else {
- rows.clear()
- }
+ buffer.clear()
while (nextRowAvailable && nextGroup == currentGroup) {
- if (sorter == null) {
- rows += nextRow.copy()
-
- if (rows.length >= 4096) {
- // We will not sort the rows, so prefixComparator and recordComparator are null.
- sorter = UnsafeExternalSorter.create(
- TaskContext.get().taskMemoryManager(),
- SparkEnv.get.blockManager,
- SparkEnv.get.serializerManager,
- TaskContext.get(),
- null,
- null,
- 1024,
- SparkEnv.get.memoryManager.pageSizeBytes,
- SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold",
- UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD),
- false)
- rows.foreach { r =>
- sorter.insertRecord(r.getBaseObject, r.getBaseOffset, r.getSizeInBytes, 0, false)
- }
- rows.clear()
- }
- } else {
- sorter.insertRecord(nextRow.getBaseObject, nextRow.getBaseOffset,
- nextRow.getSizeInBytes, 0, false)
- }
+ buffer.add(nextRow)
fetchNextRow()
}
- if (sorter != null) {
- rowBuffer = new ExternalRowBuffer(sorter, inputFields)
- } else {
- rowBuffer = new ArrayRowBuffer(rows)
- }
// Setup the frames.
var i = 0
while (i < numFrames) {
- frames(i).prepare(rowBuffer.copy())
+ frames(i).prepare(buffer)
i += 1
}
// Setup iteration
rowIndex = 0
- rowsSize = rowBuffer.size
+ bufferIterator = buffer.generateIterator()
}
// Iteration
var rowIndex = 0
- var rowsSize = 0L
- override final def hasNext: Boolean = rowIndex < rowsSize || nextRowAvailable
+ override final def hasNext: Boolean =
+ (bufferIterator != null && bufferIterator.hasNext) || nextRowAvailable
val join = new JoinedRow
override final def next(): InternalRow = {
// Load the next partition if we need to.
- if (rowIndex >= rowsSize && nextRowAvailable) {
+ if ((bufferIterator == null || !bufferIterator.hasNext) && nextRowAvailable) {
fetchNextPartition()
}
- if (rowIndex < rowsSize) {
+ if (bufferIterator.hasNext) {
+ val current = bufferIterator.next()
+
// Get the results for the window frames.
var i = 0
- val current = rowBuffer.next()
while (i < numFrames) {
frames(i).write(rowIndex, current)
i += 1
@@ -406,7 +372,9 @@ case class WindowExec(
// Return the projection.
result(join)
- } else throw new NoSuchElementException
+ } else {
+ throw new NoSuchElementException
+ }
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
index 70efc0f78d..af2b4fb920 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
@@ -22,6 +22,7 @@ import java.util
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
+import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray
/**
@@ -35,7 +36,7 @@ private[window] abstract class WindowFunctionFrame {
*
* @param rows to calculate the frame results for.
*/
- def prepare(rows: RowBuffer): Unit
+ def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit
/**
* Write the current results to the target row.
@@ -43,6 +44,12 @@ private[window] abstract class WindowFunctionFrame {
def write(index: Int, current: InternalRow): Unit
}
+object WindowFunctionFrame {
+ def getNextOrNull(iterator: Iterator[UnsafeRow]): UnsafeRow = {
+ if (iterator.hasNext) iterator.next() else null
+ }
+}
+
/**
* The offset window frame calculates frames containing LEAD/LAG statements.
*
@@ -65,7 +72,12 @@ private[window] final class OffsetWindowFunctionFrame(
extends WindowFunctionFrame {
/** Rows of the partition currently being processed. */
- private[this] var input: RowBuffer = null
+ private[this] var input: ExternalAppendOnlyUnsafeRowArray = null
+
+ /**
+ * An iterator over the [[input]]
+ */
+ private[this] var inputIterator: Iterator[UnsafeRow] = _
/** Index of the input row currently used for output. */
private[this] var inputIndex = 0
@@ -103,20 +115,21 @@ private[window] final class OffsetWindowFunctionFrame(
newMutableProjection(boundExpressions, Nil).target(target)
}
- override def prepare(rows: RowBuffer): Unit = {
+ override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
input = rows
+ inputIterator = input.generateIterator()
// drain the first few rows if offset is larger than zero
inputIndex = 0
while (inputIndex < offset) {
- input.next()
+ if (inputIterator.hasNext) inputIterator.next()
inputIndex += 1
}
inputIndex = offset
}
override def write(index: Int, current: InternalRow): Unit = {
- if (inputIndex >= 0 && inputIndex < input.size) {
- val r = input.next()
+ if (inputIndex >= 0 && inputIndex < input.length) {
+ val r = WindowFunctionFrame.getNextOrNull(inputIterator)
projection(r)
} else {
// Use default values since the offset row does not exist.
@@ -143,7 +156,12 @@ private[window] final class SlidingWindowFunctionFrame(
extends WindowFunctionFrame {
/** Rows of the partition currently being processed. */
- private[this] var input: RowBuffer = null
+ private[this] var input: ExternalAppendOnlyUnsafeRowArray = null
+
+ /**
+ * An iterator over the [[input]]
+ */
+ private[this] var inputIterator: Iterator[UnsafeRow] = _
/** The next row from `input`. */
private[this] var nextRow: InternalRow = null
@@ -164,9 +182,10 @@ private[window] final class SlidingWindowFunctionFrame(
private[this] var inputLowIndex = 0
/** Prepare the frame for calculating a new partition. Reset all variables. */
- override def prepare(rows: RowBuffer): Unit = {
+ override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
input = rows
- nextRow = rows.next()
+ inputIterator = input.generateIterator()
+ nextRow = WindowFunctionFrame.getNextOrNull(inputIterator)
inputHighIndex = 0
inputLowIndex = 0
buffer.clear()
@@ -180,7 +199,7 @@ private[window] final class SlidingWindowFunctionFrame(
// the output row upper bound.
while (nextRow != null && ubound.compare(nextRow, inputHighIndex, current, index) <= 0) {
buffer.add(nextRow.copy())
- nextRow = input.next()
+ nextRow = WindowFunctionFrame.getNextOrNull(inputIterator)
inputHighIndex += 1
bufferUpdated = true
}
@@ -195,7 +214,7 @@ private[window] final class SlidingWindowFunctionFrame(
// Only recalculate and update when the buffer changes.
if (bufferUpdated) {
- processor.initialize(input.size)
+ processor.initialize(input.length)
val iter = buffer.iterator()
while (iter.hasNext) {
processor.update(iter.next())
@@ -222,13 +241,12 @@ private[window] final class UnboundedWindowFunctionFrame(
extends WindowFunctionFrame {
/** Prepare the frame for calculating a new partition. Process all rows eagerly. */
- override def prepare(rows: RowBuffer): Unit = {
- val size = rows.size
- processor.initialize(size)
- var i = 0
- while (i < size) {
- processor.update(rows.next())
- i += 1
+ override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
+ processor.initialize(rows.length)
+
+ val iterator = rows.generateIterator()
+ while (iterator.hasNext) {
+ processor.update(iterator.next())
}
}
@@ -261,7 +279,12 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame(
extends WindowFunctionFrame {
/** Rows of the partition currently being processed. */
- private[this] var input: RowBuffer = null
+ private[this] var input: ExternalAppendOnlyUnsafeRowArray = null
+
+ /**
+ * An iterator over the [[input]]
+ */
+ private[this] var inputIterator: Iterator[UnsafeRow] = _
/** The next row from `input`. */
private[this] var nextRow: InternalRow = null
@@ -273,11 +296,15 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame(
private[this] var inputIndex = 0
/** Prepare the frame for calculating a new partition. */
- override def prepare(rows: RowBuffer): Unit = {
+ override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
input = rows
- nextRow = rows.next()
inputIndex = 0
- processor.initialize(input.size)
+ inputIterator = input.generateIterator()
+ if (inputIterator.hasNext) {
+ nextRow = inputIterator.next()
+ }
+
+ processor.initialize(input.length)
}
/** Write the frame columns for the current row to the given target row. */
@@ -288,7 +315,7 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame(
// the output row upper bound.
while (nextRow != null && ubound.compare(nextRow, inputIndex, current, index) <= 0) {
processor.update(nextRow)
- nextRow = input.next()
+ nextRow = WindowFunctionFrame.getNextOrNull(inputIterator)
inputIndex += 1
bufferUpdated = true
}
@@ -323,7 +350,7 @@ private[window] final class UnboundedFollowingWindowFunctionFrame(
extends WindowFunctionFrame {
/** Rows of the partition currently being processed. */
- private[this] var input: RowBuffer = null
+ private[this] var input: ExternalAppendOnlyUnsafeRowArray = null
/**
* Index of the first input row with a value equal to or greater than the lower bound of the
@@ -332,7 +359,7 @@ private[window] final class UnboundedFollowingWindowFunctionFrame(
private[this] var inputIndex = 0
/** Prepare the frame for calculating a new partition. */
- override def prepare(rows: RowBuffer): Unit = {
+ override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
input = rows
inputIndex = 0
}
@@ -341,25 +368,25 @@ private[window] final class UnboundedFollowingWindowFunctionFrame(
override def write(index: Int, current: InternalRow): Unit = {
var bufferUpdated = index == 0
- // Duplicate the input to have a new iterator
- val tmp = input.copy()
-
- // Drop all rows from the buffer for which the input row value is smaller than
+ // Ignore all the rows from the buffer for which the input row value is smaller than
// the output row lower bound.
- tmp.skip(inputIndex)
- var nextRow = tmp.next()
+ val iterator = input.generateIterator(startIndex = inputIndex)
+
+ var nextRow = WindowFunctionFrame.getNextOrNull(iterator)
while (nextRow != null && lbound.compare(nextRow, inputIndex, current, index) < 0) {
- nextRow = tmp.next()
inputIndex += 1
bufferUpdated = true
+ nextRow = WindowFunctionFrame.getNextOrNull(iterator)
}
// Only recalculate and update when the buffer changes.
if (bufferUpdated) {
- processor.initialize(input.size)
- while (nextRow != null) {
+ processor.initialize(input.length)
+ if (nextRow != null) {
processor.update(nextRow)
- nextRow = tmp.next()
+ }
+ while (iterator.hasNext) {
+ processor.update(iterator.next())
}
processor.evaluate(target)
}