diff options
Diffstat (limited to 'sql/core/src')
10 files changed, 1157 insertions, 292 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala new file mode 100644 index 0000000000..458ac4ba36 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.util.ConcurrentModificationException + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.serializer.SerializerManager +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray.DefaultInitialSizeOfInMemoryBuffer +import org.apache.spark.storage.BlockManager +import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, UnsafeSorterIterator} + +/** + * An append-only array for [[UnsafeRow]]s that spills content to disk when there a predefined + * threshold of rows is reached. + * + * Setting spill threshold faces following trade-off: + * + * - If the spill threshold is too high, the in-memory array may occupy more memory than is + * available, resulting in OOM. + * - If the spill threshold is too low, we spill frequently and incur unnecessary disk writes. + * This may lead to a performance regression compared to the normal case of using an + * [[ArrayBuffer]] or [[Array]]. + */ +private[sql] class ExternalAppendOnlyUnsafeRowArray( + taskMemoryManager: TaskMemoryManager, + blockManager: BlockManager, + serializerManager: SerializerManager, + taskContext: TaskContext, + initialSize: Int, + pageSizeBytes: Long, + numRowsSpillThreshold: Int) extends Logging { + + def this(numRowsSpillThreshold: Int) { + this( + TaskContext.get().taskMemoryManager(), + SparkEnv.get.blockManager, + SparkEnv.get.serializerManager, + TaskContext.get(), + 1024, + SparkEnv.get.memoryManager.pageSizeBytes, + numRowsSpillThreshold) + } + + private val initialSizeOfInMemoryBuffer = + Math.min(DefaultInitialSizeOfInMemoryBuffer, numRowsSpillThreshold) + + private val inMemoryBuffer = if (initialSizeOfInMemoryBuffer > 0) { + new ArrayBuffer[UnsafeRow](initialSizeOfInMemoryBuffer) + } else { + null + } + + private var spillableArray: UnsafeExternalSorter = _ + private var numRows = 0 + + // A counter to keep track of total modifications done to this array since its creation. + // This helps to invalidate iterators when there are changes done to the backing array. + private var modificationsCount: Long = 0 + + private var numFieldsPerRow = 0 + + def length: Int = numRows + + def isEmpty: Boolean = numRows == 0 + + /** + * Clears up resources (eg. memory) held by the backing storage + */ + def clear(): Unit = { + if (spillableArray != null) { + // The last `spillableArray` of this task will be cleaned up via task completion listener + // inside `UnsafeExternalSorter` + spillableArray.cleanupResources() + spillableArray = null + } else if (inMemoryBuffer != null) { + inMemoryBuffer.clear() + } + numFieldsPerRow = 0 + numRows = 0 + modificationsCount += 1 + } + + def add(unsafeRow: UnsafeRow): Unit = { + if (numRows < numRowsSpillThreshold) { + inMemoryBuffer += unsafeRow.copy() + } else { + if (spillableArray == null) { + logInfo(s"Reached spill threshold of $numRowsSpillThreshold rows, switching to " + + s"${classOf[UnsafeExternalSorter].getName}") + + // We will not sort the rows, so prefixComparator and recordComparator are null + spillableArray = UnsafeExternalSorter.create( + taskMemoryManager, + blockManager, + serializerManager, + taskContext, + null, + null, + initialSize, + pageSizeBytes, + numRowsSpillThreshold, + false) + + // populate with existing in-memory buffered rows + if (inMemoryBuffer != null) { + inMemoryBuffer.foreach(existingUnsafeRow => + spillableArray.insertRecord( + existingUnsafeRow.getBaseObject, + existingUnsafeRow.getBaseOffset, + existingUnsafeRow.getSizeInBytes, + 0, + false) + ) + inMemoryBuffer.clear() + } + numFieldsPerRow = unsafeRow.numFields() + } + + spillableArray.insertRecord( + unsafeRow.getBaseObject, + unsafeRow.getBaseOffset, + unsafeRow.getSizeInBytes, + 0, + false) + } + + numRows += 1 + modificationsCount += 1 + } + + /** + * Creates an [[Iterator]] for the current rows in the array starting from a user provided index + * + * If there are subsequent [[add()]] or [[clear()]] calls made on this array after creation of + * the iterator, then the iterator is invalidated thus saving clients from thinking that they + * have read all the data while there were new rows added to this array. + */ + def generateIterator(startIndex: Int): Iterator[UnsafeRow] = { + if (startIndex < 0 || (numRows > 0 && startIndex > numRows)) { + throw new ArrayIndexOutOfBoundsException( + "Invalid `startIndex` provided for generating iterator over the array. " + + s"Total elements: $numRows, requested `startIndex`: $startIndex") + } + + if (spillableArray == null) { + new InMemoryBufferIterator(startIndex) + } else { + new SpillableArrayIterator(spillableArray.getIterator, numFieldsPerRow, startIndex) + } + } + + def generateIterator(): Iterator[UnsafeRow] = generateIterator(startIndex = 0) + + private[this] + abstract class ExternalAppendOnlyUnsafeRowArrayIterator extends Iterator[UnsafeRow] { + private val expectedModificationsCount = modificationsCount + + protected def isModified(): Boolean = expectedModificationsCount != modificationsCount + + protected def throwExceptionIfModified(): Unit = { + if (expectedModificationsCount != modificationsCount) { + throw new ConcurrentModificationException( + s"The backing ${classOf[ExternalAppendOnlyUnsafeRowArray].getName} has been modified " + + s"since the creation of this Iterator") + } + } + } + + private[this] class InMemoryBufferIterator(startIndex: Int) + extends ExternalAppendOnlyUnsafeRowArrayIterator { + + private var currentIndex = startIndex + + override def hasNext(): Boolean = !isModified() && currentIndex < numRows + + override def next(): UnsafeRow = { + throwExceptionIfModified() + val result = inMemoryBuffer(currentIndex) + currentIndex += 1 + result + } + } + + private[this] class SpillableArrayIterator( + iterator: UnsafeSorterIterator, + numFieldPerRow: Int, + startIndex: Int) + extends ExternalAppendOnlyUnsafeRowArrayIterator { + + private val currentRow = new UnsafeRow(numFieldPerRow) + + def init(): Unit = { + var i = 0 + while (i < startIndex) { + if (iterator.hasNext) { + iterator.loadNext() + } else { + throw new ArrayIndexOutOfBoundsException( + "Invalid `startIndex` provided for generating iterator over the array. " + + s"Total elements: $numRows, requested `startIndex`: $startIndex") + } + i += 1 + } + } + + // Traverse upto the given [[startIndex]] + init() + + override def hasNext(): Boolean = !isModified() && iterator.hasNext + + override def next(): UnsafeRow = { + throwExceptionIfModified() + iterator.loadNext() + currentRow.pointTo(iterator.getBaseObject, iterator.getBaseOffset, iterator.getRecordLength) + currentRow + } + } +} + +private[sql] object ExternalAppendOnlyUnsafeRowArray { + val DefaultInitialSizeOfInMemoryBuffer = 128 +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala index 8341fe2ffd..f380986951 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala @@ -19,65 +19,39 @@ package org.apache.spark.sql.execution.joins import org.apache.spark._ import org.apache.spark.rdd.{CartesianPartition, CartesianRDD, RDD} -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner -import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan} +import org.apache.spark.sql.execution.{BinaryExecNode, ExternalAppendOnlyUnsafeRowArray, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.CompletionIterator -import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter /** * An optimized CartesianRDD for UnsafeRow, which will cache the rows from second child RDD, * will be much faster than building the right partition for every row in left RDD, it also * materialize the right RDD (in case of the right RDD is nondeterministic). */ -class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numFieldsOfRight: Int) +class UnsafeCartesianRDD( + left : RDD[UnsafeRow], + right : RDD[UnsafeRow], + numFieldsOfRight: Int, + spillThreshold: Int) extends CartesianRDD[UnsafeRow, UnsafeRow](left.sparkContext, left, right) { override def compute(split: Partition, context: TaskContext): Iterator[(UnsafeRow, UnsafeRow)] = { - // We will not sort the rows, so prefixComparator and recordComparator are null. - val sorter = UnsafeExternalSorter.create( - context.taskMemoryManager(), - SparkEnv.get.blockManager, - SparkEnv.get.serializerManager, - context, - null, - null, - 1024, - SparkEnv.get.memoryManager.pageSizeBytes, - SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", - UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD), - false) + val rowArray = new ExternalAppendOnlyUnsafeRowArray(spillThreshold) val partition = split.asInstanceOf[CartesianPartition] - for (y <- rdd2.iterator(partition.s2, context)) { - sorter.insertRecord(y.getBaseObject, y.getBaseOffset, y.getSizeInBytes, 0, false) - } + rdd2.iterator(partition.s2, context).foreach(rowArray.add) - // Create an iterator from sorter and wrapper it as Iterator[UnsafeRow] - def createIter(): Iterator[UnsafeRow] = { - val iter = sorter.getIterator - val unsafeRow = new UnsafeRow(numFieldsOfRight) - new Iterator[UnsafeRow] { - override def hasNext: Boolean = { - iter.hasNext - } - override def next(): UnsafeRow = { - iter.loadNext() - unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength) - unsafeRow - } - } - } + // Create an iterator from rowArray + def createIter(): Iterator[UnsafeRow] = rowArray.generateIterator() val resultIter = for (x <- rdd1.iterator(partition.s1, context); y <- createIter()) yield (x, y) CompletionIterator[(UnsafeRow, UnsafeRow), Iterator[(UnsafeRow, UnsafeRow)]]( - resultIter, sorter.cleanupResources()) + resultIter, rowArray.clear()) } } @@ -97,7 +71,9 @@ case class CartesianProductExec( val leftResults = left.execute().asInstanceOf[RDD[UnsafeRow]] val rightResults = right.execute().asInstanceOf[RDD[UnsafeRow]] - val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size) + val spillThreshold = sqlContext.conf.cartesianProductExecBufferSpillThreshold + + val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size, spillThreshold) pair.mapPartitionsWithIndexInternal { (index, iter) => val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema) val filtered = if (condition.isDefined) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index ca9c0ed8ce..bcdc4dcdf7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -25,7 +25,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, RowIterator, SparkPlan} +import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, +ExternalAppendOnlyUnsafeRowArray, RowIterator, SparkPlan} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.util.collection.BitSet @@ -95,9 +96,13 @@ case class SortMergeJoinExec( private def createRightKeyGenerator(): Projection = UnsafeProjection.create(rightKeys, right.output) + private def getSpillThreshold: Int = { + sqlContext.conf.sortMergeJoinExecBufferSpillThreshold + } + protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") - + val spillThreshold = getSpillThreshold left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => val boundCondition: (InternalRow) => Boolean = { condition.map { cond => @@ -115,39 +120,39 @@ case class SortMergeJoinExec( case _: InnerLike => new RowIterator { private[this] var currentLeftRow: InternalRow = _ - private[this] var currentRightMatches: ArrayBuffer[InternalRow] = _ - private[this] var currentMatchIdx: Int = -1 + private[this] var currentRightMatches: ExternalAppendOnlyUnsafeRowArray = _ + private[this] var rightMatchesIterator: Iterator[UnsafeRow] = null private[this] val smjScanner = new SortMergeJoinScanner( createLeftKeyGenerator(), createRightKeyGenerator(), keyOrdering, RowIterator.fromScala(leftIter), - RowIterator.fromScala(rightIter) + RowIterator.fromScala(rightIter), + spillThreshold ) private[this] val joinRow = new JoinedRow if (smjScanner.findNextInnerJoinRows()) { currentRightMatches = smjScanner.getBufferedMatches currentLeftRow = smjScanner.getStreamedRow - currentMatchIdx = 0 + rightMatchesIterator = currentRightMatches.generateIterator() } override def advanceNext(): Boolean = { - while (currentMatchIdx >= 0) { - if (currentMatchIdx == currentRightMatches.length) { + while (rightMatchesIterator != null) { + if (!rightMatchesIterator.hasNext) { if (smjScanner.findNextInnerJoinRows()) { currentRightMatches = smjScanner.getBufferedMatches currentLeftRow = smjScanner.getStreamedRow - currentMatchIdx = 0 + rightMatchesIterator = currentRightMatches.generateIterator() } else { currentRightMatches = null currentLeftRow = null - currentMatchIdx = -1 + rightMatchesIterator = null return false } } - joinRow(currentLeftRow, currentRightMatches(currentMatchIdx)) - currentMatchIdx += 1 + joinRow(currentLeftRow, rightMatchesIterator.next()) if (boundCondition(joinRow)) { numOutputRows += 1 return true @@ -165,7 +170,8 @@ case class SortMergeJoinExec( bufferedKeyGenerator = createRightKeyGenerator(), keyOrdering, streamedIter = RowIterator.fromScala(leftIter), - bufferedIter = RowIterator.fromScala(rightIter) + bufferedIter = RowIterator.fromScala(rightIter), + spillThreshold ) val rightNullRow = new GenericInternalRow(right.output.length) new LeftOuterIterator( @@ -177,7 +183,8 @@ case class SortMergeJoinExec( bufferedKeyGenerator = createLeftKeyGenerator(), keyOrdering, streamedIter = RowIterator.fromScala(rightIter), - bufferedIter = RowIterator.fromScala(leftIter) + bufferedIter = RowIterator.fromScala(leftIter), + spillThreshold ) val leftNullRow = new GenericInternalRow(left.output.length) new RightOuterIterator( @@ -209,7 +216,8 @@ case class SortMergeJoinExec( createRightKeyGenerator(), keyOrdering, RowIterator.fromScala(leftIter), - RowIterator.fromScala(rightIter) + RowIterator.fromScala(rightIter), + spillThreshold ) private[this] val joinRow = new JoinedRow @@ -217,14 +225,15 @@ case class SortMergeJoinExec( while (smjScanner.findNextInnerJoinRows()) { val currentRightMatches = smjScanner.getBufferedMatches currentLeftRow = smjScanner.getStreamedRow - var i = 0 - while (i < currentRightMatches.length) { - joinRow(currentLeftRow, currentRightMatches(i)) - if (boundCondition(joinRow)) { - numOutputRows += 1 - return true + if (currentRightMatches != null && currentRightMatches.length > 0) { + val rightMatchesIterator = currentRightMatches.generateIterator() + while (rightMatchesIterator.hasNext) { + joinRow(currentLeftRow, rightMatchesIterator.next()) + if (boundCondition(joinRow)) { + numOutputRows += 1 + return true + } } - i += 1 } } false @@ -241,7 +250,8 @@ case class SortMergeJoinExec( createRightKeyGenerator(), keyOrdering, RowIterator.fromScala(leftIter), - RowIterator.fromScala(rightIter) + RowIterator.fromScala(rightIter), + spillThreshold ) private[this] val joinRow = new JoinedRow @@ -249,17 +259,16 @@ case class SortMergeJoinExec( while (smjScanner.findNextOuterJoinRows()) { currentLeftRow = smjScanner.getStreamedRow val currentRightMatches = smjScanner.getBufferedMatches - if (currentRightMatches == null) { + if (currentRightMatches == null || currentRightMatches.length == 0) { return true } - var i = 0 var found = false - while (!found && i < currentRightMatches.length) { - joinRow(currentLeftRow, currentRightMatches(i)) + val rightMatchesIterator = currentRightMatches.generateIterator() + while (!found && rightMatchesIterator.hasNext) { + joinRow(currentLeftRow, rightMatchesIterator.next()) if (boundCondition(joinRow)) { found = true } - i += 1 } if (!found) { numOutputRows += 1 @@ -281,7 +290,8 @@ case class SortMergeJoinExec( createRightKeyGenerator(), keyOrdering, RowIterator.fromScala(leftIter), - RowIterator.fromScala(rightIter) + RowIterator.fromScala(rightIter), + spillThreshold ) private[this] val joinRow = new JoinedRow @@ -290,14 +300,13 @@ case class SortMergeJoinExec( currentLeftRow = smjScanner.getStreamedRow val currentRightMatches = smjScanner.getBufferedMatches var found = false - if (currentRightMatches != null) { - var i = 0 - while (!found && i < currentRightMatches.length) { - joinRow(currentLeftRow, currentRightMatches(i)) + if (currentRightMatches != null && currentRightMatches.length > 0) { + val rightMatchesIterator = currentRightMatches.generateIterator() + while (!found && rightMatchesIterator.hasNext) { + joinRow(currentLeftRow, rightMatchesIterator.next()) if (boundCondition(joinRow)) { found = true } - i += 1 } } result.setBoolean(0, found) @@ -376,8 +385,11 @@ case class SortMergeJoinExec( // A list to hold all matched rows from right side. val matches = ctx.freshName("matches") - val clsName = classOf[java.util.ArrayList[InternalRow]].getName - ctx.addMutableState(clsName, matches, s"$matches = new $clsName();") + val clsName = classOf[ExternalAppendOnlyUnsafeRowArray].getName + + val spillThreshold = getSpillThreshold + + ctx.addMutableState(clsName, matches, s"$matches = new $clsName($spillThreshold);") // Copy the left keys as class members so they could be used in next function call. val matchedKeyVars = copyKeys(ctx, leftKeyVars) @@ -428,7 +440,7 @@ case class SortMergeJoinExec( | } | $leftRow = null; | } else { - | $matches.add($rightRow.copy()); + | $matches.add((UnsafeRow) $rightRow); | $rightRow = null;; | } | } while ($leftRow != null); @@ -517,8 +529,7 @@ case class SortMergeJoinExec( val rightRow = ctx.freshName("rightRow") val rightVars = createRightVar(ctx, rightRow) - val size = ctx.freshName("size") - val i = ctx.freshName("i") + val iterator = ctx.freshName("iterator") val numOutput = metricTerm(ctx, "numOutputRows") val (beforeLoop, condCheck) = if (condition.isDefined) { // Split the code of creating variables based on whether it's used by condition or not. @@ -551,10 +562,10 @@ case class SortMergeJoinExec( s""" |while (findNextInnerJoinRows($leftInput, $rightInput)) { - | int $size = $matches.size(); | ${beforeLoop.trim} - | for (int $i = 0; $i < $size; $i ++) { - | InternalRow $rightRow = (InternalRow) $matches.get($i); + | scala.collection.Iterator<UnsafeRow> $iterator = $matches.generateIterator(); + | while ($iterator.hasNext()) { + | InternalRow $rightRow = (InternalRow) $iterator.next(); | ${condCheck.trim} | $numOutput.add(1); | ${consume(ctx, leftVars ++ rightVars)} @@ -589,7 +600,8 @@ private[joins] class SortMergeJoinScanner( bufferedKeyGenerator: Projection, keyOrdering: Ordering[InternalRow], streamedIter: RowIterator, - bufferedIter: RowIterator) { + bufferedIter: RowIterator, + bufferThreshold: Int) { private[this] var streamedRow: InternalRow = _ private[this] var streamedRowKey: InternalRow = _ private[this] var bufferedRow: InternalRow = _ @@ -600,7 +612,7 @@ private[joins] class SortMergeJoinScanner( */ private[this] var matchJoinKey: InternalRow = _ /** Buffered rows from the buffered side of the join. This is empty if there are no matches. */ - private[this] val bufferedMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow] + private[this] val bufferedMatches = new ExternalAppendOnlyUnsafeRowArray(bufferThreshold) // Initialization (note: do _not_ want to advance streamed here). advancedBufferedToRowWithNullFreeJoinKey() @@ -609,7 +621,7 @@ private[joins] class SortMergeJoinScanner( def getStreamedRow: InternalRow = streamedRow - def getBufferedMatches: ArrayBuffer[InternalRow] = bufferedMatches + def getBufferedMatches: ExternalAppendOnlyUnsafeRowArray = bufferedMatches /** * Advances both input iterators, stopping when we have found rows with matching join keys. @@ -755,7 +767,7 @@ private[joins] class SortMergeJoinScanner( matchJoinKey = streamedRowKey.copy() bufferedMatches.clear() do { - bufferedMatches += bufferedRow.copy() // need to copy mutable rows before buffering them + bufferedMatches.add(bufferedRow.asInstanceOf[UnsafeRow]) advancedBufferedToRowWithNullFreeJoinKey() } while (bufferedRow != null && keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0) } @@ -819,7 +831,7 @@ private abstract class OneSideOuterIterator( protected[this] val joinedRow: JoinedRow = new JoinedRow() // Index of the buffered rows, reset to 0 whenever we advance to a new streamed row - private[this] var bufferIndex: Int = 0 + private[this] var rightMatchesIterator: Iterator[UnsafeRow] = null // This iterator is initialized lazily so there should be no matches initially assert(smjScanner.getBufferedMatches.length == 0) @@ -833,7 +845,7 @@ private abstract class OneSideOuterIterator( * @return whether there are more rows in the stream to consume. */ private def advanceStream(): Boolean = { - bufferIndex = 0 + rightMatchesIterator = null if (smjScanner.findNextOuterJoinRows()) { setStreamSideOutput(smjScanner.getStreamedRow) if (smjScanner.getBufferedMatches.isEmpty) { @@ -858,10 +870,13 @@ private abstract class OneSideOuterIterator( */ private def advanceBufferUntilBoundConditionSatisfied(): Boolean = { var foundMatch: Boolean = false - while (!foundMatch && bufferIndex < smjScanner.getBufferedMatches.length) { - setBufferedSideOutput(smjScanner.getBufferedMatches(bufferIndex)) + if (rightMatchesIterator == null) { + rightMatchesIterator = smjScanner.getBufferedMatches.generateIterator() + } + + while (!foundMatch && rightMatchesIterator.hasNext) { + setBufferedSideOutput(rightMatchesIterator.next()) foundMatch = boundCondition(joinedRow) - bufferIndex += 1 } foundMatch } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/RowBuffer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/RowBuffer.scala deleted file mode 100644 index ee36c84251..0000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/RowBuffer.scala +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.window - -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, UnsafeSorterIterator} - - -/** - * The interface of row buffer for a partition. In absence of a buffer pool (with locking), the - * row buffer is used to materialize a partition of rows since we need to repeatedly scan these - * rows in window function processing. - */ -private[window] abstract class RowBuffer { - - /** Number of rows. */ - def size: Int - - /** Return next row in the buffer, null if no more left. */ - def next(): InternalRow - - /** Skip the next `n` rows. */ - def skip(n: Int): Unit - - /** Return a new RowBuffer that has the same rows. */ - def copy(): RowBuffer -} - -/** - * A row buffer based on ArrayBuffer (the number of rows is limited). - */ -private[window] class ArrayRowBuffer(buffer: ArrayBuffer[UnsafeRow]) extends RowBuffer { - - private[this] var cursor: Int = -1 - - /** Number of rows. */ - override def size: Int = buffer.length - - /** Return next row in the buffer, null if no more left. */ - override def next(): InternalRow = { - cursor += 1 - if (cursor < buffer.length) { - buffer(cursor) - } else { - null - } - } - - /** Skip the next `n` rows. */ - override def skip(n: Int): Unit = { - cursor += n - } - - /** Return a new RowBuffer that has the same rows. */ - override def copy(): RowBuffer = { - new ArrayRowBuffer(buffer) - } -} - -/** - * An external buffer of rows based on UnsafeExternalSorter. - */ -private[window] class ExternalRowBuffer(sorter: UnsafeExternalSorter, numFields: Int) - extends RowBuffer { - - private[this] val iter: UnsafeSorterIterator = sorter.getIterator - - private[this] val currentRow = new UnsafeRow(numFields) - - /** Number of rows. */ - override def size: Int = iter.getNumRecords() - - /** Return next row in the buffer, null if no more left. */ - override def next(): InternalRow = { - if (iter.hasNext) { - iter.loadNext() - currentRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength) - currentRow - } else { - null - } - } - - /** Skip the next `n` rows. */ - override def skip(n: Int): Unit = { - var i = 0 - while (i < n && iter.hasNext) { - iter.loadNext() - i += 1 - } - } - - /** Return a new RowBuffer that has the same rows. */ - override def copy(): RowBuffer = { - new ExternalRowBuffer(sorter, numFields) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala index 80b87d5ffa..950a6794a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala @@ -20,15 +20,13 @@ package org.apache.spark.sql.execution.window import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, SparkPlan, UnaryExecNode} import org.apache.spark.sql.types.IntegerType -import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter /** * This class calculates and outputs (windowed) aggregates over the rows in a single (sorted) @@ -284,6 +282,7 @@ case class WindowExec( // Unwrap the expressions and factories from the map. val expressions = windowFrameExpressionFactoryPairs.flatMap(_._1) val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray + val spillThreshold = sqlContext.conf.windowExecBufferSpillThreshold // Start processing. child.execute().mapPartitions { stream => @@ -310,10 +309,12 @@ case class WindowExec( fetchNextRow() // Manage the current partition. - val rows = ArrayBuffer.empty[UnsafeRow] val inputFields = child.output.length - var sorter: UnsafeExternalSorter = null - var rowBuffer: RowBuffer = null + + val buffer: ExternalAppendOnlyUnsafeRowArray = + new ExternalAppendOnlyUnsafeRowArray(spillThreshold) + var bufferIterator: Iterator[UnsafeRow] = _ + val windowFunctionResult = new SpecificInternalRow(expressions.map(_.dataType)) val frames = factories.map(_(windowFunctionResult)) val numFrames = frames.length @@ -323,78 +324,43 @@ case class WindowExec( val currentGroup = nextGroup.copy() // clear last partition - if (sorter != null) { - // the last sorter of this task will be cleaned up via task completion listener - sorter.cleanupResources() - sorter = null - } else { - rows.clear() - } + buffer.clear() while (nextRowAvailable && nextGroup == currentGroup) { - if (sorter == null) { - rows += nextRow.copy() - - if (rows.length >= 4096) { - // We will not sort the rows, so prefixComparator and recordComparator are null. - sorter = UnsafeExternalSorter.create( - TaskContext.get().taskMemoryManager(), - SparkEnv.get.blockManager, - SparkEnv.get.serializerManager, - TaskContext.get(), - null, - null, - 1024, - SparkEnv.get.memoryManager.pageSizeBytes, - SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", - UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD), - false) - rows.foreach { r => - sorter.insertRecord(r.getBaseObject, r.getBaseOffset, r.getSizeInBytes, 0, false) - } - rows.clear() - } - } else { - sorter.insertRecord(nextRow.getBaseObject, nextRow.getBaseOffset, - nextRow.getSizeInBytes, 0, false) - } + buffer.add(nextRow) fetchNextRow() } - if (sorter != null) { - rowBuffer = new ExternalRowBuffer(sorter, inputFields) - } else { - rowBuffer = new ArrayRowBuffer(rows) - } // Setup the frames. var i = 0 while (i < numFrames) { - frames(i).prepare(rowBuffer.copy()) + frames(i).prepare(buffer) i += 1 } // Setup iteration rowIndex = 0 - rowsSize = rowBuffer.size + bufferIterator = buffer.generateIterator() } // Iteration var rowIndex = 0 - var rowsSize = 0L - override final def hasNext: Boolean = rowIndex < rowsSize || nextRowAvailable + override final def hasNext: Boolean = + (bufferIterator != null && bufferIterator.hasNext) || nextRowAvailable val join = new JoinedRow override final def next(): InternalRow = { // Load the next partition if we need to. - if (rowIndex >= rowsSize && nextRowAvailable) { + if ((bufferIterator == null || !bufferIterator.hasNext) && nextRowAvailable) { fetchNextPartition() } - if (rowIndex < rowsSize) { + if (bufferIterator.hasNext) { + val current = bufferIterator.next() + // Get the results for the window frames. var i = 0 - val current = rowBuffer.next() while (i < numFrames) { frames(i).write(rowIndex, current) i += 1 @@ -406,7 +372,9 @@ case class WindowExec( // Return the projection. result(join) - } else throw new NoSuchElementException + } else { + throw new NoSuchElementException + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala index 70efc0f78d..af2b4fb920 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala @@ -22,6 +22,7 @@ import java.util import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp +import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray /** @@ -35,7 +36,7 @@ private[window] abstract class WindowFunctionFrame { * * @param rows to calculate the frame results for. */ - def prepare(rows: RowBuffer): Unit + def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit /** * Write the current results to the target row. @@ -43,6 +44,12 @@ private[window] abstract class WindowFunctionFrame { def write(index: Int, current: InternalRow): Unit } +object WindowFunctionFrame { + def getNextOrNull(iterator: Iterator[UnsafeRow]): UnsafeRow = { + if (iterator.hasNext) iterator.next() else null + } +} + /** * The offset window frame calculates frames containing LEAD/LAG statements. * @@ -65,7 +72,12 @@ private[window] final class OffsetWindowFunctionFrame( extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ - private[this] var input: RowBuffer = null + private[this] var input: ExternalAppendOnlyUnsafeRowArray = null + + /** + * An iterator over the [[input]] + */ + private[this] var inputIterator: Iterator[UnsafeRow] = _ /** Index of the input row currently used for output. */ private[this] var inputIndex = 0 @@ -103,20 +115,21 @@ private[window] final class OffsetWindowFunctionFrame( newMutableProjection(boundExpressions, Nil).target(target) } - override def prepare(rows: RowBuffer): Unit = { + override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = { input = rows + inputIterator = input.generateIterator() // drain the first few rows if offset is larger than zero inputIndex = 0 while (inputIndex < offset) { - input.next() + if (inputIterator.hasNext) inputIterator.next() inputIndex += 1 } inputIndex = offset } override def write(index: Int, current: InternalRow): Unit = { - if (inputIndex >= 0 && inputIndex < input.size) { - val r = input.next() + if (inputIndex >= 0 && inputIndex < input.length) { + val r = WindowFunctionFrame.getNextOrNull(inputIterator) projection(r) } else { // Use default values since the offset row does not exist. @@ -143,7 +156,12 @@ private[window] final class SlidingWindowFunctionFrame( extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ - private[this] var input: RowBuffer = null + private[this] var input: ExternalAppendOnlyUnsafeRowArray = null + + /** + * An iterator over the [[input]] + */ + private[this] var inputIterator: Iterator[UnsafeRow] = _ /** The next row from `input`. */ private[this] var nextRow: InternalRow = null @@ -164,9 +182,10 @@ private[window] final class SlidingWindowFunctionFrame( private[this] var inputLowIndex = 0 /** Prepare the frame for calculating a new partition. Reset all variables. */ - override def prepare(rows: RowBuffer): Unit = { + override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = { input = rows - nextRow = rows.next() + inputIterator = input.generateIterator() + nextRow = WindowFunctionFrame.getNextOrNull(inputIterator) inputHighIndex = 0 inputLowIndex = 0 buffer.clear() @@ -180,7 +199,7 @@ private[window] final class SlidingWindowFunctionFrame( // the output row upper bound. while (nextRow != null && ubound.compare(nextRow, inputHighIndex, current, index) <= 0) { buffer.add(nextRow.copy()) - nextRow = input.next() + nextRow = WindowFunctionFrame.getNextOrNull(inputIterator) inputHighIndex += 1 bufferUpdated = true } @@ -195,7 +214,7 @@ private[window] final class SlidingWindowFunctionFrame( // Only recalculate and update when the buffer changes. if (bufferUpdated) { - processor.initialize(input.size) + processor.initialize(input.length) val iter = buffer.iterator() while (iter.hasNext) { processor.update(iter.next()) @@ -222,13 +241,12 @@ private[window] final class UnboundedWindowFunctionFrame( extends WindowFunctionFrame { /** Prepare the frame for calculating a new partition. Process all rows eagerly. */ - override def prepare(rows: RowBuffer): Unit = { - val size = rows.size - processor.initialize(size) - var i = 0 - while (i < size) { - processor.update(rows.next()) - i += 1 + override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = { + processor.initialize(rows.length) + + val iterator = rows.generateIterator() + while (iterator.hasNext) { + processor.update(iterator.next()) } } @@ -261,7 +279,12 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame( extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ - private[this] var input: RowBuffer = null + private[this] var input: ExternalAppendOnlyUnsafeRowArray = null + + /** + * An iterator over the [[input]] + */ + private[this] var inputIterator: Iterator[UnsafeRow] = _ /** The next row from `input`. */ private[this] var nextRow: InternalRow = null @@ -273,11 +296,15 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame( private[this] var inputIndex = 0 /** Prepare the frame for calculating a new partition. */ - override def prepare(rows: RowBuffer): Unit = { + override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = { input = rows - nextRow = rows.next() inputIndex = 0 - processor.initialize(input.size) + inputIterator = input.generateIterator() + if (inputIterator.hasNext) { + nextRow = inputIterator.next() + } + + processor.initialize(input.length) } /** Write the frame columns for the current row to the given target row. */ @@ -288,7 +315,7 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame( // the output row upper bound. while (nextRow != null && ubound.compare(nextRow, inputIndex, current, index) <= 0) { processor.update(nextRow) - nextRow = input.next() + nextRow = WindowFunctionFrame.getNextOrNull(inputIterator) inputIndex += 1 bufferUpdated = true } @@ -323,7 +350,7 @@ private[window] final class UnboundedFollowingWindowFunctionFrame( extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ - private[this] var input: RowBuffer = null + private[this] var input: ExternalAppendOnlyUnsafeRowArray = null /** * Index of the first input row with a value equal to or greater than the lower bound of the @@ -332,7 +359,7 @@ private[window] final class UnboundedFollowingWindowFunctionFrame( private[this] var inputIndex = 0 /** Prepare the frame for calculating a new partition. */ - override def prepare(rows: RowBuffer): Unit = { + override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = { input = rows inputIndex = 0 } @@ -341,25 +368,25 @@ private[window] final class UnboundedFollowingWindowFunctionFrame( override def write(index: Int, current: InternalRow): Unit = { var bufferUpdated = index == 0 - // Duplicate the input to have a new iterator - val tmp = input.copy() - - // Drop all rows from the buffer for which the input row value is smaller than + // Ignore all the rows from the buffer for which the input row value is smaller than // the output row lower bound. - tmp.skip(inputIndex) - var nextRow = tmp.next() + val iterator = input.generateIterator(startIndex = inputIndex) + + var nextRow = WindowFunctionFrame.getNextOrNull(iterator) while (nextRow != null && lbound.compare(nextRow, inputIndex, current, index) < 0) { - nextRow = tmp.next() inputIndex += 1 bufferUpdated = true + nextRow = WindowFunctionFrame.getNextOrNull(iterator) } // Only recalculate and update when the buffer changes. if (bufferUpdated) { - processor.initialize(input.size) - while (nextRow != null) { + processor.initialize(input.length) + if (nextRow != null) { processor.update(nextRow) - nextRow = tmp.next() + } + while (iterator.hasNext) { + processor.update(iterator.next()) } processor.evaluate(target) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 2e006735d1..1a66aa85f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import scala.collection.mutable.ListBuffer import scala.language.existentials import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation @@ -24,7 +25,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext - +import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} class JoinSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -604,4 +605,137 @@ class JoinSuite extends QueryTest with SharedSQLContext { cartesianQueries.foreach(checkCartesianDetection) } + + test("test SortMergeJoin (without spill)") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1", + "spark.sql.sortMergeJoinExec.buffer.spill.threshold" -> Int.MaxValue.toString) { + + assertNotSpilled(sparkContext, "inner join") { + checkAnswer( + sql("SELECT * FROM testData JOIN testData2 ON key = a where key = 2"), + Row(2, "2", 2, 1) :: Row(2, "2", 2, 2) :: Nil + ) + } + + val expected = new ListBuffer[Row]() + expected.append( + Row(1, "1", 1, 1), Row(1, "1", 1, 2), + Row(2, "2", 2, 1), Row(2, "2", 2, 2), + Row(3, "3", 3, 1), Row(3, "3", 3, 2) + ) + for (i <- 4 to 100) { + expected.append(Row(i, i.toString, null, null)) + } + + assertNotSpilled(sparkContext, "left outer join") { + checkAnswer( + sql( + """ + |SELECT + | big.key, big.value, small.a, small.b + |FROM + | testData big + |LEFT OUTER JOIN + | testData2 small + |ON + | big.key = small.a + """.stripMargin), + expected + ) + } + + assertNotSpilled(sparkContext, "right outer join") { + checkAnswer( + sql( + """ + |SELECT + | big.key, big.value, small.a, small.b + |FROM + | testData2 small + |RIGHT OUTER JOIN + | testData big + |ON + | big.key = small.a + """.stripMargin), + expected + ) + } + } + } + + test("test SortMergeJoin (with spill)") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1", + "spark.sql.sortMergeJoinExec.buffer.spill.threshold" -> "0") { + + assertSpilled(sparkContext, "inner join") { + checkAnswer( + sql("SELECT * FROM testData JOIN testData2 ON key = a where key = 2"), + Row(2, "2", 2, 1) :: Row(2, "2", 2, 2) :: Nil + ) + } + + val expected = new ListBuffer[Row]() + expected.append( + Row(1, "1", 1, 1), Row(1, "1", 1, 2), + Row(2, "2", 2, 1), Row(2, "2", 2, 2), + Row(3, "3", 3, 1), Row(3, "3", 3, 2) + ) + for (i <- 4 to 100) { + expected.append(Row(i, i.toString, null, null)) + } + + assertSpilled(sparkContext, "left outer join") { + checkAnswer( + sql( + """ + |SELECT + | big.key, big.value, small.a, small.b + |FROM + | testData big + |LEFT OUTER JOIN + | testData2 small + |ON + | big.key = small.a + """.stripMargin), + expected + ) + } + + assertSpilled(sparkContext, "right outer join") { + checkAnswer( + sql( + """ + |SELECT + | big.key, big.value, small.a, small.b + |FROM + | testData2 small + |RIGHT OUTER JOIN + | testData big + |ON + | big.key = small.a + """.stripMargin), + expected + ) + } + + // FULL OUTER JOIN still does not use [[ExternalAppendOnlyUnsafeRowArray]] + // so should not cause any spill + assertNotSpilled(sparkContext, "full outer join") { + checkAnswer( + sql( + """ + |SELECT + | big.key, big.value, small.a, small.b + |FROM + | testData2 small + |FULL OUTER JOIN + | testData big + |ON + | big.key = small.a + """.stripMargin), + expected + ) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala new file mode 100644 index 0000000000..00c5f2550c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{SparkConf, SparkContext, SparkEnv, TaskContext} +import org.apache.spark.memory.MemoryTestingUtils +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.util.Benchmark +import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter + +object ExternalAppendOnlyUnsafeRowArrayBenchmark { + + def testAgainstRawArrayBuffer(numSpillThreshold: Int, numRows: Int, iterations: Int): Unit = { + val random = new java.util.Random() + val rows = (1 to numRows).map(_ => { + val row = new UnsafeRow(1) + row.pointTo(new Array[Byte](64), 16) + row.setLong(0, random.nextLong()) + row + }) + + val benchmark = new Benchmark(s"Array with $numRows rows", iterations * numRows) + + // Internally, `ExternalAppendOnlyUnsafeRowArray` will create an + // in-memory buffer of size `numSpillThreshold`. This will mimic that + val initialSize = + Math.min( + ExternalAppendOnlyUnsafeRowArray.DefaultInitialSizeOfInMemoryBuffer, + numSpillThreshold) + + benchmark.addCase("ArrayBuffer") { _: Int => + var sum = 0L + for (_ <- 0L until iterations) { + val array = new ArrayBuffer[UnsafeRow](initialSize) + + // Internally, `ExternalAppendOnlyUnsafeRowArray` will create a + // copy of the row. This will mimic that + rows.foreach(x => array += x.copy()) + + var i = 0 + val n = array.length + while (i < n) { + sum = sum + array(i).getLong(0) + i += 1 + } + array.clear() + } + } + + benchmark.addCase("ExternalAppendOnlyUnsafeRowArray") { _: Int => + var sum = 0L + for (_ <- 0L until iterations) { + val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold) + rows.foreach(x => array.add(x)) + + val iterator = array.generateIterator() + while (iterator.hasNext) { + sum = sum + iterator.next().getLong(0) + } + array.clear() + } + } + + val conf = new SparkConf(false) + // Make the Java serializer write a reset instruction (TC_RESET) after each object to test + // for a bug we had with bytes written past the last object in a batch (SPARK-2792) + conf.set("spark.serializer.objectStreamReset", "1") + conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer") + + val sc = new SparkContext("local", "test", conf) + val taskContext = MemoryTestingUtils.fakeTaskContext(SparkEnv.get) + TaskContext.setTaskContext(taskContext) + benchmark.run() + sc.stop() + } + + def testAgainstRawUnsafeExternalSorter( + numSpillThreshold: Int, + numRows: Int, + iterations: Int): Unit = { + + val random = new java.util.Random() + val rows = (1 to numRows).map(_ => { + val row = new UnsafeRow(1) + row.pointTo(new Array[Byte](64), 16) + row.setLong(0, random.nextLong()) + row + }) + + val benchmark = new Benchmark(s"Spilling with $numRows rows", iterations * numRows) + + benchmark.addCase("UnsafeExternalSorter") { _: Int => + var sum = 0L + for (_ <- 0L until iterations) { + val array = UnsafeExternalSorter.create( + TaskContext.get().taskMemoryManager(), + SparkEnv.get.blockManager, + SparkEnv.get.serializerManager, + TaskContext.get(), + null, + null, + 1024, + SparkEnv.get.memoryManager.pageSizeBytes, + numSpillThreshold, + false) + + rows.foreach(x => + array.insertRecord( + x.getBaseObject, + x.getBaseOffset, + x.getSizeInBytes, + 0, + false)) + + val unsafeRow = new UnsafeRow(1) + val iter = array.getIterator + while (iter.hasNext) { + iter.loadNext() + unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength) + sum = sum + unsafeRow.getLong(0) + } + array.cleanupResources() + } + } + + benchmark.addCase("ExternalAppendOnlyUnsafeRowArray") { _: Int => + var sum = 0L + for (_ <- 0L until iterations) { + val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold) + rows.foreach(x => array.add(x)) + + val iterator = array.generateIterator() + while (iterator.hasNext) { + sum = sum + iterator.next().getLong(0) + } + array.clear() + } + } + + val conf = new SparkConf(false) + // Make the Java serializer write a reset instruction (TC_RESET) after each object to test + // for a bug we had with bytes written past the last object in a batch (SPARK-2792) + conf.set("spark.serializer.objectStreamReset", "1") + conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer") + + val sc = new SparkContext("local", "test", conf) + val taskContext = MemoryTestingUtils.fakeTaskContext(SparkEnv.get) + TaskContext.setTaskContext(taskContext) + benchmark.run() + sc.stop() + } + + def main(args: Array[String]): Unit = { + + // ========================================================================================= // + // WITHOUT SPILL + // ========================================================================================= // + + val spillThreshold = 100 * 1000 + + /* + Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz + + Array with 1000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + ArrayBuffer 7821 / 7941 33.5 29.8 1.0X + ExternalAppendOnlyUnsafeRowArray 8798 / 8819 29.8 33.6 0.9X + */ + testAgainstRawArrayBuffer(spillThreshold, 1000, 1 << 18) + + /* + Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz + + Array with 30000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + ArrayBuffer 19200 / 19206 25.6 39.1 1.0X + ExternalAppendOnlyUnsafeRowArray 19558 / 19562 25.1 39.8 1.0X + */ + testAgainstRawArrayBuffer(spillThreshold, 30 * 1000, 1 << 14) + + /* + Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz + + Array with 100000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + ArrayBuffer 5949 / 6028 17.2 58.1 1.0X + ExternalAppendOnlyUnsafeRowArray 6078 / 6138 16.8 59.4 1.0X + */ + testAgainstRawArrayBuffer(spillThreshold, 100 * 1000, 1 << 10) + + // ========================================================================================= // + // WITH SPILL + // ========================================================================================= // + + /* + Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz + + Spilling with 1000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + UnsafeExternalSorter 9239 / 9470 28.4 35.2 1.0X + ExternalAppendOnlyUnsafeRowArray 8857 / 8909 29.6 33.8 1.0X + */ + testAgainstRawUnsafeExternalSorter(100 * 1000, 1000, 1 << 18) + + /* + Intel(R) Core(TM) i7-6920HQ CPU @ 2.90GHz + + Spilling with 10000 rows: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + UnsafeExternalSorter 4 / 5 39.3 25.5 1.0X + ExternalAppendOnlyUnsafeRowArray 5 / 6 29.8 33.5 0.8X + */ + testAgainstRawUnsafeExternalSorter( + UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt, 10 * 1000, 1 << 4) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala new file mode 100644 index 0000000000..53c4163994 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala @@ -0,0 +1,351 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.util.ConcurrentModificationException + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark._ +import org.apache.spark.memory.MemoryTestingUtils +import org.apache.spark.sql.catalyst.expressions.UnsafeRow + +class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSparkContext { + private val random = new java.util.Random() + private var taskContext: TaskContext = _ + + override def afterAll(): Unit = TaskContext.unset() + + private def withExternalArray(spillThreshold: Int) + (f: ExternalAppendOnlyUnsafeRowArray => Unit): Unit = { + sc = new SparkContext("local", "test", new SparkConf(false)) + + taskContext = MemoryTestingUtils.fakeTaskContext(SparkEnv.get) + TaskContext.setTaskContext(taskContext) + + val array = new ExternalAppendOnlyUnsafeRowArray( + taskContext.taskMemoryManager(), + SparkEnv.get.blockManager, + SparkEnv.get.serializerManager, + taskContext, + 1024, + SparkEnv.get.memoryManager.pageSizeBytes, + spillThreshold) + try f(array) finally { + array.clear() + } + } + + private def insertRow(array: ExternalAppendOnlyUnsafeRowArray): Long = { + val valueInserted = random.nextLong() + + val row = new UnsafeRow(1) + row.pointTo(new Array[Byte](64), 16) + row.setLong(0, valueInserted) + array.add(row) + valueInserted + } + + private def checkIfValueExists(iterator: Iterator[UnsafeRow], expectedValue: Long): Unit = { + assert(iterator.hasNext) + val actualRow = iterator.next() + assert(actualRow.getLong(0) == expectedValue) + assert(actualRow.getSizeInBytes == 16) + } + + private def validateData( + array: ExternalAppendOnlyUnsafeRowArray, + expectedValues: ArrayBuffer[Long]): Iterator[UnsafeRow] = { + val iterator = array.generateIterator() + for (value <- expectedValues) { + checkIfValueExists(iterator, value) + } + + assert(!iterator.hasNext) + iterator + } + + private def populateRows( + array: ExternalAppendOnlyUnsafeRowArray, + numRowsToBePopulated: Int): ArrayBuffer[Long] = { + val populatedValues = new ArrayBuffer[Long] + populateRows(array, numRowsToBePopulated, populatedValues) + } + + private def populateRows( + array: ExternalAppendOnlyUnsafeRowArray, + numRowsToBePopulated: Int, + populatedValues: ArrayBuffer[Long]): ArrayBuffer[Long] = { + for (_ <- 0 until numRowsToBePopulated) { + populatedValues.append(insertRow(array)) + } + populatedValues + } + + private def getNumBytesSpilled: Long = { + TaskContext.get().taskMetrics().memoryBytesSpilled + } + + private def assertNoSpill(): Unit = { + assert(getNumBytesSpilled == 0) + } + + private def assertSpill(): Unit = { + assert(getNumBytesSpilled > 0) + } + + test("insert rows less than the spillThreshold") { + val spillThreshold = 100 + withExternalArray(spillThreshold) { array => + assert(array.isEmpty) + + val expectedValues = populateRows(array, 1) + assert(!array.isEmpty) + assert(array.length == 1) + + val iterator1 = validateData(array, expectedValues) + + // Add more rows (but not too many to trigger switch to [[UnsafeExternalSorter]]) + // Verify that NO spill has happened + populateRows(array, spillThreshold - 1, expectedValues) + assert(array.length == spillThreshold) + assertNoSpill() + + val iterator2 = validateData(array, expectedValues) + + assert(!iterator1.hasNext) + assert(!iterator2.hasNext) + } + } + + test("insert rows more than the spillThreshold to force spill") { + val spillThreshold = 100 + withExternalArray(spillThreshold) { array => + val numValuesInserted = 20 * spillThreshold + + assert(array.isEmpty) + val expectedValues = populateRows(array, 1) + assert(array.length == 1) + + val iterator1 = validateData(array, expectedValues) + + // Populate more rows to trigger spill. Verify that spill has happened + populateRows(array, numValuesInserted - 1, expectedValues) + assert(array.length == numValuesInserted) + assertSpill() + + val iterator2 = validateData(array, expectedValues) + assert(!iterator2.hasNext) + + assert(!iterator1.hasNext) + intercept[ConcurrentModificationException](iterator1.next()) + } + } + + test("iterator on an empty array should be empty") { + withExternalArray(spillThreshold = 10) { array => + val iterator = array.generateIterator() + assert(array.isEmpty) + assert(array.length == 0) + assert(!iterator.hasNext) + } + } + + test("generate iterator with negative start index") { + withExternalArray(spillThreshold = 2) { array => + val exception = + intercept[ArrayIndexOutOfBoundsException](array.generateIterator(startIndex = -10)) + + assert(exception.getMessage.contains( + "Invalid `startIndex` provided for generating iterator over the array") + ) + } + } + + test("generate iterator with start index exceeding array's size (without spill)") { + val spillThreshold = 2 + withExternalArray(spillThreshold) { array => + populateRows(array, spillThreshold / 2) + + val exception = + intercept[ArrayIndexOutOfBoundsException]( + array.generateIterator(startIndex = spillThreshold * 10)) + assert(exception.getMessage.contains( + "Invalid `startIndex` provided for generating iterator over the array")) + } + } + + test("generate iterator with start index exceeding array's size (with spill)") { + val spillThreshold = 2 + withExternalArray(spillThreshold) { array => + populateRows(array, spillThreshold * 2) + + val exception = + intercept[ArrayIndexOutOfBoundsException]( + array.generateIterator(startIndex = spillThreshold * 10)) + + assert(exception.getMessage.contains( + "Invalid `startIndex` provided for generating iterator over the array")) + } + } + + test("generate iterator with custom start index (without spill)") { + val spillThreshold = 10 + withExternalArray(spillThreshold) { array => + val expectedValues = populateRows(array, spillThreshold) + val startIndex = spillThreshold / 2 + val iterator = array.generateIterator(startIndex = startIndex) + for (i <- startIndex until expectedValues.length) { + checkIfValueExists(iterator, expectedValues(i)) + } + } + } + + test("generate iterator with custom start index (with spill)") { + val spillThreshold = 10 + withExternalArray(spillThreshold) { array => + val expectedValues = populateRows(array, spillThreshold * 10) + val startIndex = spillThreshold * 2 + val iterator = array.generateIterator(startIndex = startIndex) + for (i <- startIndex until expectedValues.length) { + checkIfValueExists(iterator, expectedValues(i)) + } + } + } + + test("test iterator invalidation (without spill)") { + withExternalArray(spillThreshold = 10) { array => + // insert 2 rows, iterate until the first row + populateRows(array, 2) + + var iterator = array.generateIterator() + assert(iterator.hasNext) + iterator.next() + + // Adding more row(s) should invalidate any old iterators + populateRows(array, 1) + assert(!iterator.hasNext) + intercept[ConcurrentModificationException](iterator.next()) + + // Clearing the array should also invalidate any old iterators + iterator = array.generateIterator() + assert(iterator.hasNext) + iterator.next() + + array.clear() + assert(!iterator.hasNext) + intercept[ConcurrentModificationException](iterator.next()) + } + } + + test("test iterator invalidation (with spill)") { + val spillThreshold = 10 + withExternalArray(spillThreshold) { array => + // Populate enough rows so that spill has happens + populateRows(array, spillThreshold * 2) + assertSpill() + + var iterator = array.generateIterator() + assert(iterator.hasNext) + iterator.next() + + // Adding more row(s) should invalidate any old iterators + populateRows(array, 1) + assert(!iterator.hasNext) + intercept[ConcurrentModificationException](iterator.next()) + + // Clearing the array should also invalidate any old iterators + iterator = array.generateIterator() + assert(iterator.hasNext) + iterator.next() + + array.clear() + assert(!iterator.hasNext) + intercept[ConcurrentModificationException](iterator.next()) + } + } + + test("clear on an empty the array") { + withExternalArray(spillThreshold = 2) { array => + val iterator = array.generateIterator() + assert(!iterator.hasNext) + + // multiple clear'ing should not have an side-effect + array.clear() + array.clear() + array.clear() + assert(array.isEmpty) + assert(array.length == 0) + + // Clearing an empty array should also invalidate any old iterators + assert(!iterator.hasNext) + intercept[ConcurrentModificationException](iterator.next()) + } + } + + test("clear array (without spill)") { + val spillThreshold = 10 + withExternalArray(spillThreshold) { array => + // Populate rows ... but not enough to trigger spill + populateRows(array, spillThreshold / 2) + assertNoSpill() + + // Clear the array + array.clear() + assert(array.isEmpty) + + // Re-populate few rows so that there is no spill + // Verify the data. Verify that there was no spill + val expectedValues = populateRows(array, spillThreshold / 3) + validateData(array, expectedValues) + assertNoSpill() + + // Populate more rows .. enough to not trigger a spill. + // Verify the data. Verify that there was no spill + populateRows(array, spillThreshold / 3, expectedValues) + validateData(array, expectedValues) + assertNoSpill() + } + } + + test("clear array (with spill)") { + val spillThreshold = 10 + withExternalArray(spillThreshold) { array => + // Populate enough rows to trigger spill + populateRows(array, spillThreshold * 2) + val bytesSpilled = getNumBytesSpilled + assert(bytesSpilled > 0) + + // Clear the array + array.clear() + assert(array.isEmpty) + + // Re-populate the array ... but NOT upto the point that there is spill. + // Verify data. Verify that there was NO "extra" spill + val expectedValues = populateRows(array, spillThreshold / 2) + validateData(array, expectedValues) + assert(getNumBytesSpilled == bytesSpilled) + + // Populate more rows to trigger spill + // Verify the data. Verify that there was "extra" spill + populateRows(array, spillThreshold * 2, expectedValues) + validateData(array, expectedValues) + assert(getNumBytesSpilled > bytesSpilled) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala index afd47897ed..52e4f04722 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLWindowFunctionSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.TestUtils.assertSpilled case class WindowData(month: Int, area: String, product: Int) @@ -412,4 +413,36 @@ class SQLWindowFunctionSuite extends QueryTest with SharedSQLContext { """.stripMargin), Row(1, 3, null) :: Row(2, null, 4) :: Nil) } + + test("test with low buffer spill threshold") { + val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y") + nums.createOrReplaceTempView("nums") + + val expected = + Row(1, 1, 1) :: + Row(0, 2, 3) :: + Row(1, 3, 6) :: + Row(0, 4, 10) :: + Row(1, 5, 15) :: + Row(0, 6, 21) :: + Row(1, 7, 28) :: + Row(0, 8, 36) :: + Row(1, 9, 45) :: + Row(0, 10, 55) :: Nil + + val actual = sql( + """ + |SELECT y, x, sum(x) OVER w1 AS running_sum + |FROM nums + |WINDOW w1 AS (ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDiNG AND CURRENT RoW) + """.stripMargin) + + withSQLConf("spark.sql.windowExec.buffer.spill.threshold" -> "1") { + assertSpilled(sparkContext, "test with low buffer spill threshold") { + checkAnswer(actual, expected) + } + } + + spark.catalog.dropTempView("nums") + } } |