From ab8ee1a3b93286a62949569615086ef5030e9fae Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 4 Aug 2015 14:42:11 -0700 Subject: [SPARK-9452] [SQL] Support records larger than page size in UnsafeExternalSorter This patch extends UnsafeExternalSorter to support records larger than the page size. The basic strategy is the same as in #7762: store large records in their own overflow pages. Author: Josh Rosen Closes #7891 from JoshRosen/large-records-in-sql-sorter and squashes the following commits: 967580b [Josh Rosen] Merge remote-tracking branch 'origin/master' into large-records-in-sql-sorter 948c344 [Josh Rosen] Add large records tests for KV sorter. 3c17288 [Josh Rosen] Combine memory and disk cleanup into general cleanupResources() method 380f217 [Josh Rosen] Merge remote-tracking branch 'origin/master' into large-records-in-sql-sorter 27eafa0 [Josh Rosen] Fix page size in PackedRecordPointerSuite a49baef [Josh Rosen] Address initial round of review comments 3edb931 [Josh Rosen] Remove accidentally-committed debug statements. 2b164e2 [Josh Rosen] Support large records in UnsafeExternalSorter. --- .../unsafe/sort/UnsafeExternalSorter.java | 173 ++++++++++++----- .../shuffle/unsafe/PackedRecordPointerSuite.java | 8 +- .../unsafe/sort/UnsafeExternalSorterSuite.java | 129 ++++++++++--- .../sql/execution/UnsafeExternalRowSorter.java | 2 +- .../sql/execution/UnsafeKVExternalSorter.java | 11 +- .../execution/UnsafeKVExternalSorterSuite.scala | 210 +++++++++++++-------- .../spark/unsafe/memory/HeapMemoryAllocator.java | 3 + .../spark/unsafe/memory/UnsafeMemoryAllocator.java | 3 + 8 files changed, 372 insertions(+), 167 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index dec7fcfa0d..e6ddd08e5f 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -34,6 +34,7 @@ import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.storage.BlockManager; +import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.TaskMemoryManager; @@ -143,8 +144,7 @@ public final class UnsafeExternalSorter { taskContext.addOnCompleteCallback(new AbstractFunction0() { @Override public BoxedUnit apply() { - deleteSpillFiles(); - freeMemory(); + cleanupResources(); return null; } }); @@ -249,7 +249,7 @@ public final class UnsafeExternalSorter { * * @return the number of bytes freed. */ - public long freeMemory() { + private long freeMemory() { updatePeakMemoryUsed(); long memoryFreed = 0; for (MemoryBlock block : allocatedPages) { @@ -275,44 +275,32 @@ public final class UnsafeExternalSorter { /** * Deletes any spill files created by this sorter. */ - public void deleteSpillFiles() { + private void deleteSpillFiles() { for (UnsafeSorterSpillWriter spill : spillWriters) { File file = spill.getFile(); if (file != null && file.exists()) { if (!file.delete()) { logger.error("Was unable to delete spill file {}", file.getAbsolutePath()); - }; + } } } } /** - * Checks whether there is enough space to insert a new record into the sorter. - * - * @param requiredSpace the required space in the data page, in bytes, including space for storing - * the record size. - - * @return true if the record can be inserted without requiring more allocations, false otherwise. + * Frees this sorter's in-memory data structures and cleans up its spill files. */ - private boolean haveSpaceForRecord(int requiredSpace) { - assert(requiredSpace > 0); - assert(inMemSorter != null); - return (inMemSorter.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage)); + public void cleanupResources() { + deleteSpillFiles(); + freeMemory(); } /** - * Allocates more memory in order to insert an additional record. This will request additional - * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be - * obtained. - * - * @param requiredSpace the required space in the data page, in bytes, including space for storing - * the record size. + * Checks whether there is enough space to insert an additional record in to the sort pointer + * array and grows the array if additional space is required. If the required space cannot be + * obtained, then the in-memory data will be spilled to disk. */ - private void allocateSpaceForRecord(int requiredSpace) throws IOException { + private void growPointerArrayIfNecessary() throws IOException { assert(inMemSorter != null); - // TODO: merge these steps to first calculate total memory requirements for this insert, - // then try to acquire; no point in acquiring sort buffer only to spill due to no space in the - // data page. if (!inMemSorter.hasSpaceForAnotherRecord()) { logger.debug("Attempting to expand sort pointer array"); final long oldPointerArrayMemoryUsage = inMemSorter.getMemoryUsage(); @@ -326,7 +314,20 @@ public final class UnsafeExternalSorter { shuffleMemoryManager.release(oldPointerArrayMemoryUsage); } } + } + /** + * Allocates more memory in order to insert an additional record. This will request additional + * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be + * obtained. + * + * @param requiredSpace the required space in the data page, in bytes, including space for storing + * the record size. This must be less than or equal to the page size (records + * that exceed the page size are handled via a different code path which uses + * special overflow pages). + */ + private void acquireNewPageIfNecessary(int requiredSpace) throws IOException { + assert (requiredSpace <= pageSizeBytes); if (requiredSpace > freeSpaceInCurrentPage) { logger.trace("Required space {} is less than free space in current page ({})", requiredSpace, freeSpaceInCurrentPage); @@ -339,9 +340,7 @@ public final class UnsafeExternalSorter { } else { final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes); if (memoryAcquired < pageSizeBytes) { - if (memoryAcquired > 0) { - shuffleMemoryManager.release(memoryAcquired); - } + shuffleMemoryManager.release(memoryAcquired); spill(); final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes); if (memoryAcquiredAfterSpilling != pageSizeBytes) { @@ -365,26 +364,59 @@ public final class UnsafeExternalSorter { long recordBaseOffset, int lengthInBytes, long prefix) throws IOException { + + growPointerArrayIfNecessary(); // Need 4 bytes to store the record length. final int totalSpaceRequired = lengthInBytes + 4; - if (!haveSpaceForRecord(totalSpaceRequired)) { - allocateSpaceForRecord(totalSpaceRequired); + + // --- Figure out where to insert the new record ---------------------------------------------- + + final MemoryBlock dataPage; + long dataPagePosition; + boolean useOverflowPage = totalSpaceRequired > pageSizeBytes; + if (useOverflowPage) { + long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired); + // The record is larger than the page size, so allocate a special overflow page just to hold + // that record. + final long memoryGranted = shuffleMemoryManager.tryToAcquire(overflowPageSize); + if (memoryGranted != overflowPageSize) { + shuffleMemoryManager.release(memoryGranted); + spill(); + final long memoryGrantedAfterSpill = shuffleMemoryManager.tryToAcquire(overflowPageSize); + if (memoryGrantedAfterSpill != overflowPageSize) { + shuffleMemoryManager.release(memoryGrantedAfterSpill); + throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory"); + } + } + MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize); + allocatedPages.add(overflowPage); + dataPage = overflowPage; + dataPagePosition = overflowPage.getBaseOffset(); + } else { + // The record is small enough to fit in a regular data page, but the current page might not + // have enough space to hold it (or no pages have been allocated yet). + acquireNewPageIfNecessary(totalSpaceRequired); + dataPage = currentPage; + dataPagePosition = currentPagePosition; + // Update bookkeeping information + freeSpaceInCurrentPage -= totalSpaceRequired; + currentPagePosition += totalSpaceRequired; } - assert(inMemSorter != null); + final Object dataPageBaseObject = dataPage.getBaseObject(); + + // --- Insert the record ---------------------------------------------------------------------- final long recordAddress = - taskMemoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition); - final Object dataPageBaseObject = currentPage.getBaseObject(); - PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, lengthInBytes); - currentPagePosition += 4; + taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition); + PlatformDependent.UNSAFE.putInt(dataPageBaseObject, dataPagePosition, lengthInBytes); + dataPagePosition += 4; PlatformDependent.copyMemory( recordBaseObject, recordBaseOffset, dataPageBaseObject, - currentPagePosition, + dataPagePosition, lengthInBytes); - currentPagePosition += lengthInBytes; - freeSpaceInCurrentPage -= totalSpaceRequired; + assert(inMemSorter != null); inMemSorter.insertRecord(recordAddress, prefix); } @@ -399,33 +431,70 @@ public final class UnsafeExternalSorter { public void insertKVRecord( Object keyBaseObj, long keyOffset, int keyLen, Object valueBaseObj, long valueOffset, int valueLen, long prefix) throws IOException { + + growPointerArrayIfNecessary(); final int totalSpaceRequired = keyLen + valueLen + 4 + 4; - if (!haveSpaceForRecord(totalSpaceRequired)) { - allocateSpaceForRecord(totalSpaceRequired); + + // --- Figure out where to insert the new record ---------------------------------------------- + + final MemoryBlock dataPage; + long dataPagePosition; + boolean useOverflowPage = totalSpaceRequired > pageSizeBytes; + if (useOverflowPage) { + long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired); + // The record is larger than the page size, so allocate a special overflow page just to hold + // that record. + final long memoryGranted = shuffleMemoryManager.tryToAcquire(overflowPageSize); + if (memoryGranted != overflowPageSize) { + shuffleMemoryManager.release(memoryGranted); + spill(); + final long memoryGrantedAfterSpill = shuffleMemoryManager.tryToAcquire(overflowPageSize); + if (memoryGrantedAfterSpill != overflowPageSize) { + shuffleMemoryManager.release(memoryGrantedAfterSpill); + throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory"); + } + } + MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize); + allocatedPages.add(overflowPage); + dataPage = overflowPage; + dataPagePosition = overflowPage.getBaseOffset(); + } else { + // The record is small enough to fit in a regular data page, but the current page might not + // have enough space to hold it (or no pages have been allocated yet). + acquireNewPageIfNecessary(totalSpaceRequired); + dataPage = currentPage; + dataPagePosition = currentPagePosition; + // Update bookkeeping information + freeSpaceInCurrentPage -= totalSpaceRequired; + currentPagePosition += totalSpaceRequired; } - assert(inMemSorter != null); + final Object dataPageBaseObject = dataPage.getBaseObject(); + + // --- Insert the record ---------------------------------------------------------------------- final long recordAddress = - taskMemoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition); - final Object dataPageBaseObject = currentPage.getBaseObject(); - PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, keyLen + valueLen + 4); - currentPagePosition += 4; + taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition); + PlatformDependent.UNSAFE.putInt(dataPageBaseObject, dataPagePosition, keyLen + valueLen + 4); + dataPagePosition += 4; - PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, keyLen); - currentPagePosition += 4; + PlatformDependent.UNSAFE.putInt(dataPageBaseObject, dataPagePosition, keyLen); + dataPagePosition += 4; PlatformDependent.copyMemory( - keyBaseObj, keyOffset, dataPageBaseObject, currentPagePosition, keyLen); - currentPagePosition += keyLen; + keyBaseObj, keyOffset, dataPageBaseObject, dataPagePosition, keyLen); + dataPagePosition += keyLen; PlatformDependent.copyMemory( - valueBaseObj, valueOffset, dataPageBaseObject, currentPagePosition, valueLen); - currentPagePosition += valueLen; + valueBaseObj, valueOffset, dataPageBaseObject, dataPagePosition, valueLen); - freeSpaceInCurrentPage -= totalSpaceRequired; + assert(inMemSorter != null); inMemSorter.insertRecord(recordAddress, prefix); } + /** + * Returns a sorted iterator. It is the caller's responsibility to call `cleanupResources()` + * after consuming this iterator. + */ public UnsafeSorterIterator getSortedIterator() throws IOException { assert(inMemSorter != null); final UnsafeInMemorySorter.SortedIterator inMemoryIterator = inMemSorter.getSortedIterator(); diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java index db9e827590..934b7e0305 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java @@ -32,8 +32,8 @@ public class PackedRecordPointerSuite { public void heap() { final TaskMemoryManager memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); - final MemoryBlock page0 = memoryManager.allocatePage(100); - final MemoryBlock page1 = memoryManager.allocatePage(100); + final MemoryBlock page0 = memoryManager.allocatePage(128); + final MemoryBlock page1 = memoryManager.allocatePage(128); final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1, page1.getBaseOffset() + 42); PackedRecordPointer packedPointer = new PackedRecordPointer(); @@ -50,8 +50,8 @@ public class PackedRecordPointerSuite { public void offHeap() { final TaskMemoryManager memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE)); - final MemoryBlock page0 = memoryManager.allocatePage(100); - final MemoryBlock page1 = memoryManager.allocatePage(100); + final MemoryBlock page0 = memoryManager.allocatePage(128); + final MemoryBlock page1 = memoryManager.allocatePage(128); final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1, page1.getBaseOffset() + 42); PackedRecordPointer packedPointer = new PackedRecordPointer(); diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index c11949d57a..968185bde7 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -18,8 +18,10 @@ package org.apache.spark.util.collection.unsafe.sort; import java.io.File; +import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; +import java.util.Arrays; import java.util.LinkedList; import java.util.UUID; @@ -34,6 +36,7 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.junit.Assert.*; import static org.mockito.AdditionalAnswers.returnsSecondArg; import static org.mockito.Answers.RETURNS_SMART_NULLS; @@ -77,12 +80,13 @@ public class UnsafeExternalSorterSuite { } }; + SparkConf sparkConf; + File tempDir; ShuffleMemoryManager shuffleMemoryManager; @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager; @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager; @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext; - File tempDir; private final long pageSizeBytes = new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "64m"); @@ -96,6 +100,7 @@ public class UnsafeExternalSorterSuite { @Before public void setUp() { MockitoAnnotations.initMocks(this); + sparkConf = new SparkConf(); tempDir = new File(Utils.createTempDir$default$1()); shuffleMemoryManager = new ShuffleMemoryManager(Long.MAX_VALUE); spillFilesCreated.clear(); @@ -155,14 +160,19 @@ public class UnsafeExternalSorterSuite { } private static void insertNumber(UnsafeExternalSorter sorter, int value) throws Exception { - final int[] arr = new int[] { value }; + final int[] arr = new int[]{ value }; sorter.insertRecord(arr, PlatformDependent.INT_ARRAY_OFFSET, 4, value); } - @Test - public void testSortingOnlyByPrefix() throws Exception { + private static void insertRecord( + UnsafeExternalSorter sorter, + int[] record, + long prefix) throws IOException { + sorter.insertRecord(record, PlatformDependent.INT_ARRAY_OFFSET, record.length * 4, prefix); + } - final UnsafeExternalSorter sorter = UnsafeExternalSorter.create( + private UnsafeExternalSorter newSorter() throws IOException { + return UnsafeExternalSorter.create( taskMemoryManager, shuffleMemoryManager, blockManager, @@ -171,7 +181,11 @@ public class UnsafeExternalSorterSuite { prefixComparator, /* initialSize */ 1024, pageSizeBytes); + } + @Test + public void testSortingOnlyByPrefix() throws Exception { + final UnsafeExternalSorter sorter = newSorter(); insertNumber(sorter, 5); insertNumber(sorter, 1); insertNumber(sorter, 3); @@ -186,26 +200,16 @@ public class UnsafeExternalSorterSuite { iter.loadNext(); assertEquals(i, iter.getKeyPrefix()); assertEquals(4, iter.getRecordLength()); - // TODO: read rest of value. + assertEquals(i, PlatformDependent.UNSAFE.getInt(iter.getBaseObject(), iter.getBaseOffset())); } - sorter.freeMemory(); + sorter.cleanupResources(); assertSpillFilesWereCleanedUp(); } @Test public void testSortingEmptyArrays() throws Exception { - - final UnsafeExternalSorter sorter = UnsafeExternalSorter.create( - taskMemoryManager, - shuffleMemoryManager, - blockManager, - taskContext, - recordComparator, - prefixComparator, - /* initialSize */ 1024, - pageSizeBytes); - + final UnsafeExternalSorter sorter = newSorter(); sorter.insertRecord(null, 0, 0, 0); sorter.insertRecord(null, 0, 0, 0); sorter.spill(); @@ -222,28 +226,89 @@ public class UnsafeExternalSorterSuite { assertEquals(0, iter.getRecordLength()); } - sorter.freeMemory(); + sorter.cleanupResources(); assertSpillFilesWereCleanedUp(); } @Test - public void testFillingPage() throws Exception { + public void spillingOccursInResponseToMemoryPressure() throws Exception { + shuffleMemoryManager = new ShuffleMemoryManager(pageSizeBytes * 2); + final UnsafeExternalSorter sorter = newSorter(); + final int numRecords = 100000; + for (int i = 0; i <= numRecords; i++) { + insertNumber(sorter, numRecords - i); + } + // Ensure that spill files were created + assertThat(tempDir.listFiles().length, greaterThanOrEqualTo(1)); + // Read back the sorted data: + UnsafeSorterIterator iter = sorter.getSortedIterator(); - final UnsafeExternalSorter sorter = UnsafeExternalSorter.create( - taskMemoryManager, - shuffleMemoryManager, - blockManager, - taskContext, - recordComparator, - prefixComparator, - /* initialSize */ 1024, - pageSizeBytes); + int i = 0; + while (iter.hasNext()) { + iter.loadNext(); + assertEquals(i, iter.getKeyPrefix()); + assertEquals(4, iter.getRecordLength()); + assertEquals(i, PlatformDependent.UNSAFE.getInt(iter.getBaseObject(), iter.getBaseOffset())); + i++; + } + sorter.cleanupResources(); + assertSpillFilesWereCleanedUp(); + } + @Test + public void testFillingPage() throws Exception { + final UnsafeExternalSorter sorter = newSorter(); byte[] record = new byte[16]; while (sorter.getNumberOfAllocatedPages() < 2) { sorter.insertRecord(record, PlatformDependent.BYTE_ARRAY_OFFSET, record.length, 0); } - sorter.freeMemory(); + sorter.cleanupResources(); + assertSpillFilesWereCleanedUp(); + } + + @Test + public void sortingRecordsThatExceedPageSize() throws Exception { + final UnsafeExternalSorter sorter = newSorter(); + final int[] largeRecord = new int[(int) pageSizeBytes + 16]; + Arrays.fill(largeRecord, 456); + final int[] smallRecord = new int[100]; + Arrays.fill(smallRecord, 123); + + insertRecord(sorter, largeRecord, 456); + sorter.spill(); + insertRecord(sorter, smallRecord, 123); + sorter.spill(); + insertRecord(sorter, smallRecord, 123); + insertRecord(sorter, largeRecord, 456); + + UnsafeSorterIterator iter = sorter.getSortedIterator(); + // Small record + assertTrue(iter.hasNext()); + iter.loadNext(); + assertEquals(123, iter.getKeyPrefix()); + assertEquals(smallRecord.length * 4, iter.getRecordLength()); + assertEquals(123, PlatformDependent.UNSAFE.getInt(iter.getBaseObject(), iter.getBaseOffset())); + // Small record + assertTrue(iter.hasNext()); + iter.loadNext(); + assertEquals(123, iter.getKeyPrefix()); + assertEquals(smallRecord.length * 4, iter.getRecordLength()); + assertEquals(123, PlatformDependent.UNSAFE.getInt(iter.getBaseObject(), iter.getBaseOffset())); + // Large record + assertTrue(iter.hasNext()); + iter.loadNext(); + assertEquals(456, iter.getKeyPrefix()); + assertEquals(largeRecord.length * 4, iter.getRecordLength()); + assertEquals(456, PlatformDependent.UNSAFE.getInt(iter.getBaseObject(), iter.getBaseOffset())); + // Large record + assertTrue(iter.hasNext()); + iter.loadNext(); + assertEquals(456, iter.getKeyPrefix()); + assertEquals(largeRecord.length * 4, iter.getRecordLength()); + assertEquals(456, PlatformDependent.UNSAFE.getInt(iter.getBaseObject(), iter.getBaseOffset())); + + assertFalse(iter.hasNext()); + sorter.cleanupResources(); assertSpillFilesWereCleanedUp(); } @@ -289,8 +354,10 @@ public class UnsafeExternalSorterSuite { newPeakMemory = sorter.getPeakMemoryUsedBytes(); assertEquals(previousPeakMemory, newPeakMemory); } finally { - sorter.freeMemory(); + sorter.cleanupResources(); + assertSpillFilesWereCleanedUp(); } } } + diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 193906d247..a5ae2b9736 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -114,7 +114,7 @@ final class UnsafeExternalRowSorter { } private void cleanupResources() { - sorter.freeMemory(); + sorter.cleanupResources(); } @VisibleForTesting diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index 312ec8ea0d..86a563df99 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -134,6 +134,10 @@ public final class UnsafeKVExternalSorter { value.getBaseObject(), value.getBaseOffset(), value.getSizeInBytes(), prefix); } + /** + * Returns a sorted iterator. It is the caller's responsibility to call `cleanupResources()` + * after consuming this iterator. + */ public KVSorterIterator sortedIterator() throws IOException { try { final UnsafeSorterIterator underlying = sorter.getSortedIterator(); @@ -158,8 +162,11 @@ public final class UnsafeKVExternalSorter { sorter.closeCurrentPage(); } - private void cleanupResources() { - sorter.freeMemory(); + /** + * Frees this sorter's in-memory data structures and cleans up its spill files. + */ + public void cleanupResources() { + sorter.cleanupResources(); } private static final class KVComparator extends RecordComparator { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 601a5a07ad..08156f0e39 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import scala.util.Random import org.apache.spark._ -import org.apache.spark.sql.RandomDataGenerator +import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, RowOrdering, UnsafeProjection} import org.apache.spark.sql.test.TestSQLContext @@ -46,6 +46,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite { testKVSorter(keySchema, valueSchema, spill = i > 3) } + /** * Create a test case using randomly generated data for the given key and value schema. * @@ -60,96 +61,151 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite { * If spill is set to true, the sorter will spill probabilistically roughly every 100 records. */ private def testKVSorter(keySchema: StructType, valueSchema: StructType, spill: Boolean): Unit = { + // Create the data converters + val kExternalConverter = CatalystTypeConverters.createToCatalystConverter(keySchema) + val vExternalConverter = CatalystTypeConverters.createToCatalystConverter(valueSchema) + val kConverter = UnsafeProjection.create(keySchema) + val vConverter = UnsafeProjection.create(valueSchema) + + val keyDataGen = RandomDataGenerator.forType(keySchema, nullable = false).get + val valueDataGen = RandomDataGenerator.forType(valueSchema, nullable = false).get + + val inputData = Seq.fill(1024) { + val k = kConverter(kExternalConverter.apply(keyDataGen.apply()).asInstanceOf[InternalRow]) + val v = vConverter(vExternalConverter.apply(valueDataGen.apply()).asInstanceOf[InternalRow]) + (k.asInstanceOf[InternalRow].copy(), v.asInstanceOf[InternalRow].copy()) + } val keySchemaStr = keySchema.map(_.dataType.simpleString).mkString("[", ",", "]") val valueSchemaStr = valueSchema.map(_.dataType.simpleString).mkString("[", ",", "]") test(s"kv sorting key schema $keySchemaStr and value schema $valueSchemaStr") { - // Calling this make sure we have block manager and everything else setup. - TestSQLContext - - val taskMemMgr = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) - val shuffleMemMgr = new TestShuffleMemoryManager - TaskContext.setTaskContext(new TaskContextImpl( - stageId = 0, - partitionId = 0, - taskAttemptId = 98456, - attemptNumber = 0, - taskMemoryManager = taskMemMgr, - metricsSystem = null, - internalAccumulators = Seq.empty)) - - // Create the data converters - val kExternalConverter = CatalystTypeConverters.createToCatalystConverter(keySchema) - val vExternalConverter = CatalystTypeConverters.createToCatalystConverter(valueSchema) - val kConverter = UnsafeProjection.create(keySchema) - val vConverter = UnsafeProjection.create(valueSchema) - - val keyDataGen = RandomDataGenerator.forType(keySchema, nullable = false).get - val valueDataGen = RandomDataGenerator.forType(valueSchema, nullable = false).get - - val input = Seq.fill(1024) { - val k = kConverter(kExternalConverter.apply(keyDataGen.apply()).asInstanceOf[InternalRow]) - val v = vConverter(vExternalConverter.apply(valueDataGen.apply()).asInstanceOf[InternalRow]) - (k.asInstanceOf[InternalRow].copy(), v.asInstanceOf[InternalRow].copy()) - } - - val sorter = new UnsafeKVExternalSorter( - keySchema, valueSchema, SparkEnv.get.blockManager, shuffleMemMgr, 16 * 1024 * 1024) - - // Insert generated keys and values into the sorter - input.foreach { case (k, v) => - sorter.insertKV(k.asInstanceOf[UnsafeRow], v.asInstanceOf[UnsafeRow]) - // 1% chance we will spill - if (rand.nextDouble() < 0.01 && spill) { - shuffleMemMgr.markAsOutOfMemory() - sorter.closeCurrentPage() - } - } + testKVSorter( + keySchema, + valueSchema, + inputData, + pageSize = 16 * 1024 * 1024, + spill + ) + } + } - // Collect the sorted output - val out = new scala.collection.mutable.ArrayBuffer[(InternalRow, InternalRow)] - val iter = sorter.sortedIterator() - while (iter.next()) { - out += Tuple2(iter.getKey.copy(), iter.getValue.copy()) + /** + * Create a test case using the given input data for the given key and value schema. + * + * The approach works as follows: + * + * - Create input by randomly generating data based on the given schema + * - Run [[UnsafeKVExternalSorter]] on the input data + * - Collect the output from the sorter, and make sure the keys are sorted in ascending order + * - Sort the input by both key and value, and sort the sorter output also by both key and value. + * Compare the sorted input and sorted output together to make sure all the key/values match. + * + * If spill is set to true, the sorter will spill probabilistically roughly every 100 records. + */ + private def testKVSorter( + keySchema: StructType, + valueSchema: StructType, + inputData: Seq[(InternalRow, InternalRow)], + pageSize: Long, + spill: Boolean): Unit = { + // Calling this make sure we have block manager and everything else setup. + TestSQLContext + + val taskMemMgr = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) + val shuffleMemMgr = new TestShuffleMemoryManager + TaskContext.setTaskContext(new TaskContextImpl( + stageId = 0, + partitionId = 0, + taskAttemptId = 98456, + attemptNumber = 0, + taskMemoryManager = taskMemMgr, + metricsSystem = null, + internalAccumulators = Seq.empty)) + + val sorter = new UnsafeKVExternalSorter( + keySchema, valueSchema, SparkEnv.get.blockManager, shuffleMemMgr, pageSize) + + // Insert the keys and values into the sorter + inputData.foreach { case (k, v) => + sorter.insertKV(k.asInstanceOf[UnsafeRow], v.asInstanceOf[UnsafeRow]) + // 1% chance we will spill + if (rand.nextDouble() < 0.01 && spill) { + shuffleMemMgr.markAsOutOfMemory() + sorter.closeCurrentPage() } + } - val keyOrdering = RowOrdering.forSchema(keySchema.map(_.dataType)) - val valueOrdering = RowOrdering.forSchema(valueSchema.map(_.dataType)) - val kvOrdering = new Ordering[(InternalRow, InternalRow)] { - override def compare(x: (InternalRow, InternalRow), y: (InternalRow, InternalRow)): Int = { - keyOrdering.compare(x._1, y._1) match { - case 0 => valueOrdering.compare(x._2, y._2) - case cmp => cmp - } + // Collect the sorted output + val out = new scala.collection.mutable.ArrayBuffer[(InternalRow, InternalRow)] + val iter = sorter.sortedIterator() + while (iter.next()) { + out += Tuple2(iter.getKey.copy(), iter.getValue.copy()) + } + sorter.cleanupResources() + + val keyOrdering = RowOrdering.forSchema(keySchema.map(_.dataType)) + val valueOrdering = RowOrdering.forSchema(valueSchema.map(_.dataType)) + val kvOrdering = new Ordering[(InternalRow, InternalRow)] { + override def compare(x: (InternalRow, InternalRow), y: (InternalRow, InternalRow)): Int = { + keyOrdering.compare(x._1, y._1) match { + case 0 => valueOrdering.compare(x._2, y._2) + case cmp => cmp } } + } - // Testing to make sure output from the sorter is sorted by key - var prevK: InternalRow = null - out.zipWithIndex.foreach { case ((k, v), i) => - if (prevK != null) { - assert(keyOrdering.compare(prevK, k) <= 0, - s""" - |key is not in sorted order: - |previous key: $prevK - |current key : $k - """.stripMargin) - } - prevK = k + // Testing to make sure output from the sorter is sorted by key + var prevK: InternalRow = null + out.zipWithIndex.foreach { case ((k, v), i) => + if (prevK != null) { + assert(keyOrdering.compare(prevK, k) <= 0, + s""" + |key is not in sorted order: + |previous key: $prevK + |current key : $k + """.stripMargin) } + prevK = k + } - // Testing to make sure the key/value in output matches input - assert(out.sorted(kvOrdering) === input.sorted(kvOrdering)) + // Testing to make sure the key/value in output matches input + assert(out.sorted(kvOrdering) === inputData.sorted(kvOrdering)) - // Make sure there is no memory leak - val leakedUnsafeMemory: Long = taskMemMgr.cleanUpAllAllocatedMemory - if (shuffleMemMgr != null) { - val leakedShuffleMemory: Long = shuffleMemMgr.getMemoryConsumptionForThisTask() - assert(0L === leakedShuffleMemory) - } - assert(0 === leakedUnsafeMemory) - TaskContext.unset() + // Make sure there is no memory leak + val leakedUnsafeMemory: Long = taskMemMgr.cleanUpAllAllocatedMemory + if (shuffleMemMgr != null) { + val leakedShuffleMemory: Long = shuffleMemMgr.getMemoryConsumptionForThisTask() + assert(0L === leakedShuffleMemory) } + assert(0 === leakedUnsafeMemory) + TaskContext.unset() + } + + test("kv sorting with records that exceed page size") { + val pageSize = 128 + + val schema = StructType(StructField("b", BinaryType) :: Nil) + val externalConverter = CatalystTypeConverters.createToCatalystConverter(schema) + val converter = UnsafeProjection.create(schema) + + val rand = new Random() + val inputData = Seq.fill(1024) { + val kBytes = new Array[Byte](rand.nextInt(pageSize)) + val vBytes = new Array[Byte](rand.nextInt(pageSize)) + rand.nextBytes(kBytes) + rand.nextBytes(vBytes) + val k = converter(externalConverter.apply(Row(kBytes)).asInstanceOf[InternalRow]) + val v = converter(externalConverter.apply(Row(vBytes)).asInstanceOf[InternalRow]) + (k.asInstanceOf[InternalRow].copy(), v.asInstanceOf[InternalRow].copy()) + } + + testKVSorter( + schema, + schema, + inputData, + pageSize, + spill = true + ) } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java index bbe83d36cf..6722301df1 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java @@ -24,6 +24,9 @@ public class HeapMemoryAllocator implements MemoryAllocator { @Override public MemoryBlock allocate(long size) throws OutOfMemoryError { + if (size % 8 != 0) { + throw new IllegalArgumentException("Size " + size + " was not a multiple of 8"); + } long[] array = new long[(int) (size / 8)]; return MemoryBlock.fromLongArray(array); } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java index 15898771fe..62f4459696 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java @@ -26,6 +26,9 @@ public class UnsafeMemoryAllocator implements MemoryAllocator { @Override public MemoryBlock allocate(long size) throws OutOfMemoryError { + if (size % 8 != 0) { + throw new IllegalArgumentException("Size " + size + " was not a multiple of 8"); + } long address = PlatformDependent.UNSAFE.allocateMemory(size); return new MemoryBlock(null, address, size); } -- cgit v1.2.3