diff options
17 files changed, 823 insertions, 215 deletions
diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index cf222b7272..01a66084e9 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -39,14 +39,22 @@ import org.apache.spark.unsafe.memory.*; /** * An append-only hash map where keys and values are contiguous regions of bytes. - * <p> + * * This is backed by a power-of-2-sized hash table, using quadratic probing with triangular numbers, * which is guaranteed to exhaust the space. - * <p> + * * The map can support up to 2^29 keys. If the key cardinality is higher than this, you should * probably be using sorting instead of hashing for better cache locality. - * <p> - * This class is not thread safe. + * + * The key and values under the hood are stored together, in the following format: + * Bytes 0 to 4: len(k) (key length in bytes) + len(v) (value length in bytes) + 4 + * Bytes 4 to 8: len(k) + * Bytes 8 to 8 + len(k): key data + * Bytes 8 + len(k) to 8 + len(k) + len(v): value data + * + * This means that the first four bytes store the entire record (key + value) length. This format + * is consistent with {@link org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter}, + * so we can pass records from this map directly into the sorter to sort records in place. */ public final class BytesToBytesMap { @@ -253,7 +261,7 @@ public final class BytesToBytesMap { totalLength = PlatformDependent.UNSAFE.getInt(pageBaseObject, offsetInPage); } loc.with(currentPage, offsetInPage); - offsetInPage += 8 + totalLength; + offsetInPage += 4 + totalLength; currentRecordNumber++; return loc; } @@ -366,7 +374,7 @@ public final class BytesToBytesMap { position += 4; keyLength = PlatformDependent.UNSAFE.getInt(page, position); position += 4; - valueLength = totalLength - keyLength; + valueLength = totalLength - keyLength - 4; keyMemoryLocation.setObjAndOffset(page, position); @@ -565,7 +573,7 @@ public final class BytesToBytesMap { insertCursor += valueLengthBytes; // word used to store the value size PlatformDependent.UNSAFE.putInt(dataPageBaseObject, recordOffset, - keyLengthBytes + valueLengthBytes); + keyLengthBytes + valueLengthBytes + 4); PlatformDependent.UNSAFE.putInt(dataPageBaseObject, keyLengthOffset, keyLengthBytes); // Copy the key PlatformDependent.copyMemory( @@ -620,7 +628,7 @@ public final class BytesToBytesMap { * Free all allocated memory associated with this map, including the storage for keys and values * as well as the hash map array itself. * - * This method is idempotent. + * This method is idempotent and can be called multiple times. */ public void free() { longArray = null; @@ -639,6 +647,14 @@ public final class BytesToBytesMap { return taskMemoryManager; } + public ShuffleMemoryManager getShuffleMemoryManager() { + return shuffleMemoryManager; + } + + public long getPageSizeBytes() { + return pageSizeBytes; + } + /** Returns the total amount of memory, in bytes, consumed by this map's managed structures. */ public long getTotalMemoryConsumption() { long totalDataPagesSize = 0L; diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index c05f2c332e..b984301cbb 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -17,9 +17,12 @@ package org.apache.spark.util.collection.unsafe.sort; +import java.io.File; import java.io.IOException; import java.util.LinkedList; +import javax.annotation.Nullable; + import scala.runtime.AbstractFunction0; import scala.runtime.BoxedUnit; @@ -27,7 +30,6 @@ import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.shuffle.ShuffleMemoryManager; @@ -48,7 +50,7 @@ public final class UnsafeExternalSorter { private final PrefixComparator prefixComparator; private final RecordComparator recordComparator; private final int initialSize; - private final TaskMemoryManager memoryManager; + private final TaskMemoryManager taskMemoryManager; private final ShuffleMemoryManager shuffleMemoryManager; private final BlockManager blockManager; private final TaskContext taskContext; @@ -63,26 +65,57 @@ public final class UnsafeExternalSorter { * this might not be necessary if we maintained a pool of re-usable pages in the TaskMemoryManager * itself). */ - private final LinkedList<MemoryBlock> allocatedPages = new LinkedList<MemoryBlock>(); + private final LinkedList<MemoryBlock> allocatedPages = new LinkedList<>(); + + private final LinkedList<UnsafeSorterSpillWriter> spillWriters = new LinkedList<>(); // These variables are reset after spilling: - private UnsafeInMemorySorter sorter; + private UnsafeInMemorySorter inMemSorter; + // Whether the in-mem sorter is created internally, or passed in from outside. + // If it is passed in from outside, we shouldn't release the in-mem sorter's memory. + private boolean isInMemSorterExternal = false; private MemoryBlock currentPage = null; private long currentPagePosition = -1; private long freeSpaceInCurrentPage = 0; - private final LinkedList<UnsafeSorterSpillWriter> spillWriters = new LinkedList<>(); + public static UnsafeExternalSorter createWithExistingInMemorySorter( + TaskMemoryManager taskMemoryManager, + ShuffleMemoryManager shuffleMemoryManager, + BlockManager blockManager, + TaskContext taskContext, + RecordComparator recordComparator, + PrefixComparator prefixComparator, + int initialSize, + long pageSizeBytes, + UnsafeInMemorySorter inMemorySorter) throws IOException { + return new UnsafeExternalSorter(taskMemoryManager, shuffleMemoryManager, blockManager, + taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, inMemorySorter); + } - public UnsafeExternalSorter( - TaskMemoryManager memoryManager, + public static UnsafeExternalSorter create( + TaskMemoryManager taskMemoryManager, ShuffleMemoryManager shuffleMemoryManager, BlockManager blockManager, TaskContext taskContext, RecordComparator recordComparator, PrefixComparator prefixComparator, int initialSize, - SparkConf conf) throws IOException { - this.memoryManager = memoryManager; + long pageSizeBytes) throws IOException { + return new UnsafeExternalSorter(taskMemoryManager, shuffleMemoryManager, blockManager, + taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, null); + } + + private UnsafeExternalSorter( + TaskMemoryManager taskMemoryManager, + ShuffleMemoryManager shuffleMemoryManager, + BlockManager blockManager, + TaskContext taskContext, + RecordComparator recordComparator, + PrefixComparator prefixComparator, + int initialSize, + long pageSizeBytes, + @Nullable UnsafeInMemorySorter existingInMemorySorter) throws IOException { + this.taskMemoryManager = taskMemoryManager; this.shuffleMemoryManager = shuffleMemoryManager; this.blockManager = blockManager; this.taskContext = taskContext; @@ -90,9 +123,18 @@ public final class UnsafeExternalSorter { this.prefixComparator = prefixComparator; this.initialSize = initialSize; // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units - this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; - this.pageSizeBytes = conf.getSizeAsBytes("spark.buffer.pageSize", "64m"); - initializeForWriting(); + // this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; + this.fileBufferSizeBytes = 32 * 1024; + // this.pageSizeBytes = conf.getSizeAsBytes("spark.buffer.pageSize", "64m"); + this.pageSizeBytes = pageSizeBytes; + this.writeMetrics = new ShuffleWriteMetrics(); + + if (existingInMemorySorter == null) { + initializeForWriting(); + } else { + this.isInMemSorterExternal = true; + this.inMemSorter = existingInMemorySorter; + } // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at // the end of the task. This is necessary to avoid memory leaks in when the downstream operator @@ -100,6 +142,7 @@ public final class UnsafeExternalSorter { taskContext.addOnCompleteCallback(new AbstractFunction0<BoxedUnit>() { @Override public BoxedUnit apply() { + deleteSpillFiles(); freeMemory(); return null; } @@ -114,22 +157,31 @@ public final class UnsafeExternalSorter { */ private void initializeForWriting() throws IOException { this.writeMetrics = new ShuffleWriteMetrics(); - // TODO: move this sizing calculation logic into a static method of sorter: - final long memoryRequested = initialSize * 8L * 2; - final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested); - if (memoryAcquired != memoryRequested) { + final long pointerArrayMemory = + UnsafeInMemorySorter.getMemoryRequirementsForPointerArray(initialSize); + final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pointerArrayMemory); + if (memoryAcquired != pointerArrayMemory) { shuffleMemoryManager.release(memoryAcquired); - throw new IOException("Could not acquire " + memoryRequested + " bytes of memory"); + throw new IOException("Could not acquire " + pointerArrayMemory + " bytes of memory"); } - this.sorter = - new UnsafeInMemorySorter(memoryManager, recordComparator, prefixComparator, initialSize); + this.inMemSorter = + new UnsafeInMemorySorter(taskMemoryManager, recordComparator, prefixComparator, initialSize); + this.isInMemSorterExternal = false; } /** - * Sort and spill the current records in response to memory pressure. + * Marks the current page as no-more-space-available, and as a result, either allocate a + * new page or spill when we see the next record. */ @VisibleForTesting + public void closeCurrentPage() { + freeSpaceInCurrentPage = 0; + } + + /** + * Sort and spill the current records in response to memory pressure. + */ public void spill() throws IOException { logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)", Thread.currentThread().getId(), @@ -139,9 +191,9 @@ public final class UnsafeExternalSorter { final UnsafeSorterSpillWriter spillWriter = new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, - sorter.numRecords()); + inMemSorter.numRecords()); spillWriters.add(spillWriter); - final UnsafeSorterIterator sortedRecords = sorter.getSortedIterator(); + final UnsafeSorterIterator sortedRecords = inMemSorter.getSortedIterator(); while (sortedRecords.hasNext()) { sortedRecords.loadNext(); final Object baseObject = sortedRecords.getBaseObject(); @@ -150,20 +202,24 @@ public final class UnsafeExternalSorter { spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix()); } spillWriter.close(); - final long sorterMemoryUsage = sorter.getMemoryUsage(); - sorter = null; - shuffleMemoryManager.release(sorterMemoryUsage); final long spillSize = freeMemory(); + // Note that this is more-or-less going to be a multiple of the page size, so wasted space in + // pages will currently be counted as memory spilled even though that space isn't actually + // written to disk. This also counts the space needed to store the sorter's pointer array. taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); initializeForWriting(); } + /** + * Return the total memory usage of this sorter, including the data pages and the sorter's pointer + * array. + */ private long getMemoryUsage() { long totalPageSize = 0; for (MemoryBlock page : allocatedPages) { totalPageSize += page.size(); } - return sorter.getMemoryUsage() + totalPageSize; + return inMemSorter.getMemoryUsage() + totalPageSize; } @VisibleForTesting @@ -171,13 +227,26 @@ public final class UnsafeExternalSorter { return allocatedPages.size(); } + /** + * Free this sorter's in-memory data structures, including its data pages and pointer array. + * + * @return the number of bytes freed. + */ public long freeMemory() { long memoryFreed = 0; for (MemoryBlock block : allocatedPages) { - memoryManager.freePage(block); + taskMemoryManager.freePage(block); shuffleMemoryManager.release(block.size()); memoryFreed += block.size(); } + if (inMemSorter != null) { + if (!isInMemSorterExternal) { + long sorterMemoryUsage = inMemSorter.getMemoryUsage(); + memoryFreed += sorterMemoryUsage; + shuffleMemoryManager.release(sorterMemoryUsage); + } + inMemSorter = null; + } allocatedPages.clear(); currentPage = null; currentPagePosition = -1; @@ -186,6 +255,20 @@ public final class UnsafeExternalSorter { } /** + * Deletes any spill files created by this sorter. + */ + public void deleteSpillFiles() { + for (UnsafeSorterSpillWriter spill : spillWriters) { + File file = spill.getFile(); + if (file != null && file.exists()) { + if (!file.delete()) { + logger.error("Was unable to delete spill file {}", file.getAbsolutePath()); + }; + } + } + } + + /** * Checks whether there is enough space to insert a new record into the sorter. * * @param requiredSpace the required space in the data page, in bytes, including space for storing @@ -195,7 +278,7 @@ public final class UnsafeExternalSorter { */ private boolean haveSpaceForRecord(int requiredSpace) { assert (requiredSpace > 0); - return (sorter.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage)); + return (inMemSorter.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage)); } /** @@ -210,16 +293,16 @@ public final class UnsafeExternalSorter { // TODO: merge these steps to first calculate total memory requirements for this insert, // then try to acquire; no point in acquiring sort buffer only to spill due to no space in the // data page. - if (!sorter.hasSpaceForAnotherRecord()) { + if (!inMemSorter.hasSpaceForAnotherRecord()) { logger.debug("Attempting to expand sort pointer array"); - final long oldPointerArrayMemoryUsage = sorter.getMemoryUsage(); + final long oldPointerArrayMemoryUsage = inMemSorter.getMemoryUsage(); final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2; final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray); if (memoryAcquired < memoryToGrowPointerArray) { shuffleMemoryManager.release(memoryAcquired); spill(); } else { - sorter.expandPointerArray(); + inMemSorter.expandPointerArray(); shuffleMemoryManager.release(oldPointerArrayMemoryUsage); } } @@ -236,7 +319,9 @@ public final class UnsafeExternalSorter { } else { final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes); if (memoryAcquired < pageSizeBytes) { - shuffleMemoryManager.release(memoryAcquired); + if (memoryAcquired > 0) { + shuffleMemoryManager.release(memoryAcquired); + } spill(); final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes); if (memoryAcquiredAfterSpilling != pageSizeBytes) { @@ -244,7 +329,7 @@ public final class UnsafeExternalSorter { throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory"); } } - currentPage = memoryManager.allocatePage(pageSizeBytes); + currentPage = taskMemoryManager.allocatePage(pageSizeBytes); currentPagePosition = currentPage.getBaseOffset(); freeSpaceInCurrentPage = pageSizeBytes; allocatedPages.add(currentPage); @@ -267,7 +352,7 @@ public final class UnsafeExternalSorter { } final long recordAddress = - memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition); + taskMemoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition); final Object dataPageBaseObject = currentPage.getBaseObject(); PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, lengthInBytes); currentPagePosition += 4; @@ -279,26 +364,48 @@ public final class UnsafeExternalSorter { lengthInBytes); currentPagePosition += lengthInBytes; freeSpaceInCurrentPage -= totalSpaceRequired; - sorter.insertRecord(recordAddress, prefix); + inMemSorter.insertRecord(recordAddress, prefix); } /** - * Write a record to the sorter. The record is broken down into two different parts, and + * Write a key-value record to the sorter. The key and value will be put together in-memory, + * using the following format: * + * record length (4 bytes), key length (4 bytes), key data, value data + * + * record length = key length + value length + 4 */ - public void insertRecord( - Object recordBaseObject1, - long recordBaseOffset1, - int lengthInBytes1, - Object recordBaseObject2, - long recordBaseOffset2, - int lengthInBytes2, - long prefix) throws IOException { + public void insertKVRecord( + Object keyBaseObj, long keyOffset, int keyLen, + Object valueBaseObj, long valueOffset, int valueLen, long prefix) throws IOException { + final int totalSpaceRequired = keyLen + valueLen + 4 + 4; + if (!haveSpaceForRecord(totalSpaceRequired)) { + allocateSpaceForRecord(totalSpaceRequired); + } + + final long recordAddress = + taskMemoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition); + final Object dataPageBaseObject = currentPage.getBaseObject(); + PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, keyLen + valueLen + 4); + currentPagePosition += 4; + PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, keyLen); + currentPagePosition += 4; + + PlatformDependent.copyMemory( + keyBaseObj, keyOffset, dataPageBaseObject, currentPagePosition, keyLen); + currentPagePosition += keyLen; + + PlatformDependent.copyMemory( + valueBaseObj, valueOffset, dataPageBaseObject, currentPagePosition, valueLen); + currentPagePosition += valueLen; + + freeSpaceInCurrentPage -= totalSpaceRequired; + inMemSorter.insertRecord(recordAddress, prefix); } public UnsafeSorterIterator getSortedIterator() throws IOException { - final UnsafeSorterIterator inMemoryIterator = sorter.getSortedIterator(); + final UnsafeSorterIterator inMemoryIterator = inMemSorter.getSortedIterator(); int numIteratorsToMerge = spillWriters.size() + (inMemoryIterator.hasNext() ? 1 : 0); if (spillWriters.isEmpty()) { return inMemoryIterator; diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index fc34ad9cff..3131465391 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -100,6 +100,10 @@ public final class UnsafeInMemorySorter { return pointerArray.length * 8L; } + static long getMemoryRequirementsForPointerArray(long numEntries) { + return numEntries * 2L * 8L; + } + public boolean hasSpaceForAnotherRecord() { return pointerArrayInsertPosition + 2 < pointerArray.length; } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index 29e9e0f30f..ca1ccedc93 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -31,6 +31,7 @@ import org.apache.spark.unsafe.PlatformDependent; */ final class UnsafeSorterSpillReader extends UnsafeSorterIterator { + private final File file; private InputStream in; private DataInputStream din; @@ -48,6 +49,7 @@ final class UnsafeSorterSpillReader extends UnsafeSorterIterator { File file, BlockId blockId) throws IOException { assert (file.length() > 0); + this.file = file; final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file)); this.in = blockManager.wrapForCompression(blockId, bs); this.din = new DataInputStream(this.in); @@ -71,6 +73,7 @@ final class UnsafeSorterSpillReader extends UnsafeSorterIterator { numRecordsRemaining--; if (numRecordsRemaining == 0) { in.close(); + file.delete(); in = null; din = null; } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java index 71eed29563..44cf6c756d 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java @@ -140,6 +140,10 @@ final class UnsafeSorterSpillWriter { writeBuffer = null; } + public File getFile() { + return file; + } + public UnsafeSorterSpillReader getReader(BlockManager blockManager) throws IOException { return new UnsafeSorterSpillReader(blockManager, file, blockId); } diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 70f8ca4d21..dbb7c662d7 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -67,12 +67,11 @@ public abstract class AbstractBytesToBytesMapSuite { @After public void tearDown() { - if (taskMemoryManager != null) { + Assert.assertEquals(0L, taskMemoryManager.cleanUpAllAllocatedMemory()); + if (shuffleMemoryManager != null) { long leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask(); - Assert.assertEquals(0, taskMemoryManager.cleanUpAllAllocatedMemory()); - Assert.assertEquals(0, leakedShuffleMemory); shuffleMemoryManager = null; - taskMemoryManager = null; + Assert.assertEquals(0L, leakedShuffleMemory); } } diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 0e391b7512..52fa8bcd57 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -20,12 +20,14 @@ package org.apache.spark.util.collection.unsafe.sort; import java.io.File; import java.io.InputStream; import java.io.OutputStream; +import java.util.LinkedList; import java.util.UUID; import scala.Tuple2; import scala.Tuple2$; import scala.runtime.AbstractFunction1; +import org.junit.After; import org.junit.Before; import org.junit.Test; import org.mockito.Mock; @@ -33,7 +35,6 @@ import org.mockito.MockitoAnnotations; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import static org.junit.Assert.*; -import static org.mockito.AdditionalAnswers.returnsFirstArg; import static org.mockito.AdditionalAnswers.returnsSecondArg; import static org.mockito.Answers.RETURNS_SMART_NULLS; import static org.mockito.Mockito.*; @@ -53,7 +54,8 @@ import org.apache.spark.util.Utils; public class UnsafeExternalSorterSuite { - final TaskMemoryManager memoryManager = + final LinkedList<File> spillFilesCreated = new LinkedList<File>(); + final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); // Use integer comparison for comparing prefixes (which are partition ids, in this case) final PrefixComparator prefixComparator = new PrefixComparator() { @@ -75,13 +77,15 @@ public class UnsafeExternalSorterSuite { } }; - @Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager; + ShuffleMemoryManager shuffleMemoryManager; @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager; @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager; @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext; File tempDir; + private final long pageSizeBytes = new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "64m"); + private static final class CompressStream extends AbstractFunction1<OutputStream, OutputStream> { @Override public OutputStream apply(OutputStream stream) { @@ -93,15 +97,17 @@ public class UnsafeExternalSorterSuite { public void setUp() { MockitoAnnotations.initMocks(this); tempDir = new File(Utils.createTempDir$default$1()); + shuffleMemoryManager = new ShuffleMemoryManager(Long.MAX_VALUE); + spillFilesCreated.clear(); taskContext = mock(TaskContext.class); when(taskContext.taskMetrics()).thenReturn(new TaskMetrics()); - when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg()); when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); when(diskBlockManager.createTempLocalBlock()).thenAnswer(new Answer<Tuple2<TempLocalBlockId, File>>() { @Override public Tuple2<TempLocalBlockId, File> answer(InvocationOnMock invocationOnMock) throws Throwable { TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID()); File file = File.createTempFile("spillFile", ".spill", tempDir); + spillFilesCreated.add(file); return Tuple2$.MODULE$.apply(blockId, file); } }); @@ -130,6 +136,24 @@ public class UnsafeExternalSorterSuite { .then(returnsSecondArg()); } + @After + public void tearDown() { + long leakedUnsafeMemory = taskMemoryManager.cleanUpAllAllocatedMemory(); + if (shuffleMemoryManager != null) { + long leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask(); + shuffleMemoryManager = null; + assertEquals(0L, leakedShuffleMemory); + } + assertEquals(0, leakedUnsafeMemory); + } + + private void assertSpillFilesWereCleanedUp() { + for (File spillFile : spillFilesCreated) { + assertFalse("Spill file " + spillFile.getPath() + " was not cleaned up", + spillFile.exists()); + } + } + private static void insertNumber(UnsafeExternalSorter sorter, int value) throws Exception { final int[] arr = new int[] { value }; sorter.insertRecord(arr, PlatformDependent.INT_ARRAY_OFFSET, 4, value); @@ -138,15 +162,15 @@ public class UnsafeExternalSorterSuite { @Test public void testSortingOnlyByPrefix() throws Exception { - final UnsafeExternalSorter sorter = new UnsafeExternalSorter( - memoryManager, + final UnsafeExternalSorter sorter = UnsafeExternalSorter.create( + taskMemoryManager, shuffleMemoryManager, blockManager, taskContext, recordComparator, prefixComparator, - 1024, - new SparkConf()); + /* initialSize */ 1024, + pageSizeBytes); insertNumber(sorter, 5); insertNumber(sorter, 1); @@ -165,22 +189,22 @@ public class UnsafeExternalSorterSuite { // TODO: read rest of value. } - // TODO: test for cleanup: - // assert(tempDir.isEmpty) + sorter.freeMemory(); + assertSpillFilesWereCleanedUp(); } @Test public void testSortingEmptyArrays() throws Exception { - final UnsafeExternalSorter sorter = new UnsafeExternalSorter( - memoryManager, + final UnsafeExternalSorter sorter = UnsafeExternalSorter.create( + taskMemoryManager, shuffleMemoryManager, blockManager, taskContext, recordComparator, prefixComparator, - 1024, - new SparkConf()); + /* initialSize */ 1024, + pageSizeBytes); sorter.insertRecord(null, 0, 0, 0); sorter.insertRecord(null, 0, 0, 0); @@ -197,25 +221,30 @@ public class UnsafeExternalSorterSuite { assertEquals(0, iter.getKeyPrefix()); assertEquals(0, iter.getRecordLength()); } + + sorter.freeMemory(); + assertSpillFilesWereCleanedUp(); } @Test public void testFillingPage() throws Exception { - final UnsafeExternalSorter sorter = new UnsafeExternalSorter( - memoryManager, + + final UnsafeExternalSorter sorter = UnsafeExternalSorter.create( + taskMemoryManager, shuffleMemoryManager, blockManager, taskContext, recordComparator, prefixComparator, - 1024, - new SparkConf()); + /* initialSize */ 1024, + pageSizeBytes); byte[] record = new byte[16]; while (sorter.getNumberOfAllocatedPages() < 2) { sorter.insertRecord(record, PlatformDependent.BYTE_ARRAY_OFFSET, record.length, 0); } sorter.freeMemory(); + assertSpillFilesWereCleanedUp(); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 1b475b2492..b4fc0b7b70 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -507,7 +507,8 @@ public final class UnsafeRow extends MutableRow { public String toString() { StringBuilder build = new StringBuilder("["); for (int i = 0; i < sizeInBytes; i += 8) { - build.append(PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + i)); + build.append(java.lang.Long.toHexString( + PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + i))); build.append(','); } build.append(']'); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 68c49feae9..5e4c6232c9 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -59,20 +59,21 @@ final class UnsafeExternalRowSorter { StructType schema, Ordering<InternalRow> ordering, PrefixComparator prefixComparator, - PrefixComputer prefixComputer) throws IOException { + PrefixComputer prefixComputer, + long pageSizeBytes) throws IOException { this.schema = schema; this.prefixComputer = prefixComputer; final SparkEnv sparkEnv = SparkEnv.get(); final TaskContext taskContext = TaskContext.get(); - sorter = new UnsafeExternalSorter( + sorter = UnsafeExternalSorter.create( taskContext.taskMemoryManager(), sparkEnv.shuffleMemoryManager(), sparkEnv.blockManager(), taskContext, new RowComparator(ordering, schema.length()), prefixComparator, - 4096, - sparkEnv.conf() + /* initialSize */ 4096, + pageSizeBytes ); } diff --git a/sql/core/pom.xml b/sql/core/pom.xml index be0966641b..349007789f 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -106,6 +106,11 @@ <artifactId>parquet-avro</artifactId> <scope>test</scope> </dependency> + <dependency> + <groupId>org.mockito</groupId> + <artifactId>mockito-core</artifactId> + <scope>test</scope> + </dependency> </dependencies> <build> <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory> diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index a0a8dd5154..9e2c9334a7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -19,24 +19,18 @@ package org.apache.spark.sql.execution; import java.io.IOException; +import org.apache.spark.SparkEnv; import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; -import org.apache.spark.sql.catalyst.expressions.codegen.BaseOrdering; -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.KVIterator; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.map.BytesToBytesMap; -import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.MemoryLocation; import org.apache.spark.unsafe.memory.TaskMemoryManager; -import org.apache.spark.util.collection.unsafe.sort.PrefixComparator; -import org.apache.spark.util.collection.unsafe.sort.RecordComparator; -import org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter; -import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator; /** * Unsafe-based HashMap for performing aggregations where the aggregated values are fixed-width. @@ -215,7 +209,7 @@ public final class UnsafeFixedWidthAggregationMap { } /** - * Free the unsafe memory associated with this map. + * Free the memory associated with this map. This is idempotent and can be called multiple times. */ public void free() { map.free(); @@ -233,92 +227,17 @@ public final class UnsafeFixedWidthAggregationMap { } /** - * Sorts the key, value data in this map in place, and return them as an iterator. + * Sorts the map's records in place, spill them to disk, and returns an [[UnsafeKVExternalSorter]] + * that can be used to insert more records to do external sorting. * * The only memory that is allocated is the address/prefix array, 16 bytes per record. + * + * Note that this destroys the map, and as a result, the map cannot be used anymore after this. */ - public KVIterator<UnsafeRow, UnsafeRow> sortedIterator() { - int numElements = map.numElements(); - final int numKeyFields = groupingKeySchema.size(); - TaskMemoryManager memoryManager = map.getTaskMemoryManager(); - - UnsafeExternalRowSorter.PrefixComputer prefixComp = - SortPrefixUtils.createPrefixGenerator(groupingKeySchema); - PrefixComparator prefixComparator = SortPrefixUtils.getPrefixComparator(groupingKeySchema); - - final BaseOrdering ordering = GenerateOrdering.create(groupingKeySchema); - RecordComparator recordComparator = new RecordComparator() { - private final UnsafeRow row1 = new UnsafeRow(); - private final UnsafeRow row2 = new UnsafeRow(); - - @Override - public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) { - row1.pointTo(baseObj1, baseOff1 + 4, numKeyFields, -1); - row2.pointTo(baseObj2, baseOff2 + 4, numKeyFields, -1); - return ordering.compare(row1, row2); - } - }; - - // Insert the records into the in-memory sorter. - final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter( - memoryManager, recordComparator, prefixComparator, numElements); - - BytesToBytesMap.BytesToBytesMapIterator iter = map.iterator(); - UnsafeRow row = new UnsafeRow(); - while (iter.hasNext()) { - final BytesToBytesMap.Location loc = iter.next(); - final Object baseObject = loc.getKeyAddress().getBaseObject(); - final long baseOffset = loc.getKeyAddress().getBaseOffset(); - - // Get encoded memory address - MemoryBlock page = loc.getMemoryPage(); - long address = memoryManager.encodePageNumberAndOffset(page, baseOffset - 8); - - // Compute prefix - row.pointTo(baseObject, baseOffset, numKeyFields, loc.getKeyLength()); - final long prefix = prefixComp.computePrefix(row); - - sorter.insertRecord(address, prefix); - } - - // Return the sorted result as an iterator. - return new KVIterator<UnsafeRow, UnsafeRow>() { - - private UnsafeSorterIterator sortedIterator = sorter.getSortedIterator(); - private final UnsafeRow key = new UnsafeRow(); - private final UnsafeRow value = new UnsafeRow(); - private int numValueFields = aggregationBufferSchema.size(); - - @Override - public boolean next() throws IOException { - if (sortedIterator.hasNext()) { - sortedIterator.loadNext(); - Object baseObj = sortedIterator.getBaseObject(); - long recordOffset = sortedIterator.getBaseOffset(); - int recordLen = sortedIterator.getRecordLength(); - int keyLen = PlatformDependent.UNSAFE.getInt(baseObj, recordOffset); - key.pointTo(baseObj, recordOffset + 4, numKeyFields, keyLen); - value.pointTo(baseObj, recordOffset + 4 + keyLen, numValueFields, recordLen - keyLen); - return true; - } else { - return false; - } - } - - @Override - public UnsafeRow getKey() { - return key; - } - - @Override - public UnsafeRow getValue() { - return value; - } - - @Override - public void close() { - // Do nothing - } - }; + public UnsafeKVExternalSorter destructAndCreateExternalSorter() throws IOException { + UnsafeKVExternalSorter sorter = new UnsafeKVExternalSorter( + groupingKeySchema, aggregationBufferSchema, + SparkEnv.get().blockManager(), map.getShuffleMemoryManager(), map.getPageSizeBytes(), map); + return sorter; } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java new file mode 100644 index 0000000000..f6b0176863 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution; + +import java.io.IOException; + +import javax.annotation.Nullable; + +import com.google.common.annotations.VisibleForTesting; + +import org.apache.spark.TaskContext; +import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.catalyst.expressions.codegen.BaseOrdering; +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.storage.BlockManager; +import org.apache.spark.unsafe.KVIterator; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.map.BytesToBytesMap; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.util.collection.unsafe.sort.*; + +/** + * A class for performing external sorting on key-value records. Both key and value are UnsafeRows. + * + * Note that this class allows optionally passing in a {@link BytesToBytesMap} directly in order + * to perform in-place sorting of records in the map. + */ +public final class UnsafeKVExternalSorter { + + private final StructType keySchema; + private final StructType valueSchema; + private final UnsafeExternalRowSorter.PrefixComputer prefixComputer; + private final UnsafeExternalSorter sorter; + + public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema, + BlockManager blockManager, ShuffleMemoryManager shuffleMemoryManager, long pageSizeBytes) + throws IOException { + this(keySchema, valueSchema, blockManager, shuffleMemoryManager, pageSizeBytes, null); + } + + public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema, + BlockManager blockManager, ShuffleMemoryManager shuffleMemoryManager, long pageSizeBytes, + @Nullable BytesToBytesMap map) throws IOException { + this.keySchema = keySchema; + this.valueSchema = valueSchema; + final TaskContext taskContext = TaskContext.get(); + + prefixComputer = SortPrefixUtils.createPrefixGenerator(keySchema); + PrefixComparator prefixComparator = SortPrefixUtils.getPrefixComparator(keySchema); + BaseOrdering ordering = GenerateOrdering.create(keySchema); + KVComparator recordComparator = new KVComparator(ordering, keySchema.length()); + + TaskMemoryManager taskMemoryManager = taskContext.taskMemoryManager(); + + if (map == null) { + sorter = UnsafeExternalSorter.create( + taskMemoryManager, + shuffleMemoryManager, + blockManager, + taskContext, + recordComparator, + prefixComparator, + /* initialSize */ 4096, + pageSizeBytes); + } else { + // Insert the records into the in-memory sorter. + final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter( + taskMemoryManager, recordComparator, prefixComparator, map.numElements()); + + final int numKeyFields = keySchema.size(); + BytesToBytesMap.BytesToBytesMapIterator iter = map.iterator(); + UnsafeRow row = new UnsafeRow(); + while (iter.hasNext()) { + final BytesToBytesMap.Location loc = iter.next(); + final Object baseObject = loc.getKeyAddress().getBaseObject(); + final long baseOffset = loc.getKeyAddress().getBaseOffset(); + + // Get encoded memory address + // baseObject + baseOffset point to the beginning of the key data in the map, but that + // the KV-pair's length data is stored in the word immediately before that address + MemoryBlock page = loc.getMemoryPage(); + long address = taskMemoryManager.encodePageNumberAndOffset(page, baseOffset - 8); + + // Compute prefix + row.pointTo(baseObject, baseOffset, numKeyFields, loc.getKeyLength()); + final long prefix = prefixComputer.computePrefix(row); + + inMemSorter.insertRecord(address, prefix); + } + + sorter = UnsafeExternalSorter.createWithExistingInMemorySorter( + taskContext.taskMemoryManager(), + shuffleMemoryManager, + blockManager, + taskContext, + new KVComparator(ordering, keySchema.length()), + prefixComparator, + /* initialSize */ 4096, + pageSizeBytes, + inMemSorter); + + sorter.spill(); + map.free(); + } + } + + /** + * Inserts a key-value record into the sorter. If the sorter no longer has enough memory to hold + * the record, the sorter sorts the existing records in-memory, writes them out as partially + * sorted runs, and then reallocates memory to hold the new record. + */ + public void insertKV(UnsafeRow key, UnsafeRow value) throws IOException { + final long prefix = prefixComputer.computePrefix(key); + sorter.insertKVRecord( + key.getBaseObject(), key.getBaseOffset(), key.getSizeInBytes(), + value.getBaseObject(), value.getBaseOffset(), value.getSizeInBytes(), prefix); + } + + public KVIterator<UnsafeRow, UnsafeRow> sortedIterator() throws IOException { + try { + final UnsafeSorterIterator underlying = sorter.getSortedIterator(); + if (!underlying.hasNext()) { + // Since we won't ever call next() on an empty iterator, we need to clean up resources + // here in order to prevent memory leaks. + cleanupResources(); + } + + return new KVIterator<UnsafeRow, UnsafeRow>() { + private UnsafeRow key = new UnsafeRow(); + private UnsafeRow value = new UnsafeRow(); + private int numKeyFields = keySchema.size(); + private int numValueFields = valueSchema.size(); + + @Override + public boolean next() throws IOException { + try { + if (underlying.hasNext()) { + underlying.loadNext(); + + Object baseObj = underlying.getBaseObject(); + long recordOffset = underlying.getBaseOffset(); + int recordLen = underlying.getRecordLength(); + + // Note that recordLen = keyLen + valueLen + 4 bytes (for the keyLen itself) + int keyLen = PlatformDependent.UNSAFE.getInt(baseObj, recordOffset); + int valueLen = recordLen - keyLen - 4; + + key.pointTo(baseObj, recordOffset + 4, numKeyFields, keyLen); + value.pointTo(baseObj, recordOffset + 4 + keyLen, numValueFields, valueLen); + + return true; + } else { + key = null; + value = null; + cleanupResources(); + return false; + } + } catch (IOException e) { + cleanupResources(); + throw e; + } + } + + @Override + public UnsafeRow getKey() { + return key; + } + + @Override + public UnsafeRow getValue() { + return value; + } + + @Override + public void close() { + cleanupResources(); + } + }; + } catch (IOException e) { + cleanupResources(); + throw e; + } + } + + /** + * Marks the current page as no-more-space-available, and as a result, either allocate a + * new page or spill when we see the next record. + */ + @VisibleForTesting + void closeCurrentPage() { + sorter.closeCurrentPage(); + } + + private void cleanupResources() { + sorter.freeMemory(); + } + + private static final class KVComparator extends RecordComparator { + private final BaseOrdering ordering; + private final UnsafeRow row1 = new UnsafeRow(); + private final UnsafeRow row2 = new UnsafeRow(); + private final int numKeyFields; + + public KVComparator(BaseOrdering ordering, int numKeyFields) { + this.numKeyFields = numKeyFields; + this.ordering = ordering; + } + + @Override + public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) { + // Note that since ordering doesn't need the total length of the record, we just pass -1 + // into the row. + row1.pointTo(baseObj1, baseOff1 + 4, numKeyFields, -1); + row2.pointTo(baseObj2, baseOff2 + 4, numKeyFields, -1); + return ordering.compare(row1, row2); + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala index 2e870ec8ae..49adf21537 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala @@ -50,17 +50,36 @@ object SortPrefixUtils { } } + /** + * Creates the prefix comparator for the first field in the given schema, in ascending order. + */ def getPrefixComparator(schema: StructType): PrefixComparator = { - val field = schema.head - getPrefixComparator(SortOrder(BoundReference(0, field.dataType, field.nullable), Ascending)) + if (schema.nonEmpty) { + val field = schema.head + getPrefixComparator(SortOrder(BoundReference(0, field.dataType, field.nullable), Ascending)) + } else { + new PrefixComparator { + override def compare(prefix1: Long, prefix2: Long): Int = 0 + } + } } + /** + * Creates the prefix computer for the first field in the given schema, in ascending order. + */ def createPrefixGenerator(schema: StructType): UnsafeExternalRowSorter.PrefixComputer = { - val boundReference = BoundReference(0, schema.head.dataType, nullable = true) - val prefixProjection = UnsafeProjection.create(SortPrefix(SortOrder(boundReference, Ascending))) - new UnsafeExternalRowSorter.PrefixComputer { - override def computePrefix(row: InternalRow): Long = { - prefixProjection.apply(row).getLong(0) + if (schema.nonEmpty) { + val boundReference = BoundReference(0, schema.head.dataType, nullable = true) + val prefixProjection = UnsafeProjection.create( + SortPrefix(SortOrder(boundReference, Ascending))) + new UnsafeExternalRowSorter.PrefixComputer { + override def computePrefix(row: InternalRow): Long = { + prefixProjection.apply(row).getLong(0) + } + } + } else { + new UnsafeExternalRowSorter.PrefixComputer { + override def computePrefix(row: InternalRow): Long = 0 } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala index 6d903ab23c..92cf328c76 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala @@ -116,6 +116,7 @@ case class TungstenSort( protected override def doExecute(): RDD[InternalRow] = { val schema = child.schema val childOutput = child.output + val pageSize = sparkContext.conf.getSizeAsBytes("spark.buffer.pageSize", "64m") child.execute().mapPartitions({ iter => val ordering = newOrdering(sortOrder, childOutput) @@ -131,7 +132,8 @@ case class TungstenSort( } } - val sorter = new UnsafeExternalRowSorter(schema, ordering, prefixComparator, prefixComputer) + val sorter = new UnsafeExternalRowSorter( + schema, ordering, prefixComparator, prefixComputer, pageSize) if (testSpillFrequency > 0) { sorter.setTestSpillFrequency(testSpillFrequency) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala new file mode 100644 index 0000000000..53de2d0f07 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.shuffle.ShuffleMemoryManager + +/** + * A [[ShuffleMemoryManager]] that can be controlled to run out of memory. + */ +class TestShuffleMemoryManager extends ShuffleMemoryManager(Long.MaxValue) { + private var oom = false + + override def tryToAcquire(numBytes: Long): Long = { + if (oom) { + oom = false + 0 + } else { + // Uncomment the following to trace memory allocations. + // println(s"tryToAcquire $numBytes in " + + // Thread.currentThread().getStackTrace.mkString("", "\n -", "")) + val acquired = super.tryToAcquire(numBytes) + acquired + } + } + + override def release(numBytes: Long): Unit = { + // Uncomment the following to trace memory releases. + // println(s"release $numBytes in " + + // Thread.currentThread().getStackTrace.mkString("", "\n -", "")) + super.release(numBytes) + } + + def markAsOutOfMemory(): Unit = { + oom = true + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 098bdd0017..4c94b3307d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -17,24 +17,26 @@ package org.apache.spark.sql.execution -import org.scalatest.{BeforeAndAfterEach, Matchers} - -import scala.collection.JavaConverters._ +import scala.util.control.NonFatal import scala.collection.mutable -import scala.util.Random +import scala.util.{Try, Random} + +import org.scalatest.Matchers -import org.apache.spark.SparkFunSuite -import org.apache.spark.shuffle.ShuffleMemoryManager +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.{TaskContextImpl, TaskContext, SparkFunSuite} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} import org.apache.spark.unsafe.types.UTF8String - -class UnsafeFixedWidthAggregationMapSuite - extends SparkFunSuite - with Matchers - with BeforeAndAfterEach { +/** + * Test suite for [[UnsafeFixedWidthAggregationMap]]. + * + * Use [[testWithMemoryLeakDetection]] rather than [[test]] to construct test cases. + */ +class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers { import UnsafeFixedWidthAggregationMap._ @@ -44,23 +46,40 @@ class UnsafeFixedWidthAggregationMapSuite private val PAGE_SIZE_BYTES: Long = 1L << 26; // 64 megabytes private var taskMemoryManager: TaskMemoryManager = null - private var shuffleMemoryManager: ShuffleMemoryManager = null + private var shuffleMemoryManager: TestShuffleMemoryManager = null + + def testWithMemoryLeakDetection(name: String)(f: => Unit) { + def cleanup(): Unit = { + if (taskMemoryManager != null) { + val leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask() + assert(taskMemoryManager.cleanUpAllAllocatedMemory() === 0) + assert(leakedShuffleMemory === 0) + taskMemoryManager = null + } + } - override def beforeEach(): Unit = { - taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) - shuffleMemoryManager = new ShuffleMemoryManager(Long.MaxValue) + test(name) { + taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) + shuffleMemoryManager = new TestShuffleMemoryManager + try { + f + } catch { + case NonFatal(e) => + Try(cleanup()) + throw e + } + cleanup() + } } - override def afterEach(): Unit = { - if (taskMemoryManager != null) { - val leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask() - assert(taskMemoryManager.cleanUpAllAllocatedMemory() === 0) - assert(leakedShuffleMemory === 0) - taskMemoryManager = null - } + private def randomStrings(n: Int): Seq[String] = { + val rand = new Random(42) + Seq.fill(512) { + Seq.fill(rand.nextInt(100))(rand.nextPrintableChar()).mkString + }.distinct } - test("supported schemas") { + testWithMemoryLeakDetection("supported schemas") { assert(supportsAggregationBufferSchema( StructType(StructField("x", DecimalType.USER_DEFAULT) :: Nil))) assert(!supportsAggregationBufferSchema( @@ -70,7 +89,7 @@ class UnsafeFixedWidthAggregationMapSuite !supportsAggregationBufferSchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil))) } - test("empty map") { + testWithMemoryLeakDetection("empty map") { val map = new UnsafeFixedWidthAggregationMap( emptyAggregationBuffer, aggBufferSchema, @@ -85,7 +104,7 @@ class UnsafeFixedWidthAggregationMapSuite map.free() } - test("updating values for a single key") { + testWithMemoryLeakDetection("updating values for a single key") { val map = new UnsafeFixedWidthAggregationMap( emptyAggregationBuffer, aggBufferSchema, @@ -113,7 +132,7 @@ class UnsafeFixedWidthAggregationMapSuite map.free() } - test("inserting large random keys") { + testWithMemoryLeakDetection("inserting large random keys") { val map = new UnsafeFixedWidthAggregationMap( emptyAggregationBuffer, aggBufferSchema, @@ -140,7 +159,21 @@ class UnsafeFixedWidthAggregationMapSuite map.free() } - test("test sorting") { + testWithMemoryLeakDetection("test external sorting") { + // Calling this make sure we have block manager and everything else setup. + TestSQLContext + + TaskContext.setTaskContext(new TaskContextImpl( + stageId = 0, + partitionId = 0, + taskAttemptId = 0, + attemptNumber = 0, + taskMemoryManager = taskMemoryManager, + metricsSystem = null)) + + // Memory consumption in the beginning of the task. + val initialMemoryConsumption = shuffleMemoryManager.getMemoryConsumptionForThisTask() + val map = new UnsafeFixedWidthAggregationMap( emptyAggregationBuffer, aggBufferSchema, @@ -152,26 +185,47 @@ class UnsafeFixedWidthAggregationMapSuite false // disable perf metrics ) - val rand = new Random(42) - val groupKeys: Set[String] = Seq.fill(512) { - Seq.fill(rand.nextInt(100))(rand.nextPrintableChar()).mkString - }.toSet - groupKeys.foreach { keyString => + val keys = randomStrings(1024).take(512) + keys.foreach { keyString => val buf = map.getAggregationBuffer(InternalRow(UTF8String.fromString(keyString))) buf.setInt(0, keyString.length) assert(buf != null) } + // Convert the map into a sorter + val sorter = map.destructAndCreateExternalSorter() + + withClue(s"destructAndCreateExternalSorter should release memory used by the map") { + // 4096 * 16 is the initial size allocated for the pointer/prefix array in the in-mem sorter. + assert(shuffleMemoryManager.getMemoryConsumptionForThisTask() === + initialMemoryConsumption + 4096 * 16) + } + + // Add more keys to the sorter and make sure the results come out sorted. + val additionalKeys = randomStrings(1024) + val keyConverter = UnsafeProjection.create(groupKeySchema) + val valueConverter = UnsafeProjection.create(aggBufferSchema) + + additionalKeys.zipWithIndex.foreach { case (str, i) => + val k = InternalRow(UTF8String.fromString(str)) + val v = InternalRow(str.length) + sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v)) + + if ((i % 100) == 0) { + shuffleMemoryManager.markAsOutOfMemory() + sorter.closeCurrentPage() + } + } + val out = new scala.collection.mutable.ArrayBuffer[String] - val iter = map.sortedIterator() + val iter = sorter.sortedIterator() while (iter.next()) { assert(iter.getKey.getString(0).length === iter.getValue.getInt(0)) out += iter.getKey.getString(0) } - assert(out === groupKeys.toSeq.sorted) + assert(out === (keys ++ additionalKeys).sorted) map.free() } - } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala new file mode 100644 index 0000000000..5d214d7bfc --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import scala.util.Random + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{RowOrdering, UnsafeProjection} +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark._ + +class UnsafeKVExternalSorterSuite extends SparkFunSuite { + + test("sorting string key and int int value") { + + // Calling this make sure we have block manager and everything else setup. + TestSQLContext + + val taskMemMgr = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) + val shuffleMemMgr = new TestShuffleMemoryManager + + TaskContext.setTaskContext(new TaskContextImpl( + stageId = 0, + partitionId = 0, + taskAttemptId = 0, + attemptNumber = 0, + taskMemoryManager = taskMemMgr, + metricsSystem = null)) + + val keySchema = new StructType().add("a", StringType) + val valueSchema = new StructType().add("b", IntegerType).add("c", IntegerType) + val sorter = new UnsafeKVExternalSorter( + keySchema, valueSchema, SparkEnv.get.blockManager, shuffleMemMgr, + 16 * 1024) + + val keyConverter = UnsafeProjection.create(keySchema) + val valueConverter = UnsafeProjection.create(valueSchema) + + val rand = new Random(42) + val data = null +: Seq.fill[String](10) { + Seq.fill(rand.nextInt(100))(rand.nextPrintableChar()).mkString + } + + val inputRows = data.map { str => + keyConverter.apply(InternalRow(UTF8String.fromString(str))).copy() + } + + var i = 0 + data.foreach { str => + if (str != null) { + val k = InternalRow(UTF8String.fromString(str)) + val v = InternalRow(str.length, str.length + 1) + sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v)) + } else { + val k = InternalRow(UTF8String.fromString(str)) + val v = InternalRow(-1, -2) + sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v)) + } + + if ((i % 100) == 0) { + shuffleMemMgr.markAsOutOfMemory() + sorter.closeCurrentPage() + } + i += 1 + } + + val out = new scala.collection.mutable.ArrayBuffer[InternalRow] + val iter = sorter.sortedIterator() + while (iter.next()) { + if (iter.getKey.getUTF8String(0) == null) { + withClue(s"for null key") { + assert(-1 === iter.getValue.getInt(0)) + assert(-2 === iter.getValue.getInt(1)) + } + } else { + val key = iter.getKey.getString(0) + withClue(s"for key $key") { + assert(key.length === iter.getValue.getInt(0)) + assert(key.length + 1 === iter.getValue.getInt(1)) + } + } + out += iter.getKey.copy() + } + + assert(out === inputRows.sorted(RowOrdering.forSchema(keySchema.map(_.dataType)))) + } + + test("sorting arbitrary string data") { + + // Calling this make sure we have block manager and everything else setup. + TestSQLContext + + val taskMemMgr = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) + val shuffleMemMgr = new TestShuffleMemoryManager + + TaskContext.setTaskContext(new TaskContextImpl( + stageId = 0, + partitionId = 0, + taskAttemptId = 0, + attemptNumber = 0, + taskMemoryManager = taskMemMgr, + metricsSystem = null)) + + val keySchema = new StructType().add("a", StringType) + val valueSchema = new StructType().add("b", IntegerType) + val sorter = new UnsafeKVExternalSorter( + keySchema, valueSchema, SparkEnv.get.blockManager, shuffleMemMgr, + 16 * 1024) + + val keyConverter = UnsafeProjection.create(keySchema) + val valueConverter = UnsafeProjection.create(valueSchema) + + val rand = new Random(42) + val data = Seq.fill(512) { + Seq.fill(rand.nextInt(100))(rand.nextPrintableChar()).mkString + } + + var i = 0 + data.foreach { str => + val k = InternalRow(UTF8String.fromString(str)) + val v = InternalRow(str.length) + sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v)) + + if ((i % 100) == 0) { + shuffleMemMgr.markAsOutOfMemory() + sorter.closeCurrentPage() + } + i += 1 + } + + val out = new scala.collection.mutable.ArrayBuffer[String] + val iter = sorter.sortedIterator() + while (iter.next()) { + assert(iter.getKey.getString(0).length === iter.getValue.getInt(0)) + out += iter.getKey.getString(0) + } + + assert(out === data.sorted) + } +} |