diff options
Diffstat (limited to 'core/src/main')
12 files changed, 877 insertions, 592 deletions
diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java new file mode 100644 index 0000000000..008799cc77 --- /dev/null +++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory; + + +import java.io.IOException; + +import org.apache.spark.unsafe.memory.MemoryBlock; + + +/** + * An memory consumer of TaskMemoryManager, which support spilling. + */ +public abstract class MemoryConsumer { + + private final TaskMemoryManager taskMemoryManager; + private final long pageSize; + private long used; + + protected MemoryConsumer(TaskMemoryManager taskMemoryManager, long pageSize) { + this.taskMemoryManager = taskMemoryManager; + this.pageSize = pageSize; + this.used = 0; + } + + protected MemoryConsumer(TaskMemoryManager taskMemoryManager) { + this(taskMemoryManager, taskMemoryManager.pageSizeBytes()); + } + + /** + * Returns the size of used memory in bytes. + */ + long getUsed() { + return used; + } + + /** + * Force spill during building. + * + * For testing. + */ + public void spill() throws IOException { + spill(Long.MAX_VALUE, this); + } + + /** + * Spill some data to disk to release memory, which will be called by TaskMemoryManager + * when there is not enough memory for the task. + * + * This should be implemented by subclass. + * + * Note: In order to avoid possible deadlock, should not call acquireMemory() from spill(). + * + * @param size the amount of memory should be released + * @param trigger the MemoryConsumer that trigger this spilling + * @return the amount of released memory in bytes + * @throws IOException + */ + public abstract long spill(long size, MemoryConsumer trigger) throws IOException; + + /** + * Acquire `size` bytes memory. + * + * If there is not enough memory, throws OutOfMemoryError. + */ + protected void acquireMemory(long size) { + long got = taskMemoryManager.acquireExecutionMemory(size, this); + if (got < size) { + taskMemoryManager.releaseExecutionMemory(got, this); + taskMemoryManager.showMemoryUsage(); + throw new OutOfMemoryError("Could not acquire " + size + " bytes of memory, got " + got); + } + used += got; + } + + /** + * Release `size` bytes memory. + */ + protected void releaseMemory(long size) { + used -= size; + taskMemoryManager.releaseExecutionMemory(size, this); + } + + /** + * Allocate a memory block with at least `required` bytes. + * + * Throws IOException if there is not enough memory. + * + * @throws OutOfMemoryError + */ + protected MemoryBlock allocatePage(long required) { + MemoryBlock page = taskMemoryManager.allocatePage(Math.max(pageSize, required), this); + if (page == null || page.size() < required) { + long got = 0; + if (page != null) { + got = page.size(); + freePage(page); + } + taskMemoryManager.showMemoryUsage(); + throw new OutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + got); + } + used += page.size(); + return page; + } + + /** + * Free a memory block. + */ + protected void freePage(MemoryBlock page) { + used -= page.size(); + taskMemoryManager.freePage(page, this); + } +} diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index 7b31c90dac..4230575446 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -17,13 +17,18 @@ package org.apache.spark.memory; -import java.util.*; +import javax.annotation.concurrent.GuardedBy; +import java.io.IOException; +import java.util.Arrays; +import java.util.BitSet; +import java.util.HashSet; import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.util.Utils; /** * Manages the memory allocated by an individual task. @@ -101,29 +106,104 @@ public class TaskMemoryManager { private final boolean inHeap; /** + * The size of memory granted to each consumer. + */ + @GuardedBy("this") + private final HashSet<MemoryConsumer> consumers; + + /** * Construct a new TaskMemoryManager. */ public TaskMemoryManager(MemoryManager memoryManager, long taskAttemptId) { this.inHeap = memoryManager.tungstenMemoryIsAllocatedInHeap(); this.memoryManager = memoryManager; this.taskAttemptId = taskAttemptId; + this.consumers = new HashSet<>(); } /** - * Acquire N bytes of memory for execution, evicting cached blocks if necessary. + * Acquire N bytes of memory for a consumer. If there is no enough memory, it will call + * spill() of consumers to release more memory. + * * @return number of bytes successfully granted (<= N). */ - public long acquireExecutionMemory(long size) { - return memoryManager.acquireExecutionMemory(size, taskAttemptId); + public long acquireExecutionMemory(long required, MemoryConsumer consumer) { + assert(required >= 0); + synchronized (this) { + long got = memoryManager.acquireExecutionMemory(required, taskAttemptId); + + // try to release memory from other consumers first, then we can reduce the frequency of + // spilling, avoid to have too many spilled files. + if (got < required) { + // Call spill() on other consumers to release memory + for (MemoryConsumer c: consumers) { + if (c != null && c != consumer && c.getUsed() > 0) { + try { + long released = c.spill(required - got, consumer); + if (released > 0) { + logger.info("Task {} released {} from {} for {}", taskAttemptId, + Utils.bytesToString(released), c, consumer); + got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId); + if (got >= required) { + break; + } + } + } catch (IOException e) { + logger.error("error while calling spill() on " + c, e); + throw new OutOfMemoryError("error while calling spill() on " + c + " : " + + e.getMessage()); + } + } + } + } + + // call spill() on itself + if (got < required && consumer != null) { + try { + long released = consumer.spill(required - got, consumer); + if (released > 0) { + logger.info("Task {} released {} from itself ({})", taskAttemptId, + Utils.bytesToString(released), consumer); + got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId); + } + } catch (IOException e) { + logger.error("error while calling spill() on " + consumer, e); + throw new OutOfMemoryError("error while calling spill() on " + consumer + " : " + + e.getMessage()); + } + } + + consumers.add(consumer); + logger.debug("Task {} acquire {} for {}", taskAttemptId, Utils.bytesToString(got), consumer); + return got; + } } /** - * Release N bytes of execution memory. + * Release N bytes of execution memory for a MemoryConsumer. */ - public void releaseExecutionMemory(long size) { + public void releaseExecutionMemory(long size, MemoryConsumer consumer) { + logger.debug("Task {} release {} from {}", taskAttemptId, Utils.bytesToString(size), consumer); memoryManager.releaseExecutionMemory(size, taskAttemptId); } + /** + * Dump the memory usage of all consumers. + */ + public void showMemoryUsage() { + logger.info("Memory used in task " + taskAttemptId); + synchronized (this) { + for (MemoryConsumer c: consumers) { + if (c.getUsed() > 0) { + logger.info("Acquired by " + c + ": " + Utils.bytesToString(c.getUsed())); + } + } + } + } + + /** + * Return the page size in bytes. + */ public long pageSizeBytes() { return memoryManager.pageSizeBytes(); } @@ -134,42 +214,40 @@ public class TaskMemoryManager { * * Returns `null` if there was not enough memory to allocate the page. */ - public MemoryBlock allocatePage(long size) { + public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { if (size > MAXIMUM_PAGE_SIZE_BYTES) { throw new IllegalArgumentException( "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE_BYTES + " bytes"); } + long acquired = acquireExecutionMemory(size, consumer); + if (acquired <= 0) { + return null; + } + final int pageNumber; synchronized (this) { pageNumber = allocatedPages.nextClearBit(0); if (pageNumber >= PAGE_TABLE_SIZE) { + releaseExecutionMemory(acquired, consumer); throw new IllegalStateException( "Have already allocated a maximum of " + PAGE_TABLE_SIZE + " pages"); } allocatedPages.set(pageNumber); } - final long acquiredExecutionMemory = acquireExecutionMemory(size); - if (acquiredExecutionMemory != size) { - releaseExecutionMemory(acquiredExecutionMemory); - synchronized (this) { - allocatedPages.clear(pageNumber); - } - return null; - } - final MemoryBlock page = memoryManager.tungstenMemoryAllocator().allocate(size); + final MemoryBlock page = memoryManager.tungstenMemoryAllocator().allocate(acquired); page.pageNumber = pageNumber; pageTable[pageNumber] = page; if (logger.isTraceEnabled()) { - logger.trace("Allocate page number {} ({} bytes)", pageNumber, size); + logger.trace("Allocate page number {} ({} bytes)", pageNumber, acquired); } return page; } /** - * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage(long)}. + * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage}. */ - public void freePage(MemoryBlock page) { + public void freePage(MemoryBlock page, MemoryConsumer consumer) { assert (page.pageNumber != -1) : "Called freePage() on memory that wasn't allocated with allocatePage()"; assert(allocatedPages.get(page.pageNumber)); @@ -182,14 +260,14 @@ public class TaskMemoryManager { } long pageSize = page.size(); memoryManager.tungstenMemoryAllocator().free(page); - releaseExecutionMemory(pageSize); + releaseExecutionMemory(pageSize, consumer); } /** * Given a memory page and offset within that page, encode this address into a 64-bit long. * This address will remain valid as long as the corresponding page has not been freed. * - * @param page a data page allocated by {@link TaskMemoryManager#allocatePage(long)}/ + * @param page a data page allocated by {@link TaskMemoryManager#allocatePage}/ * @param offsetInPage an offset in this page which incorporates the base offset. In other words, * this should be the value that you would pass as the base offset into an * UNSAFE call (e.g. page.baseOffset() + something). @@ -261,17 +339,17 @@ public class TaskMemoryManager { * value can be used to detect memory leaks. */ public long cleanUpAllAllocatedMemory() { - long freedBytes = 0; - for (MemoryBlock page : pageTable) { - if (page != null) { - freedBytes += page.size(); - freePage(page); + synchronized (this) { + Arrays.fill(pageTable, null); + for (MemoryConsumer c: consumers) { + if (c != null && c.getUsed() > 0) { + // In case of failed task, it's normal to see leaked memory + logger.warn("leak " + Utils.bytesToString(c.getUsed()) + " memory from " + c); + } } + consumers.clear(); } - - freedBytes += memoryManager.releaseAllExecutionMemoryForTask(taskAttemptId); - - return freedBytes; + return memoryManager.releaseAllExecutionMemoryForTask(taskAttemptId); } /** diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index f43236f41a..400d852001 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -31,15 +31,15 @@ import org.slf4j.LoggerFactory; import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.memory.MemoryConsumer; +import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.serializer.DummySerializerInstance; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.DiskBlockObjectWriter; import org.apache.spark.storage.TempShuffleBlockId; import org.apache.spark.unsafe.Platform; -import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.memory.MemoryBlock; -import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.util.Utils; /** @@ -58,23 +58,18 @@ import org.apache.spark.util.Utils; * spill files. Instead, this merging is performed in {@link UnsafeShuffleWriter}, which uses a * specialized merge procedure that avoids extra serialization/deserialization. */ -final class ShuffleExternalSorter { +final class ShuffleExternalSorter extends MemoryConsumer { private final Logger logger = LoggerFactory.getLogger(ShuffleExternalSorter.class); @VisibleForTesting static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024; - private final int initialSize; private final int numPartitions; - private final int pageSizeBytes; - @VisibleForTesting - final int maxRecordSizeBytes; private final TaskMemoryManager taskMemoryManager; private final BlockManager blockManager; private final TaskContext taskContext; private final ShuffleWriteMetrics writeMetrics; - private long numRecordsInsertedSinceLastSpill = 0; /** Force this sorter to spill when there are this many elements in memory. For testing only */ private final long numElementsForSpillThreshold; @@ -98,8 +93,7 @@ final class ShuffleExternalSorter { // These variables are reset after spilling: @Nullable private ShuffleInMemorySorter inMemSorter; @Nullable private MemoryBlock currentPage = null; - private long currentPagePosition = -1; - private long freeSpaceInCurrentPage = 0; + private long pageCursor = -1; public ShuffleExternalSorter( TaskMemoryManager memoryManager, @@ -108,42 +102,21 @@ final class ShuffleExternalSorter { int initialSize, int numPartitions, SparkConf conf, - ShuffleWriteMetrics writeMetrics) throws IOException { + ShuffleWriteMetrics writeMetrics) { + super(memoryManager, (int) Math.min(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, + memoryManager.pageSizeBytes())); this.taskMemoryManager = memoryManager; this.blockManager = blockManager; this.taskContext = taskContext; - this.initialSize = initialSize; - this.peakMemoryUsedBytes = initialSize; this.numPartitions = numPartitions; // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; this.numElementsForSpillThreshold = conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MAX_VALUE); - this.pageSizeBytes = (int) Math.min( - PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, taskMemoryManager.pageSizeBytes()); - this.maxRecordSizeBytes = pageSizeBytes - 4; this.writeMetrics = writeMetrics; - initializeForWriting(); - - // preserve first page to ensure that we have at least one page to work with. Otherwise, - // other operators in the same task may starve this sorter (SPARK-9709). - acquireNewPageIfNecessary(pageSizeBytes); - } - - /** - * Allocates new sort data structures. Called when creating the sorter and after each spill. - */ - private void initializeForWriting() throws IOException { - // TODO: move this sizing calculation logic into a static method of sorter: - final long memoryRequested = initialSize * 8L; - final long memoryAcquired = taskMemoryManager.acquireExecutionMemory(memoryRequested); - if (memoryAcquired != memoryRequested) { - taskMemoryManager.releaseExecutionMemory(memoryAcquired); - throw new IOException("Could not acquire " + memoryRequested + " bytes of memory"); - } - + acquireMemory(initialSize * 8L); this.inMemSorter = new ShuffleInMemorySorter(initialSize); - numRecordsInsertedSinceLastSpill = 0; + this.peakMemoryUsedBytes = getMemoryUsage(); } /** @@ -242,6 +215,8 @@ final class ShuffleExternalSorter { } } + inMemSorter.reset(); + if (!isLastFile) { // i.e. this is a spill file // The current semantics of `shuffleRecordsWritten` seem to be that it's updated when records // are written to disk, not when they enter the shuffle sorting code. DiskBlockObjectWriter @@ -266,9 +241,12 @@ final class ShuffleExternalSorter { /** * Sort and spill the current records in response to memory pressure. */ - @VisibleForTesting - void spill() throws IOException { - assert(inMemSorter != null); + @Override + public long spill(long size, MemoryConsumer trigger) throws IOException { + if (trigger != this || inMemSorter == null || inMemSorter.numRecords() == 0) { + return 0L; + } + logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)", Thread.currentThread().getId(), Utils.bytesToString(getMemoryUsage()), @@ -276,13 +254,9 @@ final class ShuffleExternalSorter { spills.size() > 1 ? " times" : " time"); writeSortedFile(false); - final long inMemSorterMemoryUsage = inMemSorter.getMemoryUsage(); - inMemSorter = null; - taskMemoryManager.releaseExecutionMemory(inMemSorterMemoryUsage); final long spillSize = freeMemory(); taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); - - initializeForWriting(); + return spillSize; } private long getMemoryUsage() { @@ -312,18 +286,12 @@ final class ShuffleExternalSorter { updatePeakMemoryUsed(); long memoryFreed = 0; for (MemoryBlock block : allocatedPages) { - taskMemoryManager.freePage(block); memoryFreed += block.size(); - } - if (inMemSorter != null) { - long sorterMemoryUsage = inMemSorter.getMemoryUsage(); - inMemSorter = null; - taskMemoryManager.releaseExecutionMemory(sorterMemoryUsage); + freePage(block); } allocatedPages.clear(); currentPage = null; - currentPagePosition = -1; - freeSpaceInCurrentPage = 0; + pageCursor = 0; return memoryFreed; } @@ -332,16 +300,16 @@ final class ShuffleExternalSorter { */ public void cleanupResources() { freeMemory(); + if (inMemSorter != null) { + long sorterMemoryUsage = inMemSorter.getMemoryUsage(); + inMemSorter = null; + releaseMemory(sorterMemoryUsage); + } for (SpillInfo spill : spills) { if (spill.file.exists() && !spill.file.delete()) { logger.error("Unable to delete spill file {}", spill.file.getPath()); } } - if (inMemSorter != null) { - long sorterMemoryUsage = inMemSorter.getMemoryUsage(); - inMemSorter = null; - taskMemoryManager.releaseExecutionMemory(sorterMemoryUsage); - } } /** @@ -352,16 +320,27 @@ final class ShuffleExternalSorter { private void growPointerArrayIfNecessary() throws IOException { assert(inMemSorter != null); if (!inMemSorter.hasSpaceForAnotherRecord()) { - logger.debug("Attempting to expand sort pointer array"); - final long oldPointerArrayMemoryUsage = inMemSorter.getMemoryUsage(); - final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2; - final long memoryAcquired = taskMemoryManager.acquireExecutionMemory(memoryToGrowPointerArray); - if (memoryAcquired < memoryToGrowPointerArray) { - taskMemoryManager.releaseExecutionMemory(memoryAcquired); - spill(); + long used = inMemSorter.getMemoryUsage(); + long needed = used + inMemSorter.getMemoryToExpand(); + try { + acquireMemory(needed); // could trigger spilling + } catch (OutOfMemoryError e) { + // should have trigger spilling + assert(inMemSorter.hasSpaceForAnotherRecord()); + return; + } + // check if spilling is triggered or not + if (inMemSorter.hasSpaceForAnotherRecord()) { + releaseMemory(needed); } else { - inMemSorter.expandPointerArray(); - taskMemoryManager.releaseExecutionMemory(oldPointerArrayMemoryUsage); + try { + inMemSorter.expandPointerArray(); + releaseMemory(used); + } catch (OutOfMemoryError oom) { + // Just in case that JVM had run out of memory + releaseMemory(needed); + spill(); + } } } } @@ -370,96 +349,46 @@ final class ShuffleExternalSorter { * Allocates more memory in order to insert an additional record. This will request additional * memory from the memory manager and spill if the requested memory can not be obtained. * - * @param requiredSpace the required space in the data page, in bytes, including space for storing + * @param required the required space in the data page, in bytes, including space for storing * the record size. This must be less than or equal to the page size (records * that exceed the page size are handled via a different code path which uses * special overflow pages). */ - private void acquireNewPageIfNecessary(int requiredSpace) throws IOException { - growPointerArrayIfNecessary(); - if (requiredSpace > freeSpaceInCurrentPage) { - logger.trace("Required space {} is less than free space in current page ({})", requiredSpace, - freeSpaceInCurrentPage); - // TODO: we should track metrics on the amount of space wasted when we roll over to a new page - // without using the free space at the end of the current page. We should also do this for - // BytesToBytesMap. - if (requiredSpace > pageSizeBytes) { - throw new IOException("Required space " + requiredSpace + " is greater than page size (" + - pageSizeBytes + ")"); - } else { - currentPage = taskMemoryManager.allocatePage(pageSizeBytes); - if (currentPage == null) { - spill(); - currentPage = taskMemoryManager.allocatePage(pageSizeBytes); - if (currentPage == null) { - throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory"); - } - } - currentPagePosition = currentPage.getBaseOffset(); - freeSpaceInCurrentPage = pageSizeBytes; - allocatedPages.add(currentPage); - } + private void acquireNewPageIfNecessary(int required) { + if (currentPage == null || + pageCursor + required > currentPage.getBaseOffset() + currentPage.size() ) { + // TODO: try to find space in previous pages + currentPage = allocatePage(required); + pageCursor = currentPage.getBaseOffset(); + allocatedPages.add(currentPage); } } /** * Write a record to the shuffle sorter. */ - public void insertRecord( - Object recordBaseObject, - long recordBaseOffset, - int lengthInBytes, - int partitionId) throws IOException { + public void insertRecord(Object recordBase, long recordOffset, int length, int partitionId) + throws IOException { - if (numRecordsInsertedSinceLastSpill > numElementsForSpillThreshold) { + // for tests + assert(inMemSorter != null); + if (inMemSorter.numRecords() > numElementsForSpillThreshold) { spill(); } growPointerArrayIfNecessary(); // Need 4 bytes to store the record length. - final int totalSpaceRequired = lengthInBytes + 4; - - // --- Figure out where to insert the new record ---------------------------------------------- - - final MemoryBlock dataPage; - long dataPagePosition; - boolean useOverflowPage = totalSpaceRequired > pageSizeBytes; - if (useOverflowPage) { - long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired); - // The record is larger than the page size, so allocate a special overflow page just to hold - // that record. - MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize); - if (overflowPage == null) { - spill(); - overflowPage = taskMemoryManager.allocatePage(overflowPageSize); - if (overflowPage == null) { - throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory"); - } - } - allocatedPages.add(overflowPage); - dataPage = overflowPage; - dataPagePosition = overflowPage.getBaseOffset(); - } else { - // The record is small enough to fit in a regular data page, but the current page might not - // have enough space to hold it (or no pages have been allocated yet). - acquireNewPageIfNecessary(totalSpaceRequired); - dataPage = currentPage; - dataPagePosition = currentPagePosition; - // Update bookkeeping information - freeSpaceInCurrentPage -= totalSpaceRequired; - currentPagePosition += totalSpaceRequired; - } - final Object dataPageBaseObject = dataPage.getBaseObject(); - - final long recordAddress = - taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition); - Platform.putInt(dataPageBaseObject, dataPagePosition, lengthInBytes); - dataPagePosition += 4; - Platform.copyMemory( - recordBaseObject, recordBaseOffset, dataPageBaseObject, dataPagePosition, lengthInBytes); - assert(inMemSorter != null); + final int required = length + 4; + acquireNewPageIfNecessary(required); + + assert(currentPage != null); + final Object base = currentPage.getBaseObject(); + final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor); + Platform.putInt(base, pageCursor, length); + pageCursor += 4; + Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length); + pageCursor += length; inMemSorter.insertRecord(recordAddress, partitionId); - numRecordsInsertedSinceLastSpill += 1; } /** @@ -475,6 +404,9 @@ final class ShuffleExternalSorter { // Do not count the final file towards the spill count. writeSortedFile(true); freeMemory(); + long sorterMemoryUsage = inMemSorter.getMemoryUsage(); + inMemSorter = null; + releaseMemory(sorterMemoryUsage); } return spills.toArray(new SpillInfo[spills.size()]); } catch (IOException e) { diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java index a8dee6c610..e630575d1a 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java @@ -37,33 +37,51 @@ final class ShuffleInMemorySorter { * {@link PackedRecordPointer}. The sort operates on this array instead of directly manipulating * records. */ - private long[] pointerArray; + private long[] array; /** * The position in the pointer array where new records can be inserted. */ - private int pointerArrayInsertPosition = 0; + private int pos = 0; public ShuffleInMemorySorter(int initialSize) { assert (initialSize > 0); - this.pointerArray = new long[initialSize]; - this.sorter = new Sorter<PackedRecordPointer, long[]>(ShuffleSortDataFormat.INSTANCE); + this.array = new long[initialSize]; + this.sorter = new Sorter<>(ShuffleSortDataFormat.INSTANCE); } - public void expandPointerArray() { - final long[] oldArray = pointerArray; + public int numRecords() { + return pos; + } + + public void reset() { + pos = 0; + } + + private int newLength() { // Guard against overflow: - final int newLength = oldArray.length * 2 > 0 ? (oldArray.length * 2) : Integer.MAX_VALUE; - pointerArray = new long[newLength]; - System.arraycopy(oldArray, 0, pointerArray, 0, oldArray.length); + return array.length <= Integer.MAX_VALUE / 2 ? (array.length * 2) : Integer.MAX_VALUE; + } + + /** + * Returns the memory needed to expand + */ + public long getMemoryToExpand() { + return ((long) (newLength() - array.length)) * 8; + } + + public void expandPointerArray() { + final long[] oldArray = array; + array = new long[newLength()]; + System.arraycopy(oldArray, 0, array, 0, oldArray.length); } public boolean hasSpaceForAnotherRecord() { - return pointerArrayInsertPosition + 1 < pointerArray.length; + return pos < array.length; } public long getMemoryUsage() { - return pointerArray.length * 8L; + return array.length * 8L; } /** @@ -78,15 +96,15 @@ final class ShuffleInMemorySorter { */ public void insertRecord(long recordPointer, int partitionId) { if (!hasSpaceForAnotherRecord()) { - if (pointerArray.length == Integer.MAX_VALUE) { + if (array.length == Integer.MAX_VALUE) { throw new IllegalStateException("Sort pointer array has reached maximum size"); } else { expandPointerArray(); } } - pointerArray[pointerArrayInsertPosition] = + array[pos] = PackedRecordPointer.packPointer(recordPointer, partitionId); - pointerArrayInsertPosition++; + pos++; } /** @@ -118,7 +136,7 @@ final class ShuffleInMemorySorter { * Return an iterator over record pointers in sorted order. */ public ShuffleSorterIterator getSortedIterator() { - sorter.sort(pointerArray, 0, pointerArrayInsertPosition, SORT_COMPARATOR); - return new ShuffleSorterIterator(pointerArrayInsertPosition, pointerArray); + sorter.sort(array, 0, pos, SORT_COMPARATOR); + return new ShuffleSorterIterator(pos, array); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index f6c5c944bd..e19b378642 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -127,12 +127,6 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> { open(); } - @VisibleForTesting - public int maxRecordSizeBytes() { - assert(sorter != null); - return sorter.maxRecordSizeBytes; - } - private void updatePeakMemoryUsed() { // sorter can be null if this writer is closed if (sorter != null) { diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index f035bdac81..e36709c6fc 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -18,14 +18,20 @@ package org.apache.spark.unsafe.map; import javax.annotation.Nullable; +import java.io.File; +import java.io.IOException; import java.util.Iterator; import java.util.LinkedList; -import java.util.List; import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.SparkEnv; +import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.memory.MemoryConsumer; +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.array.LongArray; @@ -33,7 +39,8 @@ import org.apache.spark.unsafe.bitset.BitSet; import org.apache.spark.unsafe.hash.Murmur3_x86_32; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.MemoryLocation; -import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillReader; +import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillWriter; /** * An append-only hash map where keys and values are contiguous regions of bytes. @@ -54,7 +61,7 @@ import org.apache.spark.memory.TaskMemoryManager; * is consistent with {@link org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter}, * so we can pass records from this map directly into the sorter to sort records in place. */ -public final class BytesToBytesMap { +public final class BytesToBytesMap extends MemoryConsumer { private final Logger logger = LoggerFactory.getLogger(BytesToBytesMap.class); @@ -62,27 +69,22 @@ public final class BytesToBytesMap { private static final HashMapGrowthStrategy growthStrategy = HashMapGrowthStrategy.DOUBLING; - /** - * Special record length that is placed after the last record in a data page. - */ - private static final int END_OF_PAGE_MARKER = -1; - private final TaskMemoryManager taskMemoryManager; /** * A linked list for tracking all allocated data pages so that we can free all of our memory. */ - private final List<MemoryBlock> dataPages = new LinkedList<MemoryBlock>(); + private final LinkedList<MemoryBlock> dataPages = new LinkedList<>(); /** * The data page that will be used to store keys and values for new hashtable entries. When this * page becomes full, a new page will be allocated and this pointer will change to point to that * new page. */ - private MemoryBlock currentDataPage = null; + private MemoryBlock currentPage = null; /** - * Offset into `currentDataPage` that points to the location where new data can be inserted into + * Offset into `currentPage` that points to the location where new data can be inserted into * the page. This does not incorporate the page's base offset. */ private long pageCursor = 0; @@ -117,6 +119,11 @@ public final class BytesToBytesMap { // absolute memory addresses. /** + * Whether or not the longArray can grow. We will not insert more elements if it's false. + */ + private boolean canGrowArray = true; + + /** * A {@link BitSet} used to track location of the map where the key is set. * Size of the bitset should be half of the size of the long array. */ @@ -164,13 +171,20 @@ public final class BytesToBytesMap { private long peakMemoryUsedBytes = 0L; + private final BlockManager blockManager; + private volatile MapIterator destructiveIterator = null; + private LinkedList<UnsafeSorterSpillWriter> spillWriters = new LinkedList<>(); + public BytesToBytesMap( TaskMemoryManager taskMemoryManager, + BlockManager blockManager, int initialCapacity, double loadFactor, long pageSizeBytes, boolean enablePerfMetrics) { + super(taskMemoryManager, pageSizeBytes); this.taskMemoryManager = taskMemoryManager; + this.blockManager = blockManager; this.loadFactor = loadFactor; this.loc = new Location(); this.pageSizeBytes = pageSizeBytes; @@ -187,18 +201,13 @@ public final class BytesToBytesMap { TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES); } allocate(initialCapacity); - - // Acquire a new page as soon as we construct the map to ensure that we have at least - // one page to work with. Otherwise, other operators in the same task may starve this - // map (SPARK-9747). - acquireNewPage(); } public BytesToBytesMap( TaskMemoryManager taskMemoryManager, int initialCapacity, long pageSizeBytes) { - this(taskMemoryManager, initialCapacity, 0.70, pageSizeBytes, false); + this(taskMemoryManager, initialCapacity, pageSizeBytes, false); } public BytesToBytesMap( @@ -208,6 +217,7 @@ public final class BytesToBytesMap { boolean enablePerfMetrics) { this( taskMemoryManager, + SparkEnv.get() != null ? SparkEnv.get().blockManager() : null, initialCapacity, 0.70, pageSizeBytes, @@ -219,61 +229,153 @@ public final class BytesToBytesMap { */ public int numElements() { return numElements; } - public static final class BytesToBytesMapIterator implements Iterator<Location> { + public final class MapIterator implements Iterator<Location> { - private final int numRecords; - private final Iterator<MemoryBlock> dataPagesIterator; + private int numRecords; private final Location loc; private MemoryBlock currentPage = null; - private int currentRecordNumber = 0; + private int recordsInPage = 0; private Object pageBaseObject; private long offsetInPage; // If this iterator destructive or not. When it is true, it frees each page as it moves onto // next one. private boolean destructive = false; - private BytesToBytesMap bmap; + private UnsafeSorterSpillReader reader = null; - private BytesToBytesMapIterator( - int numRecords, Iterator<MemoryBlock> dataPagesIterator, Location loc, - boolean destructive, BytesToBytesMap bmap) { + private MapIterator(int numRecords, Location loc, boolean destructive) { this.numRecords = numRecords; - this.dataPagesIterator = dataPagesIterator; this.loc = loc; this.destructive = destructive; - this.bmap = bmap; - if (dataPagesIterator.hasNext()) { - advanceToNextPage(); + if (destructive) { + destructiveIterator = this; } } private void advanceToNextPage() { - if (destructive && currentPage != null) { - dataPagesIterator.remove(); - this.bmap.taskMemoryManager.freePage(currentPage); + synchronized (this) { + int nextIdx = dataPages.indexOf(currentPage) + 1; + if (destructive && currentPage != null) { + dataPages.remove(currentPage); + freePage(currentPage); + nextIdx --; + } + if (dataPages.size() > nextIdx) { + currentPage = dataPages.get(nextIdx); + pageBaseObject = currentPage.getBaseObject(); + offsetInPage = currentPage.getBaseOffset(); + recordsInPage = Platform.getInt(pageBaseObject, offsetInPage); + offsetInPage += 4; + } else { + currentPage = null; + if (reader != null) { + // remove the spill file from disk + File file = spillWriters.removeFirst().getFile(); + if (file != null && file.exists()) { + if (!file.delete()) { + logger.error("Was unable to delete spill file {}", file.getAbsolutePath()); + } + } + } + try { + reader = spillWriters.getFirst().getReader(blockManager); + recordsInPage = -1; + } catch (IOException e) { + // Scala iterator does not handle exception + Platform.throwException(e); + } + } } - currentPage = dataPagesIterator.next(); - pageBaseObject = currentPage.getBaseObject(); - offsetInPage = currentPage.getBaseOffset(); } @Override public boolean hasNext() { - return currentRecordNumber != numRecords; + if (numRecords == 0) { + if (reader != null) { + // remove the spill file from disk + File file = spillWriters.removeFirst().getFile(); + if (file != null && file.exists()) { + if (!file.delete()) { + logger.error("Was unable to delete spill file {}", file.getAbsolutePath()); + } + } + } + } + return numRecords > 0; } @Override public Location next() { - int totalLength = Platform.getInt(pageBaseObject, offsetInPage); - if (totalLength == END_OF_PAGE_MARKER) { + if (recordsInPage == 0) { advanceToNextPage(); - totalLength = Platform.getInt(pageBaseObject, offsetInPage); } - loc.with(currentPage, offsetInPage); - offsetInPage += 4 + totalLength; - currentRecordNumber++; - return loc; + numRecords--; + if (currentPage != null) { + int totalLength = Platform.getInt(pageBaseObject, offsetInPage); + loc.with(currentPage, offsetInPage); + offsetInPage += 4 + totalLength; + recordsInPage --; + return loc; + } else { + assert(reader != null); + if (!reader.hasNext()) { + advanceToNextPage(); + } + try { + reader.loadNext(); + } catch (IOException e) { + // Scala iterator does not handle exception + Platform.throwException(e); + } + loc.with(reader.getBaseObject(), reader.getBaseOffset(), reader.getRecordLength()); + return loc; + } + } + + public long spill(long numBytes) throws IOException { + synchronized (this) { + if (!destructive || dataPages.size() == 1) { + return 0L; + } + + // TODO: use existing ShuffleWriteMetrics + ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics(); + + long released = 0L; + while (dataPages.size() > 0) { + MemoryBlock block = dataPages.getLast(); + // The currentPage is used, cannot be released + if (block == currentPage) { + break; + } + + Object base = block.getBaseObject(); + long offset = block.getBaseOffset(); + int numRecords = Platform.getInt(base, offset); + offset += 4; + final UnsafeSorterSpillWriter writer = + new UnsafeSorterSpillWriter(blockManager, 32 * 1024, writeMetrics, numRecords); + while (numRecords > 0) { + int length = Platform.getInt(base, offset); + writer.write(base, offset + 4, length, 0); + offset += 4 + length; + numRecords--; + } + writer.close(); + spillWriters.add(writer); + + dataPages.removeLast(); + released += block.size(); + freePage(block); + + if (released >= numBytes) { + break; + } + } + + return released; + } } @Override @@ -290,8 +392,8 @@ public final class BytesToBytesMap { * If any other lookups or operations are performed on this map while iterating over it, including * `lookup()`, the behavior of the returned iterator is undefined. */ - public BytesToBytesMapIterator iterator() { - return new BytesToBytesMapIterator(numElements, dataPages.iterator(), loc, false, this); + public MapIterator iterator() { + return new MapIterator(numElements, loc, false); } /** @@ -304,8 +406,8 @@ public final class BytesToBytesMap { * If any other lookups or operations are performed on this map while iterating over it, including * `lookup()`, the behavior of the returned iterator is undefined. */ - public BytesToBytesMapIterator destructiveIterator() { - return new BytesToBytesMapIterator(numElements, dataPages.iterator(), loc, true, this); + public MapIterator destructiveIterator() { + return new MapIterator(numElements, loc, true); } /** @@ -314,11 +416,8 @@ public final class BytesToBytesMap { * * This function always return the same {@link Location} instance to avoid object allocation. */ - public Location lookup( - Object keyBaseObject, - long keyBaseOffset, - int keyRowLengthBytes) { - safeLookup(keyBaseObject, keyBaseOffset, keyRowLengthBytes, loc); + public Location lookup(Object keyBase, long keyOffset, int keyLength) { + safeLookup(keyBase, keyOffset, keyLength, loc); return loc; } @@ -327,18 +426,14 @@ public final class BytesToBytesMap { * * This is a thread-safe version of `lookup`, could be used by multiple threads. */ - public void safeLookup( - Object keyBaseObject, - long keyBaseOffset, - int keyRowLengthBytes, - Location loc) { + public void safeLookup(Object keyBase, long keyOffset, int keyLength, Location loc) { assert(bitset != null); assert(longArray != null); if (enablePerfMetrics) { numKeyLookups++; } - final int hashcode = HASHER.hashUnsafeWords(keyBaseObject, keyBaseOffset, keyRowLengthBytes); + final int hashcode = HASHER.hashUnsafeWords(keyBase, keyOffset, keyLength); int pos = hashcode & mask; int step = 1; while (true) { @@ -354,16 +449,16 @@ public final class BytesToBytesMap { if ((int) (stored) == hashcode) { // Full hash code matches. Let's compare the keys for equality. loc.with(pos, hashcode, true); - if (loc.getKeyLength() == keyRowLengthBytes) { + if (loc.getKeyLength() == keyLength) { final MemoryLocation keyAddress = loc.getKeyAddress(); - final Object storedKeyBaseObject = keyAddress.getBaseObject(); - final long storedKeyBaseOffset = keyAddress.getBaseOffset(); + final Object storedkeyBase = keyAddress.getBaseObject(); + final long storedkeyOffset = keyAddress.getBaseOffset(); final boolean areEqual = ByteArrayMethods.arrayEquals( - keyBaseObject, - keyBaseOffset, - storedKeyBaseObject, - storedKeyBaseOffset, - keyRowLengthBytes + keyBase, + keyOffset, + storedkeyBase, + storedkeyOffset, + keyLength ); if (areEqual) { return; @@ -410,18 +505,18 @@ public final class BytesToBytesMap { taskMemoryManager.getOffsetInPage(fullKeyAddress)); } - private void updateAddressesAndSizes(final Object page, final long offsetInPage) { - long position = offsetInPage; - final int totalLength = Platform.getInt(page, position); + private void updateAddressesAndSizes(final Object base, final long offset) { + long position = offset; + final int totalLength = Platform.getInt(base, position); position += 4; - keyLength = Platform.getInt(page, position); + keyLength = Platform.getInt(base, position); position += 4; valueLength = totalLength - keyLength - 4; - keyMemoryLocation.setObjAndOffset(page, position); + keyMemoryLocation.setObjAndOffset(base, position); position += keyLength; - valueMemoryLocation.setObjAndOffset(page, position); + valueMemoryLocation.setObjAndOffset(base, position); } private Location with(int pos, int keyHashcode, boolean isDefined) { @@ -444,6 +539,19 @@ public final class BytesToBytesMap { } /** + * This is only used for spilling + */ + private Location with(Object base, long offset, int length) { + this.isDefined = true; + this.memoryPage = null; + keyLength = Platform.getInt(base, offset); + valueLength = length - 4 - keyLength; + keyMemoryLocation.setObjAndOffset(base, offset + 4); + valueMemoryLocation.setObjAndOffset(base, offset + 4 + keyLength); + return this; + } + + /** * Returns the memory page that contains the current record. * This is only valid if this is returned by {@link BytesToBytesMap#iterator()}. */ @@ -517,9 +625,9 @@ public final class BytesToBytesMap { * As an example usage, here's the proper way to store a new key: * </p> * <pre> - * Location loc = map.lookup(keyBaseObject, keyBaseOffset, keyLengthInBytes); + * Location loc = map.lookup(keyBase, keyOffset, keyLength); * if (!loc.isDefined()) { - * if (!loc.putNewKey(keyBaseObject, keyBaseOffset, keyLengthInBytes, ...)) { + * if (!loc.putNewKey(keyBase, keyOffset, keyLength, ...)) { * // handle failure to grow map (by spilling, for example) * } * } @@ -531,113 +639,59 @@ public final class BytesToBytesMap { * @return true if the put() was successful and false if the put() failed because memory could * not be acquired. */ - public boolean putNewKey( - Object keyBaseObject, - long keyBaseOffset, - int keyLengthBytes, - Object valueBaseObject, - long valueBaseOffset, - int valueLengthBytes) { + public boolean putNewKey(Object keyBase, long keyOffset, int keyLength, + Object valueBase, long valueOffset, int valueLength) { assert (!isDefined) : "Can only set value once for a key"; - assert (keyLengthBytes % 8 == 0); - assert (valueLengthBytes % 8 == 0); + assert (keyLength % 8 == 0); + assert (valueLength % 8 == 0); assert(bitset != null); assert(longArray != null); - if (numElements == MAX_CAPACITY) { - throw new IllegalStateException("BytesToBytesMap has reached maximum capacity"); + if (numElements == MAX_CAPACITY || !canGrowArray) { + return false; } // Here, we'll copy the data into our data pages. Because we only store a relative offset from // the key address instead of storing the absolute address of the value, the key and value // must be stored in the same memory page. // (8 byte key length) (key) (value) - final long requiredSize = 8 + keyLengthBytes + valueLengthBytes; - - // --- Figure out where to insert the new record --------------------------------------------- - - final MemoryBlock dataPage; - final Object dataPageBaseObject; - final long dataPageInsertOffset; - boolean useOverflowPage = requiredSize > pageSizeBytes - 8; - if (useOverflowPage) { - // The record is larger than the page size, so allocate a special overflow page just to hold - // that record. - final long overflowPageSize = requiredSize + 8; - MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize); - if (overflowPage == null) { - logger.debug("Failed to acquire {} bytes of memory", overflowPageSize); + final long recordLength = 8 + keyLength + valueLength; + if (currentPage == null || currentPage.size() - pageCursor < recordLength) { + if (!acquireNewPage(recordLength + 4L)) { return false; } - dataPages.add(overflowPage); - dataPage = overflowPage; - dataPageBaseObject = overflowPage.getBaseObject(); - dataPageInsertOffset = overflowPage.getBaseOffset(); - } else if (currentDataPage == null || pageSizeBytes - 8 - pageCursor < requiredSize) { - // The record can fit in a data page, but either we have not allocated any pages yet or - // the current page does not have enough space. - if (currentDataPage != null) { - // There wasn't enough space in the current page, so write an end-of-page marker: - final Object pageBaseObject = currentDataPage.getBaseObject(); - final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor; - Platform.putInt(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER); - } - if (!acquireNewPage()) { - return false; - } - dataPage = currentDataPage; - dataPageBaseObject = currentDataPage.getBaseObject(); - dataPageInsertOffset = currentDataPage.getBaseOffset(); - } else { - // There is enough space in the current data page. - dataPage = currentDataPage; - dataPageBaseObject = currentDataPage.getBaseObject(); - dataPageInsertOffset = currentDataPage.getBaseOffset() + pageCursor; } // --- Append the key and value data to the current data page -------------------------------- - - long insertCursor = dataPageInsertOffset; - - // Compute all of our offsets up-front: - final long recordOffset = insertCursor; - insertCursor += 4; - final long keyLengthOffset = insertCursor; - insertCursor += 4; - final long keyDataOffsetInPage = insertCursor; - insertCursor += keyLengthBytes; - final long valueDataOffsetInPage = insertCursor; - insertCursor += valueLengthBytes; // word used to store the value size - - Platform.putInt(dataPageBaseObject, recordOffset, - keyLengthBytes + valueLengthBytes + 4); - Platform.putInt(dataPageBaseObject, keyLengthOffset, keyLengthBytes); - // Copy the key - Platform.copyMemory( - keyBaseObject, keyBaseOffset, dataPageBaseObject, keyDataOffsetInPage, keyLengthBytes); - // Copy the value - Platform.copyMemory(valueBaseObject, valueBaseOffset, dataPageBaseObject, - valueDataOffsetInPage, valueLengthBytes); - - // --- Update bookeeping data structures ----------------------------------------------------- - - if (useOverflowPage) { - // Store the end-of-page marker at the end of the data page - Platform.putInt(dataPageBaseObject, insertCursor, END_OF_PAGE_MARKER); - } else { - pageCursor += requiredSize; - } - + final Object base = currentPage.getBaseObject(); + long offset = currentPage.getBaseOffset() + pageCursor; + final long recordOffset = offset; + Platform.putInt(base, offset, keyLength + valueLength + 4); + Platform.putInt(base, offset + 4, keyLength); + offset += 8; + Platform.copyMemory(keyBase, keyOffset, base, offset, keyLength); + offset += keyLength; + Platform.copyMemory(valueBase, valueOffset, base, offset, valueLength); + + // --- Update bookkeeping data structures ----------------------------------------------------- + offset = currentPage.getBaseOffset(); + Platform.putInt(base, offset, Platform.getInt(base, offset) + 1); + pageCursor += recordLength; numElements++; bitset.set(pos); final long storedKeyAddress = taskMemoryManager.encodePageNumberAndOffset( - dataPage, recordOffset); + currentPage, recordOffset); longArray.set(pos * 2, storedKeyAddress); longArray.set(pos * 2 + 1, keyHashcode); updateAddressesAndSizes(storedKeyAddress); isDefined = true; + if (numElements > growthThreshold && longArray.size() < MAX_CAPACITY) { - growAndRehash(); + try { + growAndRehash(); + } catch (OutOfMemoryError oom) { + canGrowArray = false; + } } return true; } @@ -647,18 +701,26 @@ public final class BytesToBytesMap { * Acquire a new page from the memory manager. * @return whether there is enough space to allocate the new page. */ - private boolean acquireNewPage() { - MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes); - if (newPage == null) { - logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes); + private boolean acquireNewPage(long required) { + try { + currentPage = allocatePage(required); + } catch (OutOfMemoryError e) { return false; } - dataPages.add(newPage); - pageCursor = 0; - currentDataPage = newPage; + dataPages.add(currentPage); + Platform.putInt(currentPage.getBaseObject(), currentPage.getBaseOffset(), 0); + pageCursor = 4; return true; } + @Override + public long spill(long size, MemoryConsumer trigger) throws IOException { + if (trigger != this && destructiveIterator != null) { + return destructiveIterator.spill(size); + } + return 0L; + } + /** * Allocate new data structures for this map. When calling this outside of the constructor, * make sure to keep references to the old data structures so that you can free them. @@ -670,6 +732,7 @@ public final class BytesToBytesMap { // The capacity needs to be divisible by 64 so that our bit set can be sized properly capacity = Math.max((int) Math.min(MAX_CAPACITY, ByteArrayMethods.nextPowerOf2(capacity)), 64); assert (capacity <= MAX_CAPACITY); + acquireMemory(capacity * 16); longArray = new LongArray(MemoryBlock.fromLongArray(new long[capacity * 2])); bitset = new BitSet(MemoryBlock.fromLongArray(new long[capacity / 64])); @@ -678,22 +741,42 @@ public final class BytesToBytesMap { } /** + * Free the memory used by longArray. + */ + public void freeArray() { + updatePeakMemoryUsed(); + if (longArray != null) { + long used = longArray.memoryBlock().size(); + longArray = null; + releaseMemory(used); + bitset = null; + } + } + + /** * Free all allocated memory associated with this map, including the storage for keys and values * as well as the hash map array itself. * * This method is idempotent and can be called multiple times. */ public void free() { - updatePeakMemoryUsed(); - longArray = null; - bitset = null; + freeArray(); Iterator<MemoryBlock> dataPagesIterator = dataPages.iterator(); while (dataPagesIterator.hasNext()) { MemoryBlock dataPage = dataPagesIterator.next(); dataPagesIterator.remove(); - taskMemoryManager.freePage(dataPage); + freePage(dataPage); } assert(dataPages.isEmpty()); + + while (!spillWriters.isEmpty()) { + File file = spillWriters.removeFirst().getFile(); + if (file != null && file.exists()) { + if (!file.delete()) { + logger.error("Was unable to delete spill file {}", file.getAbsolutePath()); + } + } + } } public TaskMemoryManager getTaskMemoryManager() { @@ -782,7 +865,13 @@ public final class BytesToBytesMap { final int oldCapacity = (int) oldBitSet.capacity(); // Allocate the new data structures - allocate(Math.min(growthStrategy.nextCapacity(oldCapacity), MAX_CAPACITY)); + try { + allocate(Math.min(growthStrategy.nextCapacity(oldCapacity), MAX_CAPACITY)); + } catch (OutOfMemoryError oom) { + longArray = oldLongArray; + bitset = oldBitSet; + throw oom; + } // Re-mask (we don't recompute the hashcode because we stored all 32 bits of it) for (int pos = oldBitSet.nextSetBit(0); pos >= 0; pos = oldBitSet.nextSetBit(pos + 1)) { @@ -806,6 +895,7 @@ public final class BytesToBytesMap { } } } + releaseMemory(oldLongArray.memoryBlock().size()); if (enablePerfMetrics) { timeSpentResizingNs += System.nanoTime() - resizeStartTime; diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index e317ea391c..49a5a4b13b 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -17,39 +17,34 @@ package org.apache.spark.util.collection.unsafe.sort; +import javax.annotation.Nullable; import java.io.File; import java.io.IOException; import java.util.LinkedList; -import javax.annotation.Nullable; - -import scala.runtime.AbstractFunction0; -import scala.runtime.BoxedUnit; - import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleWriteMetrics; +import org.apache.spark.memory.MemoryConsumer; +import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.storage.BlockManager; -import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.MemoryBlock; -import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.util.TaskCompletionListener; import org.apache.spark.util.Utils; /** * External sorter based on {@link UnsafeInMemorySorter}. */ -public final class UnsafeExternalSorter { +public final class UnsafeExternalSorter extends MemoryConsumer { private final Logger logger = LoggerFactory.getLogger(UnsafeExternalSorter.class); - private final long pageSizeBytes; private final PrefixComparator prefixComparator; private final RecordComparator recordComparator; - private final int initialSize; private final TaskMemoryManager taskMemoryManager; private final BlockManager blockManager; private final TaskContext taskContext; @@ -69,14 +64,12 @@ public final class UnsafeExternalSorter { private final LinkedList<UnsafeSorterSpillWriter> spillWriters = new LinkedList<>(); // These variables are reset after spilling: - @Nullable private UnsafeInMemorySorter inMemSorter; - // Whether the in-mem sorter is created internally, or passed in from outside. - // If it is passed in from outside, we shouldn't release the in-mem sorter's memory. - private boolean isInMemSorterExternal = false; + @Nullable private volatile UnsafeInMemorySorter inMemSorter; + private MemoryBlock currentPage = null; - private long currentPagePosition = -1; - private long freeSpaceInCurrentPage = 0; + private long pageCursor = -1; private long peakMemoryUsedBytes = 0; + private volatile SpillableIterator readingIterator = null; public static UnsafeExternalSorter createWithExistingInMemorySorter( TaskMemoryManager taskMemoryManager, @@ -86,7 +79,7 @@ public final class UnsafeExternalSorter { PrefixComparator prefixComparator, int initialSize, long pageSizeBytes, - UnsafeInMemorySorter inMemorySorter) throws IOException { + UnsafeInMemorySorter inMemorySorter) { return new UnsafeExternalSorter(taskMemoryManager, blockManager, taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, inMemorySorter); } @@ -98,7 +91,7 @@ public final class UnsafeExternalSorter { RecordComparator recordComparator, PrefixComparator prefixComparator, int initialSize, - long pageSizeBytes) throws IOException { + long pageSizeBytes) { return new UnsafeExternalSorter(taskMemoryManager, blockManager, taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, null); } @@ -111,60 +104,41 @@ public final class UnsafeExternalSorter { PrefixComparator prefixComparator, int initialSize, long pageSizeBytes, - @Nullable UnsafeInMemorySorter existingInMemorySorter) throws IOException { + @Nullable UnsafeInMemorySorter existingInMemorySorter) { + super(taskMemoryManager, pageSizeBytes); this.taskMemoryManager = taskMemoryManager; this.blockManager = blockManager; this.taskContext = taskContext; this.recordComparator = recordComparator; this.prefixComparator = prefixComparator; - this.initialSize = initialSize; // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units // this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; this.fileBufferSizeBytes = 32 * 1024; - this.pageSizeBytes = pageSizeBytes; + // TODO: metrics tracking + integration with shuffle write metrics + // need to connect the write metrics to task metrics so we count the spill IO somewhere. this.writeMetrics = new ShuffleWriteMetrics(); if (existingInMemorySorter == null) { - initializeForWriting(); - // Acquire a new page as soon as we construct the sorter to ensure that we have at - // least one page to work with. Otherwise, other operators in the same task may starve - // this sorter (SPARK-9709). We don't need to do this if we already have an existing sorter. - acquireNewPage(); + this.inMemSorter = + new UnsafeInMemorySorter(taskMemoryManager, recordComparator, prefixComparator, initialSize); + acquireMemory(inMemSorter.getMemoryUsage()); } else { - this.isInMemSorterExternal = true; this.inMemSorter = existingInMemorySorter; + // will acquire after free the map } + this.peakMemoryUsedBytes = getMemoryUsage(); // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at // the end of the task. This is necessary to avoid memory leaks in when the downstream operator // does not fully consume the sorter's output (e.g. sort followed by limit). - taskContext.addOnCompleteCallback(new AbstractFunction0<BoxedUnit>() { - @Override - public BoxedUnit apply() { - cleanupResources(); - return null; + taskContext.addTaskCompletionListener( + new TaskCompletionListener() { + @Override + public void onTaskCompletion(TaskContext context) { + cleanupResources(); + } } - }); - } - - // TODO: metrics tracking + integration with shuffle write metrics - // need to connect the write metrics to task metrics so we count the spill IO somewhere. - - /** - * Allocates new sort data structures. Called when creating the sorter and after each spill. - */ - private void initializeForWriting() throws IOException { - // Note: Do not track memory for the pointer array for now because of SPARK-10474. - // In more detail, in TungstenAggregate we only reserve a page, but when we fall back to - // sort-based aggregation we try to acquire a page AND a pointer array, which inevitably - // fails if all other memory is already occupied. It should be safe to not track the array - // because its memory footprint is frequently much smaller than that of a page. This is a - // temporary hack that we should address in 1.6.0. - // TODO: track the pointer array memory! - this.writeMetrics = new ShuffleWriteMetrics(); - this.inMemSorter = - new UnsafeInMemorySorter(taskMemoryManager, recordComparator, prefixComparator, initialSize); - this.isInMemSorterExternal = false; + ); } /** @@ -173,14 +147,27 @@ public final class UnsafeExternalSorter { */ @VisibleForTesting public void closeCurrentPage() { - freeSpaceInCurrentPage = 0; + if (currentPage != null) { + pageCursor = currentPage.getBaseOffset() + currentPage.size(); + } } /** * Sort and spill the current records in response to memory pressure. */ - public void spill() throws IOException { - assert(inMemSorter != null); + @Override + public long spill(long size, MemoryConsumer trigger) throws IOException { + if (trigger != this) { + if (readingIterator != null) { + return readingIterator.spill(); + } + return 0L; + } + + if (inMemSorter == null || inMemSorter.numRecords() <= 0) { + return 0L; + } + logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)", Thread.currentThread().getId(), Utils.bytesToString(getMemoryUsage()), @@ -202,6 +189,8 @@ public final class UnsafeExternalSorter { spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix()); } spillWriter.close(); + + inMemSorter.reset(); } final long spillSize = freeMemory(); @@ -210,7 +199,7 @@ public final class UnsafeExternalSorter { // written to disk. This also counts the space needed to store the sorter's pointer array. taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); - initializeForWriting(); + return spillSize; } /** @@ -246,7 +235,7 @@ public final class UnsafeExternalSorter { } /** - * Free this sorter's in-memory data structures, including its data pages and pointer array. + * Free this sorter's data pages. * * @return the number of bytes freed. */ @@ -254,14 +243,12 @@ public final class UnsafeExternalSorter { updatePeakMemoryUsed(); long memoryFreed = 0; for (MemoryBlock block : allocatedPages) { - taskMemoryManager.freePage(block); memoryFreed += block.size(); + freePage(block); } - // TODO: track in-memory sorter memory usage (SPARK-10474) allocatedPages.clear(); currentPage = null; - currentPagePosition = -1; - freeSpaceInCurrentPage = 0; + pageCursor = 0; return memoryFreed; } @@ -283,8 +270,15 @@ public final class UnsafeExternalSorter { * Frees this sorter's in-memory data structures and cleans up its spill files. */ public void cleanupResources() { - deleteSpillFiles(); - freeMemory(); + synchronized (this) { + deleteSpillFiles(); + freeMemory(); + if (inMemSorter != null) { + long used = inMemSorter.getMemoryUsage(); + inMemSorter = null; + releaseMemory(used); + } + } } /** @@ -295,8 +289,28 @@ public final class UnsafeExternalSorter { private void growPointerArrayIfNecessary() throws IOException { assert(inMemSorter != null); if (!inMemSorter.hasSpaceForAnotherRecord()) { - // TODO: track the pointer array memory! (SPARK-10474) - inMemSorter.expandPointerArray(); + long used = inMemSorter.getMemoryUsage(); + long needed = used + inMemSorter.getMemoryToExpand(); + try { + acquireMemory(needed); // could trigger spilling + } catch (OutOfMemoryError e) { + // should have trigger spilling + assert(inMemSorter.hasSpaceForAnotherRecord()); + return; + } + // check if spilling is triggered or not + if (inMemSorter.hasSpaceForAnotherRecord()) { + releaseMemory(needed); + } else { + try { + inMemSorter.expandPointerArray(); + releaseMemory(used); + } catch (OutOfMemoryError oom) { + // Just in case that JVM had run out of memory + releaseMemory(needed); + spill(); + } + } } } @@ -304,101 +318,38 @@ public final class UnsafeExternalSorter { * Allocates more memory in order to insert an additional record. This will request additional * memory from the memory manager and spill if the requested memory can not be obtained. * - * @param requiredSpace the required space in the data page, in bytes, including space for storing + * @param required the required space in the data page, in bytes, including space for storing * the record size. This must be less than or equal to the page size (records * that exceed the page size are handled via a different code path which uses * special overflow pages). */ - private void acquireNewPageIfNecessary(int requiredSpace) throws IOException { - assert (requiredSpace <= pageSizeBytes); - if (requiredSpace > freeSpaceInCurrentPage) { - logger.trace("Required space {} is less than free space in current page ({})", requiredSpace, - freeSpaceInCurrentPage); - // TODO: we should track metrics on the amount of space wasted when we roll over to a new page - // without using the free space at the end of the current page. We should also do this for - // BytesToBytesMap. - if (requiredSpace > pageSizeBytes) { - throw new IOException("Required space " + requiredSpace + " is greater than page size (" + - pageSizeBytes + ")"); - } else { - acquireNewPage(); - } + private void acquireNewPageIfNecessary(int required) { + if (currentPage == null || + pageCursor + required > currentPage.getBaseOffset() + currentPage.size()) { + // TODO: try to find space on previous pages + currentPage = allocatePage(required); + pageCursor = currentPage.getBaseOffset(); + allocatedPages.add(currentPage); } } /** - * Acquire a new page from the memory manager. - * - * If there is not enough space to allocate the new page, spill all existing ones - * and try again. If there is still not enough space, report error to the caller. - */ - private void acquireNewPage() throws IOException { - currentPage = taskMemoryManager.allocatePage(pageSizeBytes); - if (currentPage == null) { - spill(); - currentPage = taskMemoryManager.allocatePage(pageSizeBytes); - if (currentPage == null) { - throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory"); - } - } - currentPagePosition = currentPage.getBaseOffset(); - freeSpaceInCurrentPage = pageSizeBytes; - allocatedPages.add(currentPage); - } - - /** * Write a record to the sorter. */ - public void insertRecord( - Object recordBaseObject, - long recordBaseOffset, - int lengthInBytes, - long prefix) throws IOException { + public void insertRecord(Object recordBase, long recordOffset, int length, long prefix) + throws IOException { growPointerArrayIfNecessary(); // Need 4 bytes to store the record length. - final int totalSpaceRequired = lengthInBytes + 4; - - // --- Figure out where to insert the new record ---------------------------------------------- - - final MemoryBlock dataPage; - long dataPagePosition; - boolean useOverflowPage = totalSpaceRequired > pageSizeBytes; - if (useOverflowPage) { - long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired); - // The record is larger than the page size, so allocate a special overflow page just to hold - // that record. - MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize); - if (overflowPage == null) { - spill(); - overflowPage = taskMemoryManager.allocatePage(overflowPageSize); - if (overflowPage == null) { - throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory"); - } - } - allocatedPages.add(overflowPage); - dataPage = overflowPage; - dataPagePosition = overflowPage.getBaseOffset(); - } else { - // The record is small enough to fit in a regular data page, but the current page might not - // have enough space to hold it (or no pages have been allocated yet). - acquireNewPageIfNecessary(totalSpaceRequired); - dataPage = currentPage; - dataPagePosition = currentPagePosition; - // Update bookkeeping information - freeSpaceInCurrentPage -= totalSpaceRequired; - currentPagePosition += totalSpaceRequired; - } - final Object dataPageBaseObject = dataPage.getBaseObject(); - - // --- Insert the record ---------------------------------------------------------------------- - - final long recordAddress = - taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition); - Platform.putInt(dataPageBaseObject, dataPagePosition, lengthInBytes); - dataPagePosition += 4; - Platform.copyMemory( - recordBaseObject, recordBaseOffset, dataPageBaseObject, dataPagePosition, lengthInBytes); + final int required = length + 4; + acquireNewPageIfNecessary(required); + + final Object base = currentPage.getBaseObject(); + final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor); + Platform.putInt(base, pageCursor, length); + pageCursor += 4; + Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length); + pageCursor += length; assert(inMemSorter != null); inMemSorter.insertRecord(recordAddress, prefix); } @@ -411,59 +362,24 @@ public final class UnsafeExternalSorter { * * record length = key length + value length + 4 */ - public void insertKVRecord( - Object keyBaseObj, long keyOffset, int keyLen, - Object valueBaseObj, long valueOffset, int valueLen, long prefix) throws IOException { + public void insertKVRecord(Object keyBase, long keyOffset, int keyLen, + Object valueBase, long valueOffset, int valueLen, long prefix) + throws IOException { growPointerArrayIfNecessary(); - final int totalSpaceRequired = keyLen + valueLen + 4 + 4; - - // --- Figure out where to insert the new record ---------------------------------------------- - - final MemoryBlock dataPage; - long dataPagePosition; - boolean useOverflowPage = totalSpaceRequired > pageSizeBytes; - if (useOverflowPage) { - long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired); - // The record is larger than the page size, so allocate a special overflow page just to hold - // that record. - MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize); - if (overflowPage == null) { - spill(); - overflowPage = taskMemoryManager.allocatePage(overflowPageSize); - if (overflowPage == null) { - throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory"); - } - } - allocatedPages.add(overflowPage); - dataPage = overflowPage; - dataPagePosition = overflowPage.getBaseOffset(); - } else { - // The record is small enough to fit in a regular data page, but the current page might not - // have enough space to hold it (or no pages have been allocated yet). - acquireNewPageIfNecessary(totalSpaceRequired); - dataPage = currentPage; - dataPagePosition = currentPagePosition; - // Update bookkeeping information - freeSpaceInCurrentPage -= totalSpaceRequired; - currentPagePosition += totalSpaceRequired; - } - final Object dataPageBaseObject = dataPage.getBaseObject(); - - // --- Insert the record ---------------------------------------------------------------------- - - final long recordAddress = - taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition); - Platform.putInt(dataPageBaseObject, dataPagePosition, keyLen + valueLen + 4); - dataPagePosition += 4; - - Platform.putInt(dataPageBaseObject, dataPagePosition, keyLen); - dataPagePosition += 4; - - Platform.copyMemory(keyBaseObj, keyOffset, dataPageBaseObject, dataPagePosition, keyLen); - dataPagePosition += keyLen; - - Platform.copyMemory(valueBaseObj, valueOffset, dataPageBaseObject, dataPagePosition, valueLen); + final int required = keyLen + valueLen + 4 + 4; + acquireNewPageIfNecessary(required); + + final Object base = currentPage.getBaseObject(); + final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor); + Platform.putInt(base, pageCursor, keyLen + valueLen + 4); + pageCursor += 4; + Platform.putInt(base, pageCursor, keyLen); + pageCursor += 4; + Platform.copyMemory(keyBase, keyOffset, base, pageCursor, keyLen); + pageCursor += keyLen; + Platform.copyMemory(valueBase, valueOffset, base, pageCursor, valueLen); + pageCursor += valueLen; assert(inMemSorter != null); inMemSorter.insertRecord(recordAddress, prefix); @@ -475,10 +391,10 @@ public final class UnsafeExternalSorter { */ public UnsafeSorterIterator getSortedIterator() throws IOException { assert(inMemSorter != null); - final UnsafeInMemorySorter.SortedIterator inMemoryIterator = inMemSorter.getSortedIterator(); - int numIteratorsToMerge = spillWriters.size() + (inMemoryIterator.hasNext() ? 1 : 0); + readingIterator = new SpillableIterator(inMemSorter.getSortedIterator()); + int numIteratorsToMerge = spillWriters.size() + (readingIterator.hasNext() ? 1 : 0); if (spillWriters.isEmpty()) { - return inMemoryIterator; + return readingIterator; } else { final UnsafeSorterSpillMerger spillMerger = new UnsafeSorterSpillMerger(recordComparator, prefixComparator, numIteratorsToMerge); @@ -486,9 +402,113 @@ public final class UnsafeExternalSorter { spillMerger.addSpillIfNotEmpty(spillWriter.getReader(blockManager)); } spillWriters.clear(); - spillMerger.addSpillIfNotEmpty(inMemoryIterator); + spillMerger.addSpillIfNotEmpty(readingIterator); return spillMerger.getSortedIterator(); } } + + /** + * An UnsafeSorterIterator that support spilling. + */ + class SpillableIterator extends UnsafeSorterIterator { + private UnsafeSorterIterator upstream; + private UnsafeSorterIterator nextUpstream = null; + private MemoryBlock lastPage = null; + private boolean loaded = false; + private int numRecords = 0; + + public SpillableIterator(UnsafeInMemorySorter.SortedIterator inMemIterator) { + this.upstream = inMemIterator; + this.numRecords = inMemIterator.numRecordsLeft(); + } + + public long spill() throws IOException { + synchronized (this) { + if (!(upstream instanceof UnsafeInMemorySorter.SortedIterator && nextUpstream == null + && numRecords > 0)) { + return 0L; + } + + UnsafeInMemorySorter.SortedIterator inMemIterator = + ((UnsafeInMemorySorter.SortedIterator) upstream).clone(); + + final UnsafeSorterSpillWriter spillWriter = + new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, numRecords); + while (inMemIterator.hasNext()) { + inMemIterator.loadNext(); + final Object baseObject = inMemIterator.getBaseObject(); + final long baseOffset = inMemIterator.getBaseOffset(); + final int recordLength = inMemIterator.getRecordLength(); + spillWriter.write(baseObject, baseOffset, recordLength, inMemIterator.getKeyPrefix()); + } + spillWriter.close(); + spillWriters.add(spillWriter); + nextUpstream = spillWriter.getReader(blockManager); + + long released = 0L; + synchronized (UnsafeExternalSorter.this) { + // release the pages except the one that is used + for (MemoryBlock page : allocatedPages) { + if (!loaded || page.getBaseObject() != inMemIterator.getBaseObject()) { + released += page.size(); + freePage(page); + } else { + lastPage = page; + } + } + allocatedPages.clear(); + } + return released; + } + } + + @Override + public boolean hasNext() { + return numRecords > 0; + } + + @Override + public void loadNext() throws IOException { + synchronized (this) { + loaded = true; + if (nextUpstream != null) { + // Just consumed the last record from in memory iterator + if (lastPage != null) { + freePage(lastPage); + lastPage = null; + } + upstream = nextUpstream; + nextUpstream = null; + + assert(inMemSorter != null); + long used = inMemSorter.getMemoryUsage(); + inMemSorter = null; + releaseMemory(used); + } + numRecords--; + upstream.loadNext(); + } + } + + @Override + public Object getBaseObject() { + return upstream.getBaseObject(); + } + + @Override + public long getBaseOffset() { + return upstream.getBaseOffset(); + } + + @Override + public int getRecordLength() { + return upstream.getRecordLength(); + } + + @Override + public long getKeyPrefix() { + return upstream.getKeyPrefix(); + } + } } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 5aad72c374..1480f0681e 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -70,12 +70,12 @@ public final class UnsafeInMemorySorter { * Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. */ - private long[] pointerArray; + private long[] array; /** * The position in the sort buffer where new records can be inserted. */ - private int pointerArrayInsertPosition = 0; + private int pos = 0; public UnsafeInMemorySorter( final TaskMemoryManager memoryManager, @@ -83,37 +83,43 @@ public final class UnsafeInMemorySorter { final PrefixComparator prefixComparator, int initialSize) { assert (initialSize > 0); - this.pointerArray = new long[initialSize * 2]; + this.array = new long[initialSize * 2]; this.memoryManager = memoryManager; this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE); this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager); } + public void reset() { + pos = 0; + } + /** * @return the number of records that have been inserted into this sorter. */ public int numRecords() { - return pointerArrayInsertPosition / 2; + return pos / 2; } - public long getMemoryUsage() { - return pointerArray.length * 8L; + private int newLength() { + return array.length < Integer.MAX_VALUE / 2 ? (array.length * 2) : Integer.MAX_VALUE; + } + + public long getMemoryToExpand() { + return (long) (newLength() - array.length) * 8L; } - static long getMemoryRequirementsForPointerArray(long numEntries) { - return numEntries * 2L * 8L; + public long getMemoryUsage() { + return array.length * 8L; } public boolean hasSpaceForAnotherRecord() { - return pointerArrayInsertPosition + 2 < pointerArray.length; + return pos + 2 <= array.length; } public void expandPointerArray() { - final long[] oldArray = pointerArray; - // Guard against overflow: - final int newLength = oldArray.length * 2 > 0 ? (oldArray.length * 2) : Integer.MAX_VALUE; - pointerArray = new long[newLength]; - System.arraycopy(oldArray, 0, pointerArray, 0, oldArray.length); + final long[] oldArray = array; + array = new long[newLength()]; + System.arraycopy(oldArray, 0, array, 0, oldArray.length); } /** @@ -127,10 +133,10 @@ public final class UnsafeInMemorySorter { if (!hasSpaceForAnotherRecord()) { expandPointerArray(); } - pointerArray[pointerArrayInsertPosition] = recordPointer; - pointerArrayInsertPosition++; - pointerArray[pointerArrayInsertPosition] = keyPrefix; - pointerArrayInsertPosition++; + array[pos] = recordPointer; + pos++; + array[pos] = keyPrefix; + pos++; } public static final class SortedIterator extends UnsafeSorterIterator { @@ -153,11 +159,25 @@ public final class UnsafeInMemorySorter { this.sortBuffer = sortBuffer; } + public SortedIterator clone () { + SortedIterator iter = new SortedIterator(memoryManager, sortBufferInsertPosition, sortBuffer); + iter.position = position; + iter.baseObject = baseObject; + iter.baseOffset = baseOffset; + iter.keyPrefix = keyPrefix; + iter.recordLength = recordLength; + return iter; + } + @Override public boolean hasNext() { return position < sortBufferInsertPosition; } + public int numRecordsLeft() { + return (sortBufferInsertPosition - position) / 2; + } + @Override public void loadNext() { // This pointer points to a 4-byte record length, followed by the record's bytes @@ -187,7 +207,7 @@ public final class UnsafeInMemorySorter { * {@code next()} will return the same mutable object. */ public SortedIterator getSortedIterator() { - sorter.sort(pointerArray, 0, pointerArrayInsertPosition / 2, sortComparator); - return new SortedIterator(memoryManager, pointerArrayInsertPosition, pointerArray); + sorter.sort(array, 0, pos / 2, sortComparator); + return new SortedIterator(memoryManager, pos, array); } } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index 501dfe77d1..039e940a35 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -20,18 +20,18 @@ package org.apache.spark.util.collection.unsafe.sort; import java.io.*; import com.google.common.io.ByteStreams; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.apache.spark.storage.BlockId; import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.Platform; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** * Reads spill files written by {@link UnsafeSorterSpillWriter} (see that class for a description * of the file format). */ -final class UnsafeSorterSpillReader extends UnsafeSorterIterator { +public final class UnsafeSorterSpillReader extends UnsafeSorterIterator { private static final Logger logger = LoggerFactory.getLogger(UnsafeSorterSpillReader.class); private final File file; diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java index e59a84ff8d..234e21140a 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java @@ -35,7 +35,7 @@ import org.apache.spark.unsafe.Platform; * * [# of records (int)] [[len (int)][prefix (long)][data (bytes)]...] */ -final class UnsafeSorterSpillWriter { +public final class UnsafeSorterSpillWriter { static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024; diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala index 6c9a71c385..b0cf2696a3 100644 --- a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala @@ -24,6 +24,7 @@ import scala.collection.mutable.ArrayBuffer import com.google.common.annotations.VisibleForTesting +import org.apache.spark.util.Utils import org.apache.spark.{SparkException, TaskContext, SparkConf, Logging} import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore} import org.apache.spark.unsafe.array.ByteArrayMethods @@ -215,8 +216,12 @@ private[spark] abstract class MemoryManager(conf: SparkConf, numCores: Int) exte final def releaseExecutionMemory(numBytes: Long, taskAttemptId: Long): Unit = synchronized { val curMem = executionMemoryForTask.getOrElse(taskAttemptId, 0L) if (curMem < numBytes) { - throw new SparkException( - s"Internal error: release called on $numBytes bytes but task only has $curMem") + if (Utils.isTesting) { + throw new SparkException( + s"Internal error: release called on $numBytes bytes but task only has $curMem") + } else { + logWarning(s"Internal error: release called on $numBytes bytes but task only has $curMem") + } } if (executionMemoryForTask.contains(taskAttemptId)) { executionMemoryForTask(taskAttemptId) -= numBytes diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala index a76891acf0..9e002621a6 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -78,7 +78,7 @@ private[spark] trait Spillable[C] extends Logging { if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) { // Claim up to double our current memory from the shuffle memory pool val amountToRequest = 2 * currentMemory - myMemoryThreshold - val granted = taskMemoryManager.acquireExecutionMemory(amountToRequest) + val granted = taskMemoryManager.acquireExecutionMemory(amountToRequest, null) myMemoryThreshold += granted // If we were granted too little memory to grow further (either tryToAcquire returned 0, // or we already had more memory than myMemoryThreshold), spill the current collection @@ -107,7 +107,7 @@ private[spark] trait Spillable[C] extends Logging { */ def releaseMemory(): Unit = { // The amount we requested does not include the initial memory tracking threshold - taskMemoryManager.releaseExecutionMemory(myMemoryThreshold - initialMemoryThreshold) + taskMemoryManager.releaseExecutionMemory(myMemoryThreshold - initialMemoryThreshold, null) myMemoryThreshold = initialMemoryThreshold } |