aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-08-02 12:32:14 -0700
committerJosh Rosen <joshrosen@databricks.com>2015-08-02 12:32:14 -0700
commit2e981b7bfa9dec93fdcf25f3e7220cd6aaba744f (patch)
treef7458ae297d36bba1acf21fd08169defef6c2ef8
parent66924ffa6bdb8e0df1b90b789cb7ad443377e729 (diff)
downloadspark-2e981b7bfa9dec93fdcf25f3e7220cd6aaba744f.tar.gz
spark-2e981b7bfa9dec93fdcf25f3e7220cd6aaba744f.tar.bz2
spark-2e981b7bfa9dec93fdcf25f3e7220cd6aaba744f.zip
[SPARK-9531] [SQL] UnsafeFixedWidthAggregationMap.destructAndCreateExternalSorter
This pull request adds a destructAndCreateExternalSorter method to UnsafeFixedWidthAggregationMap. The new method does the following: 1. Creates a new external sorter UnsafeKVExternalSorter 2. Adds all the data into an in-memory sorter, sorts them 3. Spills the sorted in-memory data to disk This method can be used to fallback to sort-based aggregation when under memory pressure. The pull request also includes accounting fixes from JoshRosen. TODOs (that can be done in follow-up PRs) - [x] Address Josh's feedbacks from #7849 - [x] More documentation and test cases - [x] Make sure we are doing memory accounting correctly with test cases (e.g. did we release the memory in BytesToBytesMap twice?) - [ ] Look harder at possible memory leaks and exception handling - [ ] Randomized tester for the KV sorter as well as the aggregation map Author: Reynold Xin <rxin@databricks.com> Author: Josh Rosen <joshrosen@databricks.com> Closes #7860 from rxin/kvsorter and squashes the following commits: 986a58c [Reynold Xin] Bug fix. 599317c [Reynold Xin] Style fix and slightly more compact code. fe7bd4e [Reynold Xin] Bug fixes. fd71bef [Reynold Xin] Merge remote-tracking branch 'josh/large-records-in-sql-sorter' into kvsorter-with-josh-fix 3efae38 [Reynold Xin] More fixes and documentation. 45f1b09 [Josh Rosen] Ensure that spill files are cleaned up f6a9bd3 [Reynold Xin] Josh feedback. 9be8139 [Reynold Xin] Remove testSpillFrequency. 7cbe759 [Reynold Xin] [SPARK-9531][SQL] UnsafeFixedWidthAggregationMap.destructAndCreateExternalSorter. ae4a8af [Josh Rosen] Detect leaked unsafe memory in UnsafeExternalSorterSuite. 52f9b06 [Josh Rosen] Detect ShuffleMemoryManager leaks in UnsafeExternalSorter.
-rw-r--r--core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java32
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java197
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java4
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java3
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java4
-rw-r--r--core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java7
-rw-r--r--core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java65
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java3
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java9
-rw-r--r--sql/core/pom.xml5
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java103
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java236
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala33
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala51
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala124
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala158
17 files changed, 823 insertions, 215 deletions
diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index cf222b7272..01a66084e9 100644
--- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -39,14 +39,22 @@ import org.apache.spark.unsafe.memory.*;
/**
* An append-only hash map where keys and values are contiguous regions of bytes.
- * <p>
+ *
* This is backed by a power-of-2-sized hash table, using quadratic probing with triangular numbers,
* which is guaranteed to exhaust the space.
- * <p>
+ *
* The map can support up to 2^29 keys. If the key cardinality is higher than this, you should
* probably be using sorting instead of hashing for better cache locality.
- * <p>
- * This class is not thread safe.
+ *
+ * The key and values under the hood are stored together, in the following format:
+ * Bytes 0 to 4: len(k) (key length in bytes) + len(v) (value length in bytes) + 4
+ * Bytes 4 to 8: len(k)
+ * Bytes 8 to 8 + len(k): key data
+ * Bytes 8 + len(k) to 8 + len(k) + len(v): value data
+ *
+ * This means that the first four bytes store the entire record (key + value) length. This format
+ * is consistent with {@link org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter},
+ * so we can pass records from this map directly into the sorter to sort records in place.
*/
public final class BytesToBytesMap {
@@ -253,7 +261,7 @@ public final class BytesToBytesMap {
totalLength = PlatformDependent.UNSAFE.getInt(pageBaseObject, offsetInPage);
}
loc.with(currentPage, offsetInPage);
- offsetInPage += 8 + totalLength;
+ offsetInPage += 4 + totalLength;
currentRecordNumber++;
return loc;
}
@@ -366,7 +374,7 @@ public final class BytesToBytesMap {
position += 4;
keyLength = PlatformDependent.UNSAFE.getInt(page, position);
position += 4;
- valueLength = totalLength - keyLength;
+ valueLength = totalLength - keyLength - 4;
keyMemoryLocation.setObjAndOffset(page, position);
@@ -565,7 +573,7 @@ public final class BytesToBytesMap {
insertCursor += valueLengthBytes; // word used to store the value size
PlatformDependent.UNSAFE.putInt(dataPageBaseObject, recordOffset,
- keyLengthBytes + valueLengthBytes);
+ keyLengthBytes + valueLengthBytes + 4);
PlatformDependent.UNSAFE.putInt(dataPageBaseObject, keyLengthOffset, keyLengthBytes);
// Copy the key
PlatformDependent.copyMemory(
@@ -620,7 +628,7 @@ public final class BytesToBytesMap {
* Free all allocated memory associated with this map, including the storage for keys and values
* as well as the hash map array itself.
*
- * This method is idempotent.
+ * This method is idempotent and can be called multiple times.
*/
public void free() {
longArray = null;
@@ -639,6 +647,14 @@ public final class BytesToBytesMap {
return taskMemoryManager;
}
+ public ShuffleMemoryManager getShuffleMemoryManager() {
+ return shuffleMemoryManager;
+ }
+
+ public long getPageSizeBytes() {
+ return pageSizeBytes;
+ }
+
/** Returns the total amount of memory, in bytes, consumed by this map's managed structures. */
public long getTotalMemoryConsumption() {
long totalDataPagesSize = 0L;
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index c05f2c332e..b984301cbb 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -17,9 +17,12 @@
package org.apache.spark.util.collection.unsafe.sort;
+import java.io.File;
import java.io.IOException;
import java.util.LinkedList;
+import javax.annotation.Nullable;
+
import scala.runtime.AbstractFunction0;
import scala.runtime.BoxedUnit;
@@ -27,7 +30,6 @@ import com.google.common.annotations.VisibleForTesting;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.shuffle.ShuffleMemoryManager;
@@ -48,7 +50,7 @@ public final class UnsafeExternalSorter {
private final PrefixComparator prefixComparator;
private final RecordComparator recordComparator;
private final int initialSize;
- private final TaskMemoryManager memoryManager;
+ private final TaskMemoryManager taskMemoryManager;
private final ShuffleMemoryManager shuffleMemoryManager;
private final BlockManager blockManager;
private final TaskContext taskContext;
@@ -63,26 +65,57 @@ public final class UnsafeExternalSorter {
* this might not be necessary if we maintained a pool of re-usable pages in the TaskMemoryManager
* itself).
*/
- private final LinkedList<MemoryBlock> allocatedPages = new LinkedList<MemoryBlock>();
+ private final LinkedList<MemoryBlock> allocatedPages = new LinkedList<>();
+
+ private final LinkedList<UnsafeSorterSpillWriter> spillWriters = new LinkedList<>();
// These variables are reset after spilling:
- private UnsafeInMemorySorter sorter;
+ private UnsafeInMemorySorter inMemSorter;
+ // Whether the in-mem sorter is created internally, or passed in from outside.
+ // If it is passed in from outside, we shouldn't release the in-mem sorter's memory.
+ private boolean isInMemSorterExternal = false;
private MemoryBlock currentPage = null;
private long currentPagePosition = -1;
private long freeSpaceInCurrentPage = 0;
- private final LinkedList<UnsafeSorterSpillWriter> spillWriters = new LinkedList<>();
+ public static UnsafeExternalSorter createWithExistingInMemorySorter(
+ TaskMemoryManager taskMemoryManager,
+ ShuffleMemoryManager shuffleMemoryManager,
+ BlockManager blockManager,
+ TaskContext taskContext,
+ RecordComparator recordComparator,
+ PrefixComparator prefixComparator,
+ int initialSize,
+ long pageSizeBytes,
+ UnsafeInMemorySorter inMemorySorter) throws IOException {
+ return new UnsafeExternalSorter(taskMemoryManager, shuffleMemoryManager, blockManager,
+ taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, inMemorySorter);
+ }
- public UnsafeExternalSorter(
- TaskMemoryManager memoryManager,
+ public static UnsafeExternalSorter create(
+ TaskMemoryManager taskMemoryManager,
ShuffleMemoryManager shuffleMemoryManager,
BlockManager blockManager,
TaskContext taskContext,
RecordComparator recordComparator,
PrefixComparator prefixComparator,
int initialSize,
- SparkConf conf) throws IOException {
- this.memoryManager = memoryManager;
+ long pageSizeBytes) throws IOException {
+ return new UnsafeExternalSorter(taskMemoryManager, shuffleMemoryManager, blockManager,
+ taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, null);
+ }
+
+ private UnsafeExternalSorter(
+ TaskMemoryManager taskMemoryManager,
+ ShuffleMemoryManager shuffleMemoryManager,
+ BlockManager blockManager,
+ TaskContext taskContext,
+ RecordComparator recordComparator,
+ PrefixComparator prefixComparator,
+ int initialSize,
+ long pageSizeBytes,
+ @Nullable UnsafeInMemorySorter existingInMemorySorter) throws IOException {
+ this.taskMemoryManager = taskMemoryManager;
this.shuffleMemoryManager = shuffleMemoryManager;
this.blockManager = blockManager;
this.taskContext = taskContext;
@@ -90,9 +123,18 @@ public final class UnsafeExternalSorter {
this.prefixComparator = prefixComparator;
this.initialSize = initialSize;
// Use getSizeAsKb (not bytes) to maintain backwards compatibility for units
- this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
- this.pageSizeBytes = conf.getSizeAsBytes("spark.buffer.pageSize", "64m");
- initializeForWriting();
+ // this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
+ this.fileBufferSizeBytes = 32 * 1024;
+ // this.pageSizeBytes = conf.getSizeAsBytes("spark.buffer.pageSize", "64m");
+ this.pageSizeBytes = pageSizeBytes;
+ this.writeMetrics = new ShuffleWriteMetrics();
+
+ if (existingInMemorySorter == null) {
+ initializeForWriting();
+ } else {
+ this.isInMemSorterExternal = true;
+ this.inMemSorter = existingInMemorySorter;
+ }
// Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at
// the end of the task. This is necessary to avoid memory leaks in when the downstream operator
@@ -100,6 +142,7 @@ public final class UnsafeExternalSorter {
taskContext.addOnCompleteCallback(new AbstractFunction0<BoxedUnit>() {
@Override
public BoxedUnit apply() {
+ deleteSpillFiles();
freeMemory();
return null;
}
@@ -114,22 +157,31 @@ public final class UnsafeExternalSorter {
*/
private void initializeForWriting() throws IOException {
this.writeMetrics = new ShuffleWriteMetrics();
- // TODO: move this sizing calculation logic into a static method of sorter:
- final long memoryRequested = initialSize * 8L * 2;
- final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested);
- if (memoryAcquired != memoryRequested) {
+ final long pointerArrayMemory =
+ UnsafeInMemorySorter.getMemoryRequirementsForPointerArray(initialSize);
+ final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pointerArrayMemory);
+ if (memoryAcquired != pointerArrayMemory) {
shuffleMemoryManager.release(memoryAcquired);
- throw new IOException("Could not acquire " + memoryRequested + " bytes of memory");
+ throw new IOException("Could not acquire " + pointerArrayMemory + " bytes of memory");
}
- this.sorter =
- new UnsafeInMemorySorter(memoryManager, recordComparator, prefixComparator, initialSize);
+ this.inMemSorter =
+ new UnsafeInMemorySorter(taskMemoryManager, recordComparator, prefixComparator, initialSize);
+ this.isInMemSorterExternal = false;
}
/**
- * Sort and spill the current records in response to memory pressure.
+ * Marks the current page as no-more-space-available, and as a result, either allocate a
+ * new page or spill when we see the next record.
*/
@VisibleForTesting
+ public void closeCurrentPage() {
+ freeSpaceInCurrentPage = 0;
+ }
+
+ /**
+ * Sort and spill the current records in response to memory pressure.
+ */
public void spill() throws IOException {
logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)",
Thread.currentThread().getId(),
@@ -139,9 +191,9 @@ public final class UnsafeExternalSorter {
final UnsafeSorterSpillWriter spillWriter =
new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics,
- sorter.numRecords());
+ inMemSorter.numRecords());
spillWriters.add(spillWriter);
- final UnsafeSorterIterator sortedRecords = sorter.getSortedIterator();
+ final UnsafeSorterIterator sortedRecords = inMemSorter.getSortedIterator();
while (sortedRecords.hasNext()) {
sortedRecords.loadNext();
final Object baseObject = sortedRecords.getBaseObject();
@@ -150,20 +202,24 @@ public final class UnsafeExternalSorter {
spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix());
}
spillWriter.close();
- final long sorterMemoryUsage = sorter.getMemoryUsage();
- sorter = null;
- shuffleMemoryManager.release(sorterMemoryUsage);
final long spillSize = freeMemory();
+ // Note that this is more-or-less going to be a multiple of the page size, so wasted space in
+ // pages will currently be counted as memory spilled even though that space isn't actually
+ // written to disk. This also counts the space needed to store the sorter's pointer array.
taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
initializeForWriting();
}
+ /**
+ * Return the total memory usage of this sorter, including the data pages and the sorter's pointer
+ * array.
+ */
private long getMemoryUsage() {
long totalPageSize = 0;
for (MemoryBlock page : allocatedPages) {
totalPageSize += page.size();
}
- return sorter.getMemoryUsage() + totalPageSize;
+ return inMemSorter.getMemoryUsage() + totalPageSize;
}
@VisibleForTesting
@@ -171,13 +227,26 @@ public final class UnsafeExternalSorter {
return allocatedPages.size();
}
+ /**
+ * Free this sorter's in-memory data structures, including its data pages and pointer array.
+ *
+ * @return the number of bytes freed.
+ */
public long freeMemory() {
long memoryFreed = 0;
for (MemoryBlock block : allocatedPages) {
- memoryManager.freePage(block);
+ taskMemoryManager.freePage(block);
shuffleMemoryManager.release(block.size());
memoryFreed += block.size();
}
+ if (inMemSorter != null) {
+ if (!isInMemSorterExternal) {
+ long sorterMemoryUsage = inMemSorter.getMemoryUsage();
+ memoryFreed += sorterMemoryUsage;
+ shuffleMemoryManager.release(sorterMemoryUsage);
+ }
+ inMemSorter = null;
+ }
allocatedPages.clear();
currentPage = null;
currentPagePosition = -1;
@@ -186,6 +255,20 @@ public final class UnsafeExternalSorter {
}
/**
+ * Deletes any spill files created by this sorter.
+ */
+ public void deleteSpillFiles() {
+ for (UnsafeSorterSpillWriter spill : spillWriters) {
+ File file = spill.getFile();
+ if (file != null && file.exists()) {
+ if (!file.delete()) {
+ logger.error("Was unable to delete spill file {}", file.getAbsolutePath());
+ };
+ }
+ }
+ }
+
+ /**
* Checks whether there is enough space to insert a new record into the sorter.
*
* @param requiredSpace the required space in the data page, in bytes, including space for storing
@@ -195,7 +278,7 @@ public final class UnsafeExternalSorter {
*/
private boolean haveSpaceForRecord(int requiredSpace) {
assert (requiredSpace > 0);
- return (sorter.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage));
+ return (inMemSorter.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage));
}
/**
@@ -210,16 +293,16 @@ public final class UnsafeExternalSorter {
// TODO: merge these steps to first calculate total memory requirements for this insert,
// then try to acquire; no point in acquiring sort buffer only to spill due to no space in the
// data page.
- if (!sorter.hasSpaceForAnotherRecord()) {
+ if (!inMemSorter.hasSpaceForAnotherRecord()) {
logger.debug("Attempting to expand sort pointer array");
- final long oldPointerArrayMemoryUsage = sorter.getMemoryUsage();
+ final long oldPointerArrayMemoryUsage = inMemSorter.getMemoryUsage();
final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2;
final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray);
if (memoryAcquired < memoryToGrowPointerArray) {
shuffleMemoryManager.release(memoryAcquired);
spill();
} else {
- sorter.expandPointerArray();
+ inMemSorter.expandPointerArray();
shuffleMemoryManager.release(oldPointerArrayMemoryUsage);
}
}
@@ -236,7 +319,9 @@ public final class UnsafeExternalSorter {
} else {
final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
if (memoryAcquired < pageSizeBytes) {
- shuffleMemoryManager.release(memoryAcquired);
+ if (memoryAcquired > 0) {
+ shuffleMemoryManager.release(memoryAcquired);
+ }
spill();
final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
if (memoryAcquiredAfterSpilling != pageSizeBytes) {
@@ -244,7 +329,7 @@ public final class UnsafeExternalSorter {
throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory");
}
}
- currentPage = memoryManager.allocatePage(pageSizeBytes);
+ currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
currentPagePosition = currentPage.getBaseOffset();
freeSpaceInCurrentPage = pageSizeBytes;
allocatedPages.add(currentPage);
@@ -267,7 +352,7 @@ public final class UnsafeExternalSorter {
}
final long recordAddress =
- memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition);
+ taskMemoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition);
final Object dataPageBaseObject = currentPage.getBaseObject();
PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, lengthInBytes);
currentPagePosition += 4;
@@ -279,26 +364,48 @@ public final class UnsafeExternalSorter {
lengthInBytes);
currentPagePosition += lengthInBytes;
freeSpaceInCurrentPage -= totalSpaceRequired;
- sorter.insertRecord(recordAddress, prefix);
+ inMemSorter.insertRecord(recordAddress, prefix);
}
/**
- * Write a record to the sorter. The record is broken down into two different parts, and
+ * Write a key-value record to the sorter. The key and value will be put together in-memory,
+ * using the following format:
*
+ * record length (4 bytes), key length (4 bytes), key data, value data
+ *
+ * record length = key length + value length + 4
*/
- public void insertRecord(
- Object recordBaseObject1,
- long recordBaseOffset1,
- int lengthInBytes1,
- Object recordBaseObject2,
- long recordBaseOffset2,
- int lengthInBytes2,
- long prefix) throws IOException {
+ public void insertKVRecord(
+ Object keyBaseObj, long keyOffset, int keyLen,
+ Object valueBaseObj, long valueOffset, int valueLen, long prefix) throws IOException {
+ final int totalSpaceRequired = keyLen + valueLen + 4 + 4;
+ if (!haveSpaceForRecord(totalSpaceRequired)) {
+ allocateSpaceForRecord(totalSpaceRequired);
+ }
+
+ final long recordAddress =
+ taskMemoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition);
+ final Object dataPageBaseObject = currentPage.getBaseObject();
+ PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, keyLen + valueLen + 4);
+ currentPagePosition += 4;
+ PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, keyLen);
+ currentPagePosition += 4;
+
+ PlatformDependent.copyMemory(
+ keyBaseObj, keyOffset, dataPageBaseObject, currentPagePosition, keyLen);
+ currentPagePosition += keyLen;
+
+ PlatformDependent.copyMemory(
+ valueBaseObj, valueOffset, dataPageBaseObject, currentPagePosition, valueLen);
+ currentPagePosition += valueLen;
+
+ freeSpaceInCurrentPage -= totalSpaceRequired;
+ inMemSorter.insertRecord(recordAddress, prefix);
}
public UnsafeSorterIterator getSortedIterator() throws IOException {
- final UnsafeSorterIterator inMemoryIterator = sorter.getSortedIterator();
+ final UnsafeSorterIterator inMemoryIterator = inMemSorter.getSortedIterator();
int numIteratorsToMerge = spillWriters.size() + (inMemoryIterator.hasNext() ? 1 : 0);
if (spillWriters.isEmpty()) {
return inMemoryIterator;
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index fc34ad9cff..3131465391 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -100,6 +100,10 @@ public final class UnsafeInMemorySorter {
return pointerArray.length * 8L;
}
+ static long getMemoryRequirementsForPointerArray(long numEntries) {
+ return numEntries * 2L * 8L;
+ }
+
public boolean hasSpaceForAnotherRecord() {
return pointerArrayInsertPosition + 2 < pointerArray.length;
}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
index 29e9e0f30f..ca1ccedc93 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
@@ -31,6 +31,7 @@ import org.apache.spark.unsafe.PlatformDependent;
*/
final class UnsafeSorterSpillReader extends UnsafeSorterIterator {
+ private final File file;
private InputStream in;
private DataInputStream din;
@@ -48,6 +49,7 @@ final class UnsafeSorterSpillReader extends UnsafeSorterIterator {
File file,
BlockId blockId) throws IOException {
assert (file.length() > 0);
+ this.file = file;
final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file));
this.in = blockManager.wrapForCompression(blockId, bs);
this.din = new DataInputStream(this.in);
@@ -71,6 +73,7 @@ final class UnsafeSorterSpillReader extends UnsafeSorterIterator {
numRecordsRemaining--;
if (numRecordsRemaining == 0) {
in.close();
+ file.delete();
in = null;
din = null;
}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
index 71eed29563..44cf6c756d 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
@@ -140,6 +140,10 @@ final class UnsafeSorterSpillWriter {
writeBuffer = null;
}
+ public File getFile() {
+ return file;
+ }
+
public UnsafeSorterSpillReader getReader(BlockManager blockManager) throws IOException {
return new UnsafeSorterSpillReader(blockManager, file, blockId);
}
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index 70f8ca4d21..dbb7c662d7 100644
--- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -67,12 +67,11 @@ public abstract class AbstractBytesToBytesMapSuite {
@After
public void tearDown() {
- if (taskMemoryManager != null) {
+ Assert.assertEquals(0L, taskMemoryManager.cleanUpAllAllocatedMemory());
+ if (shuffleMemoryManager != null) {
long leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask();
- Assert.assertEquals(0, taskMemoryManager.cleanUpAllAllocatedMemory());
- Assert.assertEquals(0, leakedShuffleMemory);
shuffleMemoryManager = null;
- taskMemoryManager = null;
+ Assert.assertEquals(0L, leakedShuffleMemory);
}
}
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
index 0e391b7512..52fa8bcd57 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -20,12 +20,14 @@ package org.apache.spark.util.collection.unsafe.sort;
import java.io.File;
import java.io.InputStream;
import java.io.OutputStream;
+import java.util.LinkedList;
import java.util.UUID;
import scala.Tuple2;
import scala.Tuple2$;
import scala.runtime.AbstractFunction1;
+import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mock;
@@ -33,7 +35,6 @@ import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import static org.junit.Assert.*;
-import static org.mockito.AdditionalAnswers.returnsFirstArg;
import static org.mockito.AdditionalAnswers.returnsSecondArg;
import static org.mockito.Answers.RETURNS_SMART_NULLS;
import static org.mockito.Mockito.*;
@@ -53,7 +54,8 @@ import org.apache.spark.util.Utils;
public class UnsafeExternalSorterSuite {
- final TaskMemoryManager memoryManager =
+ final LinkedList<File> spillFilesCreated = new LinkedList<File>();
+ final TaskMemoryManager taskMemoryManager =
new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
// Use integer comparison for comparing prefixes (which are partition ids, in this case)
final PrefixComparator prefixComparator = new PrefixComparator() {
@@ -75,13 +77,15 @@ public class UnsafeExternalSorterSuite {
}
};
- @Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager;
+ ShuffleMemoryManager shuffleMemoryManager;
@Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
@Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
@Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext;
File tempDir;
+ private final long pageSizeBytes = new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "64m");
+
private static final class CompressStream extends AbstractFunction1<OutputStream, OutputStream> {
@Override
public OutputStream apply(OutputStream stream) {
@@ -93,15 +97,17 @@ public class UnsafeExternalSorterSuite {
public void setUp() {
MockitoAnnotations.initMocks(this);
tempDir = new File(Utils.createTempDir$default$1());
+ shuffleMemoryManager = new ShuffleMemoryManager(Long.MAX_VALUE);
+ spillFilesCreated.clear();
taskContext = mock(TaskContext.class);
when(taskContext.taskMetrics()).thenReturn(new TaskMetrics());
- when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg());
when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
when(diskBlockManager.createTempLocalBlock()).thenAnswer(new Answer<Tuple2<TempLocalBlockId, File>>() {
@Override
public Tuple2<TempLocalBlockId, File> answer(InvocationOnMock invocationOnMock) throws Throwable {
TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID());
File file = File.createTempFile("spillFile", ".spill", tempDir);
+ spillFilesCreated.add(file);
return Tuple2$.MODULE$.apply(blockId, file);
}
});
@@ -130,6 +136,24 @@ public class UnsafeExternalSorterSuite {
.then(returnsSecondArg());
}
+ @After
+ public void tearDown() {
+ long leakedUnsafeMemory = taskMemoryManager.cleanUpAllAllocatedMemory();
+ if (shuffleMemoryManager != null) {
+ long leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask();
+ shuffleMemoryManager = null;
+ assertEquals(0L, leakedShuffleMemory);
+ }
+ assertEquals(0, leakedUnsafeMemory);
+ }
+
+ private void assertSpillFilesWereCleanedUp() {
+ for (File spillFile : spillFilesCreated) {
+ assertFalse("Spill file " + spillFile.getPath() + " was not cleaned up",
+ spillFile.exists());
+ }
+ }
+
private static void insertNumber(UnsafeExternalSorter sorter, int value) throws Exception {
final int[] arr = new int[] { value };
sorter.insertRecord(arr, PlatformDependent.INT_ARRAY_OFFSET, 4, value);
@@ -138,15 +162,15 @@ public class UnsafeExternalSorterSuite {
@Test
public void testSortingOnlyByPrefix() throws Exception {
- final UnsafeExternalSorter sorter = new UnsafeExternalSorter(
- memoryManager,
+ final UnsafeExternalSorter sorter = UnsafeExternalSorter.create(
+ taskMemoryManager,
shuffleMemoryManager,
blockManager,
taskContext,
recordComparator,
prefixComparator,
- 1024,
- new SparkConf());
+ /* initialSize */ 1024,
+ pageSizeBytes);
insertNumber(sorter, 5);
insertNumber(sorter, 1);
@@ -165,22 +189,22 @@ public class UnsafeExternalSorterSuite {
// TODO: read rest of value.
}
- // TODO: test for cleanup:
- // assert(tempDir.isEmpty)
+ sorter.freeMemory();
+ assertSpillFilesWereCleanedUp();
}
@Test
public void testSortingEmptyArrays() throws Exception {
- final UnsafeExternalSorter sorter = new UnsafeExternalSorter(
- memoryManager,
+ final UnsafeExternalSorter sorter = UnsafeExternalSorter.create(
+ taskMemoryManager,
shuffleMemoryManager,
blockManager,
taskContext,
recordComparator,
prefixComparator,
- 1024,
- new SparkConf());
+ /* initialSize */ 1024,
+ pageSizeBytes);
sorter.insertRecord(null, 0, 0, 0);
sorter.insertRecord(null, 0, 0, 0);
@@ -197,25 +221,30 @@ public class UnsafeExternalSorterSuite {
assertEquals(0, iter.getKeyPrefix());
assertEquals(0, iter.getRecordLength());
}
+
+ sorter.freeMemory();
+ assertSpillFilesWereCleanedUp();
}
@Test
public void testFillingPage() throws Exception {
- final UnsafeExternalSorter sorter = new UnsafeExternalSorter(
- memoryManager,
+
+ final UnsafeExternalSorter sorter = UnsafeExternalSorter.create(
+ taskMemoryManager,
shuffleMemoryManager,
blockManager,
taskContext,
recordComparator,
prefixComparator,
- 1024,
- new SparkConf());
+ /* initialSize */ 1024,
+ pageSizeBytes);
byte[] record = new byte[16];
while (sorter.getNumberOfAllocatedPages() < 2) {
sorter.insertRecord(record, PlatformDependent.BYTE_ARRAY_OFFSET, record.length, 0);
}
sorter.freeMemory();
+ assertSpillFilesWereCleanedUp();
}
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 1b475b2492..b4fc0b7b70 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -507,7 +507,8 @@ public final class UnsafeRow extends MutableRow {
public String toString() {
StringBuilder build = new StringBuilder("[");
for (int i = 0; i < sizeInBytes; i += 8) {
- build.append(PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + i));
+ build.append(java.lang.Long.toHexString(
+ PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + i)));
build.append(',');
}
build.append(']');
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
index 68c49feae9..5e4c6232c9 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -59,20 +59,21 @@ final class UnsafeExternalRowSorter {
StructType schema,
Ordering<InternalRow> ordering,
PrefixComparator prefixComparator,
- PrefixComputer prefixComputer) throws IOException {
+ PrefixComputer prefixComputer,
+ long pageSizeBytes) throws IOException {
this.schema = schema;
this.prefixComputer = prefixComputer;
final SparkEnv sparkEnv = SparkEnv.get();
final TaskContext taskContext = TaskContext.get();
- sorter = new UnsafeExternalSorter(
+ sorter = UnsafeExternalSorter.create(
taskContext.taskMemoryManager(),
sparkEnv.shuffleMemoryManager(),
sparkEnv.blockManager(),
taskContext,
new RowComparator(ordering, schema.length()),
prefixComparator,
- 4096,
- sparkEnv.conf()
+ /* initialSize */ 4096,
+ pageSizeBytes
);
}
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index be0966641b..349007789f 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -106,6 +106,11 @@
<artifactId>parquet-avro</artifactId>
<scope>test</scope>
</dependency>
+ <dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-core</artifactId>
+ <scope>test</scope>
+ </dependency>
</dependencies>
<build>
<outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index a0a8dd5154..9e2c9334a7 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -19,24 +19,18 @@ package org.apache.spark.sql.execution;
import java.io.IOException;
+import org.apache.spark.SparkEnv;
import org.apache.spark.shuffle.ShuffleMemoryManager;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
-import org.apache.spark.sql.catalyst.expressions.codegen.BaseOrdering;
-import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.unsafe.KVIterator;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.map.BytesToBytesMap;
-import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.unsafe.memory.MemoryLocation;
import org.apache.spark.unsafe.memory.TaskMemoryManager;
-import org.apache.spark.util.collection.unsafe.sort.PrefixComparator;
-import org.apache.spark.util.collection.unsafe.sort.RecordComparator;
-import org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter;
-import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator;
/**
* Unsafe-based HashMap for performing aggregations where the aggregated values are fixed-width.
@@ -215,7 +209,7 @@ public final class UnsafeFixedWidthAggregationMap {
}
/**
- * Free the unsafe memory associated with this map.
+ * Free the memory associated with this map. This is idempotent and can be called multiple times.
*/
public void free() {
map.free();
@@ -233,92 +227,17 @@ public final class UnsafeFixedWidthAggregationMap {
}
/**
- * Sorts the key, value data in this map in place, and return them as an iterator.
+ * Sorts the map's records in place, spill them to disk, and returns an [[UnsafeKVExternalSorter]]
+ * that can be used to insert more records to do external sorting.
*
* The only memory that is allocated is the address/prefix array, 16 bytes per record.
+ *
+ * Note that this destroys the map, and as a result, the map cannot be used anymore after this.
*/
- public KVIterator<UnsafeRow, UnsafeRow> sortedIterator() {
- int numElements = map.numElements();
- final int numKeyFields = groupingKeySchema.size();
- TaskMemoryManager memoryManager = map.getTaskMemoryManager();
-
- UnsafeExternalRowSorter.PrefixComputer prefixComp =
- SortPrefixUtils.createPrefixGenerator(groupingKeySchema);
- PrefixComparator prefixComparator = SortPrefixUtils.getPrefixComparator(groupingKeySchema);
-
- final BaseOrdering ordering = GenerateOrdering.create(groupingKeySchema);
- RecordComparator recordComparator = new RecordComparator() {
- private final UnsafeRow row1 = new UnsafeRow();
- private final UnsafeRow row2 = new UnsafeRow();
-
- @Override
- public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) {
- row1.pointTo(baseObj1, baseOff1 + 4, numKeyFields, -1);
- row2.pointTo(baseObj2, baseOff2 + 4, numKeyFields, -1);
- return ordering.compare(row1, row2);
- }
- };
-
- // Insert the records into the in-memory sorter.
- final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(
- memoryManager, recordComparator, prefixComparator, numElements);
-
- BytesToBytesMap.BytesToBytesMapIterator iter = map.iterator();
- UnsafeRow row = new UnsafeRow();
- while (iter.hasNext()) {
- final BytesToBytesMap.Location loc = iter.next();
- final Object baseObject = loc.getKeyAddress().getBaseObject();
- final long baseOffset = loc.getKeyAddress().getBaseOffset();
-
- // Get encoded memory address
- MemoryBlock page = loc.getMemoryPage();
- long address = memoryManager.encodePageNumberAndOffset(page, baseOffset - 8);
-
- // Compute prefix
- row.pointTo(baseObject, baseOffset, numKeyFields, loc.getKeyLength());
- final long prefix = prefixComp.computePrefix(row);
-
- sorter.insertRecord(address, prefix);
- }
-
- // Return the sorted result as an iterator.
- return new KVIterator<UnsafeRow, UnsafeRow>() {
-
- private UnsafeSorterIterator sortedIterator = sorter.getSortedIterator();
- private final UnsafeRow key = new UnsafeRow();
- private final UnsafeRow value = new UnsafeRow();
- private int numValueFields = aggregationBufferSchema.size();
-
- @Override
- public boolean next() throws IOException {
- if (sortedIterator.hasNext()) {
- sortedIterator.loadNext();
- Object baseObj = sortedIterator.getBaseObject();
- long recordOffset = sortedIterator.getBaseOffset();
- int recordLen = sortedIterator.getRecordLength();
- int keyLen = PlatformDependent.UNSAFE.getInt(baseObj, recordOffset);
- key.pointTo(baseObj, recordOffset + 4, numKeyFields, keyLen);
- value.pointTo(baseObj, recordOffset + 4 + keyLen, numValueFields, recordLen - keyLen);
- return true;
- } else {
- return false;
- }
- }
-
- @Override
- public UnsafeRow getKey() {
- return key;
- }
-
- @Override
- public UnsafeRow getValue() {
- return value;
- }
-
- @Override
- public void close() {
- // Do nothing
- }
- };
+ public UnsafeKVExternalSorter destructAndCreateExternalSorter() throws IOException {
+ UnsafeKVExternalSorter sorter = new UnsafeKVExternalSorter(
+ groupingKeySchema, aggregationBufferSchema,
+ SparkEnv.get().blockManager(), map.getShuffleMemoryManager(), map.getPageSizeBytes(), map);
+ return sorter;
}
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
new file mode 100644
index 0000000000..f6b0176863
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
@@ -0,0 +1,236 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution;
+
+import java.io.IOException;
+
+import javax.annotation.Nullable;
+
+import com.google.common.annotations.VisibleForTesting;
+
+import org.apache.spark.TaskContext;
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
+import org.apache.spark.sql.catalyst.expressions.codegen.BaseOrdering;
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.storage.BlockManager;
+import org.apache.spark.unsafe.KVIterator;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.map.BytesToBytesMap;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.util.collection.unsafe.sort.*;
+
+/**
+ * A class for performing external sorting on key-value records. Both key and value are UnsafeRows.
+ *
+ * Note that this class allows optionally passing in a {@link BytesToBytesMap} directly in order
+ * to perform in-place sorting of records in the map.
+ */
+public final class UnsafeKVExternalSorter {
+
+ private final StructType keySchema;
+ private final StructType valueSchema;
+ private final UnsafeExternalRowSorter.PrefixComputer prefixComputer;
+ private final UnsafeExternalSorter sorter;
+
+ public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema,
+ BlockManager blockManager, ShuffleMemoryManager shuffleMemoryManager, long pageSizeBytes)
+ throws IOException {
+ this(keySchema, valueSchema, blockManager, shuffleMemoryManager, pageSizeBytes, null);
+ }
+
+ public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema,
+ BlockManager blockManager, ShuffleMemoryManager shuffleMemoryManager, long pageSizeBytes,
+ @Nullable BytesToBytesMap map) throws IOException {
+ this.keySchema = keySchema;
+ this.valueSchema = valueSchema;
+ final TaskContext taskContext = TaskContext.get();
+
+ prefixComputer = SortPrefixUtils.createPrefixGenerator(keySchema);
+ PrefixComparator prefixComparator = SortPrefixUtils.getPrefixComparator(keySchema);
+ BaseOrdering ordering = GenerateOrdering.create(keySchema);
+ KVComparator recordComparator = new KVComparator(ordering, keySchema.length());
+
+ TaskMemoryManager taskMemoryManager = taskContext.taskMemoryManager();
+
+ if (map == null) {
+ sorter = UnsafeExternalSorter.create(
+ taskMemoryManager,
+ shuffleMemoryManager,
+ blockManager,
+ taskContext,
+ recordComparator,
+ prefixComparator,
+ /* initialSize */ 4096,
+ pageSizeBytes);
+ } else {
+ // Insert the records into the in-memory sorter.
+ final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter(
+ taskMemoryManager, recordComparator, prefixComparator, map.numElements());
+
+ final int numKeyFields = keySchema.size();
+ BytesToBytesMap.BytesToBytesMapIterator iter = map.iterator();
+ UnsafeRow row = new UnsafeRow();
+ while (iter.hasNext()) {
+ final BytesToBytesMap.Location loc = iter.next();
+ final Object baseObject = loc.getKeyAddress().getBaseObject();
+ final long baseOffset = loc.getKeyAddress().getBaseOffset();
+
+ // Get encoded memory address
+ // baseObject + baseOffset point to the beginning of the key data in the map, but that
+ // the KV-pair's length data is stored in the word immediately before that address
+ MemoryBlock page = loc.getMemoryPage();
+ long address = taskMemoryManager.encodePageNumberAndOffset(page, baseOffset - 8);
+
+ // Compute prefix
+ row.pointTo(baseObject, baseOffset, numKeyFields, loc.getKeyLength());
+ final long prefix = prefixComputer.computePrefix(row);
+
+ inMemSorter.insertRecord(address, prefix);
+ }
+
+ sorter = UnsafeExternalSorter.createWithExistingInMemorySorter(
+ taskContext.taskMemoryManager(),
+ shuffleMemoryManager,
+ blockManager,
+ taskContext,
+ new KVComparator(ordering, keySchema.length()),
+ prefixComparator,
+ /* initialSize */ 4096,
+ pageSizeBytes,
+ inMemSorter);
+
+ sorter.spill();
+ map.free();
+ }
+ }
+
+ /**
+ * Inserts a key-value record into the sorter. If the sorter no longer has enough memory to hold
+ * the record, the sorter sorts the existing records in-memory, writes them out as partially
+ * sorted runs, and then reallocates memory to hold the new record.
+ */
+ public void insertKV(UnsafeRow key, UnsafeRow value) throws IOException {
+ final long prefix = prefixComputer.computePrefix(key);
+ sorter.insertKVRecord(
+ key.getBaseObject(), key.getBaseOffset(), key.getSizeInBytes(),
+ value.getBaseObject(), value.getBaseOffset(), value.getSizeInBytes(), prefix);
+ }
+
+ public KVIterator<UnsafeRow, UnsafeRow> sortedIterator() throws IOException {
+ try {
+ final UnsafeSorterIterator underlying = sorter.getSortedIterator();
+ if (!underlying.hasNext()) {
+ // Since we won't ever call next() on an empty iterator, we need to clean up resources
+ // here in order to prevent memory leaks.
+ cleanupResources();
+ }
+
+ return new KVIterator<UnsafeRow, UnsafeRow>() {
+ private UnsafeRow key = new UnsafeRow();
+ private UnsafeRow value = new UnsafeRow();
+ private int numKeyFields = keySchema.size();
+ private int numValueFields = valueSchema.size();
+
+ @Override
+ public boolean next() throws IOException {
+ try {
+ if (underlying.hasNext()) {
+ underlying.loadNext();
+
+ Object baseObj = underlying.getBaseObject();
+ long recordOffset = underlying.getBaseOffset();
+ int recordLen = underlying.getRecordLength();
+
+ // Note that recordLen = keyLen + valueLen + 4 bytes (for the keyLen itself)
+ int keyLen = PlatformDependent.UNSAFE.getInt(baseObj, recordOffset);
+ int valueLen = recordLen - keyLen - 4;
+
+ key.pointTo(baseObj, recordOffset + 4, numKeyFields, keyLen);
+ value.pointTo(baseObj, recordOffset + 4 + keyLen, numValueFields, valueLen);
+
+ return true;
+ } else {
+ key = null;
+ value = null;
+ cleanupResources();
+ return false;
+ }
+ } catch (IOException e) {
+ cleanupResources();
+ throw e;
+ }
+ }
+
+ @Override
+ public UnsafeRow getKey() {
+ return key;
+ }
+
+ @Override
+ public UnsafeRow getValue() {
+ return value;
+ }
+
+ @Override
+ public void close() {
+ cleanupResources();
+ }
+ };
+ } catch (IOException e) {
+ cleanupResources();
+ throw e;
+ }
+ }
+
+ /**
+ * Marks the current page as no-more-space-available, and as a result, either allocate a
+ * new page or spill when we see the next record.
+ */
+ @VisibleForTesting
+ void closeCurrentPage() {
+ sorter.closeCurrentPage();
+ }
+
+ private void cleanupResources() {
+ sorter.freeMemory();
+ }
+
+ private static final class KVComparator extends RecordComparator {
+ private final BaseOrdering ordering;
+ private final UnsafeRow row1 = new UnsafeRow();
+ private final UnsafeRow row2 = new UnsafeRow();
+ private final int numKeyFields;
+
+ public KVComparator(BaseOrdering ordering, int numKeyFields) {
+ this.numKeyFields = numKeyFields;
+ this.ordering = ordering;
+ }
+
+ @Override
+ public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) {
+ // Note that since ordering doesn't need the total length of the record, we just pass -1
+ // into the row.
+ row1.pointTo(baseObj1, baseOff1 + 4, numKeyFields, -1);
+ row2.pointTo(baseObj2, baseOff2 + 4, numKeyFields, -1);
+ return ordering.compare(row1, row2);
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
index 2e870ec8ae..49adf21537 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
@@ -50,17 +50,36 @@ object SortPrefixUtils {
}
}
+ /**
+ * Creates the prefix comparator for the first field in the given schema, in ascending order.
+ */
def getPrefixComparator(schema: StructType): PrefixComparator = {
- val field = schema.head
- getPrefixComparator(SortOrder(BoundReference(0, field.dataType, field.nullable), Ascending))
+ if (schema.nonEmpty) {
+ val field = schema.head
+ getPrefixComparator(SortOrder(BoundReference(0, field.dataType, field.nullable), Ascending))
+ } else {
+ new PrefixComparator {
+ override def compare(prefix1: Long, prefix2: Long): Int = 0
+ }
+ }
}
+ /**
+ * Creates the prefix computer for the first field in the given schema, in ascending order.
+ */
def createPrefixGenerator(schema: StructType): UnsafeExternalRowSorter.PrefixComputer = {
- val boundReference = BoundReference(0, schema.head.dataType, nullable = true)
- val prefixProjection = UnsafeProjection.create(SortPrefix(SortOrder(boundReference, Ascending)))
- new UnsafeExternalRowSorter.PrefixComputer {
- override def computePrefix(row: InternalRow): Long = {
- prefixProjection.apply(row).getLong(0)
+ if (schema.nonEmpty) {
+ val boundReference = BoundReference(0, schema.head.dataType, nullable = true)
+ val prefixProjection = UnsafeProjection.create(
+ SortPrefix(SortOrder(boundReference, Ascending)))
+ new UnsafeExternalRowSorter.PrefixComputer {
+ override def computePrefix(row: InternalRow): Long = {
+ prefixProjection.apply(row).getLong(0)
+ }
+ }
+ } else {
+ new UnsafeExternalRowSorter.PrefixComputer {
+ override def computePrefix(row: InternalRow): Long = 0
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
index 6d903ab23c..92cf328c76 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
@@ -116,6 +116,7 @@ case class TungstenSort(
protected override def doExecute(): RDD[InternalRow] = {
val schema = child.schema
val childOutput = child.output
+ val pageSize = sparkContext.conf.getSizeAsBytes("spark.buffer.pageSize", "64m")
child.execute().mapPartitions({ iter =>
val ordering = newOrdering(sortOrder, childOutput)
@@ -131,7 +132,8 @@ case class TungstenSort(
}
}
- val sorter = new UnsafeExternalRowSorter(schema, ordering, prefixComparator, prefixComputer)
+ val sorter = new UnsafeExternalRowSorter(
+ schema, ordering, prefixComparator, prefixComputer, pageSize)
if (testSpillFrequency > 0) {
sorter.setTestSpillFrequency(testSpillFrequency)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala
new file mode 100644
index 0000000000..53de2d0f07
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.shuffle.ShuffleMemoryManager
+
+/**
+ * A [[ShuffleMemoryManager]] that can be controlled to run out of memory.
+ */
+class TestShuffleMemoryManager extends ShuffleMemoryManager(Long.MaxValue) {
+ private var oom = false
+
+ override def tryToAcquire(numBytes: Long): Long = {
+ if (oom) {
+ oom = false
+ 0
+ } else {
+ // Uncomment the following to trace memory allocations.
+ // println(s"tryToAcquire $numBytes in " +
+ // Thread.currentThread().getStackTrace.mkString("", "\n -", ""))
+ val acquired = super.tryToAcquire(numBytes)
+ acquired
+ }
+ }
+
+ override def release(numBytes: Long): Unit = {
+ // Uncomment the following to trace memory releases.
+ // println(s"release $numBytes in " +
+ // Thread.currentThread().getStackTrace.mkString("", "\n -", ""))
+ super.release(numBytes)
+ }
+
+ def markAsOutOfMemory(): Unit = {
+ oom = true
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
index 098bdd0017..4c94b3307d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
@@ -17,24 +17,26 @@
package org.apache.spark.sql.execution
-import org.scalatest.{BeforeAndAfterEach, Matchers}
-
-import scala.collection.JavaConverters._
+import scala.util.control.NonFatal
import scala.collection.mutable
-import scala.util.Random
+import scala.util.{Try, Random}
+
+import org.scalatest.Matchers
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.shuffle.ShuffleMemoryManager
+import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
+import org.apache.spark.{TaskContextImpl, TaskContext, SparkFunSuite}
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
import org.apache.spark.unsafe.types.UTF8String
-
-class UnsafeFixedWidthAggregationMapSuite
- extends SparkFunSuite
- with Matchers
- with BeforeAndAfterEach {
+/**
+ * Test suite for [[UnsafeFixedWidthAggregationMap]].
+ *
+ * Use [[testWithMemoryLeakDetection]] rather than [[test]] to construct test cases.
+ */
+class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers {
import UnsafeFixedWidthAggregationMap._
@@ -44,23 +46,40 @@ class UnsafeFixedWidthAggregationMapSuite
private val PAGE_SIZE_BYTES: Long = 1L << 26; // 64 megabytes
private var taskMemoryManager: TaskMemoryManager = null
- private var shuffleMemoryManager: ShuffleMemoryManager = null
+ private var shuffleMemoryManager: TestShuffleMemoryManager = null
+
+ def testWithMemoryLeakDetection(name: String)(f: => Unit) {
+ def cleanup(): Unit = {
+ if (taskMemoryManager != null) {
+ val leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask()
+ assert(taskMemoryManager.cleanUpAllAllocatedMemory() === 0)
+ assert(leakedShuffleMemory === 0)
+ taskMemoryManager = null
+ }
+ }
- override def beforeEach(): Unit = {
- taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
- shuffleMemoryManager = new ShuffleMemoryManager(Long.MaxValue)
+ test(name) {
+ taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
+ shuffleMemoryManager = new TestShuffleMemoryManager
+ try {
+ f
+ } catch {
+ case NonFatal(e) =>
+ Try(cleanup())
+ throw e
+ }
+ cleanup()
+ }
}
- override def afterEach(): Unit = {
- if (taskMemoryManager != null) {
- val leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask()
- assert(taskMemoryManager.cleanUpAllAllocatedMemory() === 0)
- assert(leakedShuffleMemory === 0)
- taskMemoryManager = null
- }
+ private def randomStrings(n: Int): Seq[String] = {
+ val rand = new Random(42)
+ Seq.fill(512) {
+ Seq.fill(rand.nextInt(100))(rand.nextPrintableChar()).mkString
+ }.distinct
}
- test("supported schemas") {
+ testWithMemoryLeakDetection("supported schemas") {
assert(supportsAggregationBufferSchema(
StructType(StructField("x", DecimalType.USER_DEFAULT) :: Nil)))
assert(!supportsAggregationBufferSchema(
@@ -70,7 +89,7 @@ class UnsafeFixedWidthAggregationMapSuite
!supportsAggregationBufferSchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil)))
}
- test("empty map") {
+ testWithMemoryLeakDetection("empty map") {
val map = new UnsafeFixedWidthAggregationMap(
emptyAggregationBuffer,
aggBufferSchema,
@@ -85,7 +104,7 @@ class UnsafeFixedWidthAggregationMapSuite
map.free()
}
- test("updating values for a single key") {
+ testWithMemoryLeakDetection("updating values for a single key") {
val map = new UnsafeFixedWidthAggregationMap(
emptyAggregationBuffer,
aggBufferSchema,
@@ -113,7 +132,7 @@ class UnsafeFixedWidthAggregationMapSuite
map.free()
}
- test("inserting large random keys") {
+ testWithMemoryLeakDetection("inserting large random keys") {
val map = new UnsafeFixedWidthAggregationMap(
emptyAggregationBuffer,
aggBufferSchema,
@@ -140,7 +159,21 @@ class UnsafeFixedWidthAggregationMapSuite
map.free()
}
- test("test sorting") {
+ testWithMemoryLeakDetection("test external sorting") {
+ // Calling this make sure we have block manager and everything else setup.
+ TestSQLContext
+
+ TaskContext.setTaskContext(new TaskContextImpl(
+ stageId = 0,
+ partitionId = 0,
+ taskAttemptId = 0,
+ attemptNumber = 0,
+ taskMemoryManager = taskMemoryManager,
+ metricsSystem = null))
+
+ // Memory consumption in the beginning of the task.
+ val initialMemoryConsumption = shuffleMemoryManager.getMemoryConsumptionForThisTask()
+
val map = new UnsafeFixedWidthAggregationMap(
emptyAggregationBuffer,
aggBufferSchema,
@@ -152,26 +185,47 @@ class UnsafeFixedWidthAggregationMapSuite
false // disable perf metrics
)
- val rand = new Random(42)
- val groupKeys: Set[String] = Seq.fill(512) {
- Seq.fill(rand.nextInt(100))(rand.nextPrintableChar()).mkString
- }.toSet
- groupKeys.foreach { keyString =>
+ val keys = randomStrings(1024).take(512)
+ keys.foreach { keyString =>
val buf = map.getAggregationBuffer(InternalRow(UTF8String.fromString(keyString)))
buf.setInt(0, keyString.length)
assert(buf != null)
}
+ // Convert the map into a sorter
+ val sorter = map.destructAndCreateExternalSorter()
+
+ withClue(s"destructAndCreateExternalSorter should release memory used by the map") {
+ // 4096 * 16 is the initial size allocated for the pointer/prefix array in the in-mem sorter.
+ assert(shuffleMemoryManager.getMemoryConsumptionForThisTask() ===
+ initialMemoryConsumption + 4096 * 16)
+ }
+
+ // Add more keys to the sorter and make sure the results come out sorted.
+ val additionalKeys = randomStrings(1024)
+ val keyConverter = UnsafeProjection.create(groupKeySchema)
+ val valueConverter = UnsafeProjection.create(aggBufferSchema)
+
+ additionalKeys.zipWithIndex.foreach { case (str, i) =>
+ val k = InternalRow(UTF8String.fromString(str))
+ val v = InternalRow(str.length)
+ sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v))
+
+ if ((i % 100) == 0) {
+ shuffleMemoryManager.markAsOutOfMemory()
+ sorter.closeCurrentPage()
+ }
+ }
+
val out = new scala.collection.mutable.ArrayBuffer[String]
- val iter = map.sortedIterator()
+ val iter = sorter.sortedIterator()
while (iter.next()) {
assert(iter.getKey.getString(0).length === iter.getValue.getInt(0))
out += iter.getKey.getString(0)
}
- assert(out === groupKeys.toSeq.sorted)
+ assert(out === (keys ++ additionalKeys).sorted)
map.free()
}
-
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
new file mode 100644
index 0000000000..5d214d7bfc
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import scala.util.Random
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{RowOrdering, UnsafeProjection}
+import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark._
+
+class UnsafeKVExternalSorterSuite extends SparkFunSuite {
+
+ test("sorting string key and int int value") {
+
+ // Calling this make sure we have block manager and everything else setup.
+ TestSQLContext
+
+ val taskMemMgr = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
+ val shuffleMemMgr = new TestShuffleMemoryManager
+
+ TaskContext.setTaskContext(new TaskContextImpl(
+ stageId = 0,
+ partitionId = 0,
+ taskAttemptId = 0,
+ attemptNumber = 0,
+ taskMemoryManager = taskMemMgr,
+ metricsSystem = null))
+
+ val keySchema = new StructType().add("a", StringType)
+ val valueSchema = new StructType().add("b", IntegerType).add("c", IntegerType)
+ val sorter = new UnsafeKVExternalSorter(
+ keySchema, valueSchema, SparkEnv.get.blockManager, shuffleMemMgr,
+ 16 * 1024)
+
+ val keyConverter = UnsafeProjection.create(keySchema)
+ val valueConverter = UnsafeProjection.create(valueSchema)
+
+ val rand = new Random(42)
+ val data = null +: Seq.fill[String](10) {
+ Seq.fill(rand.nextInt(100))(rand.nextPrintableChar()).mkString
+ }
+
+ val inputRows = data.map { str =>
+ keyConverter.apply(InternalRow(UTF8String.fromString(str))).copy()
+ }
+
+ var i = 0
+ data.foreach { str =>
+ if (str != null) {
+ val k = InternalRow(UTF8String.fromString(str))
+ val v = InternalRow(str.length, str.length + 1)
+ sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v))
+ } else {
+ val k = InternalRow(UTF8String.fromString(str))
+ val v = InternalRow(-1, -2)
+ sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v))
+ }
+
+ if ((i % 100) == 0) {
+ shuffleMemMgr.markAsOutOfMemory()
+ sorter.closeCurrentPage()
+ }
+ i += 1
+ }
+
+ val out = new scala.collection.mutable.ArrayBuffer[InternalRow]
+ val iter = sorter.sortedIterator()
+ while (iter.next()) {
+ if (iter.getKey.getUTF8String(0) == null) {
+ withClue(s"for null key") {
+ assert(-1 === iter.getValue.getInt(0))
+ assert(-2 === iter.getValue.getInt(1))
+ }
+ } else {
+ val key = iter.getKey.getString(0)
+ withClue(s"for key $key") {
+ assert(key.length === iter.getValue.getInt(0))
+ assert(key.length + 1 === iter.getValue.getInt(1))
+ }
+ }
+ out += iter.getKey.copy()
+ }
+
+ assert(out === inputRows.sorted(RowOrdering.forSchema(keySchema.map(_.dataType))))
+ }
+
+ test("sorting arbitrary string data") {
+
+ // Calling this make sure we have block manager and everything else setup.
+ TestSQLContext
+
+ val taskMemMgr = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
+ val shuffleMemMgr = new TestShuffleMemoryManager
+
+ TaskContext.setTaskContext(new TaskContextImpl(
+ stageId = 0,
+ partitionId = 0,
+ taskAttemptId = 0,
+ attemptNumber = 0,
+ taskMemoryManager = taskMemMgr,
+ metricsSystem = null))
+
+ val keySchema = new StructType().add("a", StringType)
+ val valueSchema = new StructType().add("b", IntegerType)
+ val sorter = new UnsafeKVExternalSorter(
+ keySchema, valueSchema, SparkEnv.get.blockManager, shuffleMemMgr,
+ 16 * 1024)
+
+ val keyConverter = UnsafeProjection.create(keySchema)
+ val valueConverter = UnsafeProjection.create(valueSchema)
+
+ val rand = new Random(42)
+ val data = Seq.fill(512) {
+ Seq.fill(rand.nextInt(100))(rand.nextPrintableChar()).mkString
+ }
+
+ var i = 0
+ data.foreach { str =>
+ val k = InternalRow(UTF8String.fromString(str))
+ val v = InternalRow(str.length)
+ sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v))
+
+ if ((i % 100) == 0) {
+ shuffleMemMgr.markAsOutOfMemory()
+ sorter.closeCurrentPage()
+ }
+ i += 1
+ }
+
+ val out = new scala.collection.mutable.ArrayBuffer[String]
+ val iter = sorter.sortedIterator()
+ while (iter.next()) {
+ assert(iter.getKey.getString(0).length === iter.getValue.getInt(0))
+ out += iter.getKey.getString(0)
+ }
+
+ assert(out === data.sorted)
+ }
+}