From 85e654c5ec87e666a8845bfd77185c1ea57b268a Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 25 Oct 2015 21:19:52 -0700 Subject: [SPARK-10984] Simplify *MemoryManager class structure This patch refactors the MemoryManager class structure. After #9000, Spark had the following classes: - MemoryManager - StaticMemoryManager - ExecutorMemoryManager - TaskMemoryManager - ShuffleMemoryManager This is fairly confusing. To simplify things, this patch consolidates several of these classes: - ShuffleMemoryManager and ExecutorMemoryManager were merged into MemoryManager. - TaskMemoryManager is moved into Spark Core. **Key changes and tasks**: - [x] Merge ExecutorMemoryManager into MemoryManager. - [x] Move pooling logic into Allocator. - [x] Move TaskMemoryManager from `spark-unsafe` to `spark-core`. - [x] Refactor the existing Tungsten TaskMemoryManager interactions so Tungsten code use only this and not both this and ShuffleMemoryManager. - [x] Refactor non-Tungsten code to use the TaskMemoryManager instead of ShuffleMemoryManager. - [x] Merge ShuffleMemoryManager into MemoryManager. - [x] Move code - [x] ~~Simplify 1/n calculation.~~ **Will defer to followup, since this needs more work.** - [x] Port ShuffleMemoryManagerSuite tests. - [x] Move classes from `unsafe` package to `memory` package. - [ ] Figure out how to handle the hacky use of the memory managers in HashedRelation's broadcast variable construction. - [x] Test porting and cleanup: several tests relied on mock functionality (such as `TestShuffleMemoryManager.markAsOutOfMemory`) which has been changed or broken during the memory manager consolidation - [x] AbstractBytesToBytesMapSuite - [x] UnsafeExternalSorterSuite - [x] UnsafeFixedWidthAggregationMapSuite - [x] UnsafeKVExternalSorterSuite **Compatiblity notes**: - This patch introduces breaking changes in `ExternalAppendOnlyMap`, which is marked as `DevloperAPI` (likely for legacy reasons): this class now cannot be used outside of a task. Author: Josh Rosen Closes #9127 from JoshRosen/SPARK-10984. --- .../org/apache/spark/memory/TaskMemoryManager.java | 283 ++++++++++++++++++ .../spark/shuffle/sort/PackedRecordPointer.java | 4 +- .../spark/shuffle/sort/ShuffleExternalSorter.java | 57 ++-- .../spark/shuffle/sort/UnsafeShuffleWriter.java | 7 +- .../apache/spark/unsafe/map/BytesToBytesMap.java | 36 +-- .../unsafe/sort/RecordPointerAndKeyPrefix.java | 4 +- .../unsafe/sort/UnsafeExternalSorter.java | 51 ++-- .../unsafe/sort/UnsafeInMemorySorter.java | 2 +- .../src/main/scala/org/apache/spark/SparkEnv.scala | 23 +- .../main/scala/org/apache/spark/TaskContext.scala | 2 +- .../scala/org/apache/spark/TaskContextImpl.scala | 2 +- .../scala/org/apache/spark/executor/Executor.scala | 4 +- .../org/apache/spark/memory/MemoryManager.scala | 197 +++++++++++-- .../apache/spark/memory/StaticMemoryManager.scala | 12 +- .../apache/spark/memory/UnifiedMemoryManager.scala | 12 +- .../scala/org/apache/spark/scheduler/Task.scala | 6 +- .../spark/shuffle/BlockStoreShuffleReader.scala | 5 +- .../spark/shuffle/ShuffleMemoryManager.scala | 209 ------------- .../spark/shuffle/sort/SortShuffleManager.scala | 1 - .../spark/shuffle/sort/SortShuffleWriter.scala | 6 +- .../util/collection/ExternalAppendOnlyMap.scala | 49 +++- .../spark/util/collection/ExternalSorter.scala | 8 +- .../apache/spark/util/collection/Spillable.scala | 16 +- .../spark/memory/TaskMemoryManagerSuite.java | 59 ++++ .../shuffle/sort/PackedRecordPointerSuite.java | 12 +- .../shuffle/sort/ShuffleInMemorySorterSuite.java | 9 +- .../shuffle/sort/UnsafeShuffleWriterSuite.java | 53 ++-- .../unsafe/map/AbstractBytesToBytesMapSuite.java | 108 ++----- .../unsafe/map/BytesToBytesMapOffHeapSuite.java | 7 +- .../unsafe/map/BytesToBytesMapOnHeapSuite.java | 7 +- .../unsafe/sort/UnsafeExternalSorterSuite.java | 34 +-- .../unsafe/sort/UnsafeInMemorySorterSuite.java | 13 +- .../test/scala/org/apache/spark/FailureSuite.scala | 4 +- .../memory/GrantEverythingMemoryManager.scala | 54 ++++ .../apache/spark/memory/MemoryManagerSuite.scala | 134 +++++++++ .../apache/spark/memory/MemoryTestingUtils.scala | 37 +++ .../spark/memory/StaticMemoryManagerSuite.scala | 24 +- .../spark/memory/UnifiedMemoryManagerSuite.scala | 26 +- .../spark/shuffle/ShuffleMemoryManagerSuite.scala | 326 --------------------- .../storage/BlockManagerReplicationSuite.scala | 4 +- .../apache/spark/storage/BlockManagerSuite.scala | 8 +- .../collection/ExternalAppendOnlyMapSuite.scala | 60 ++-- .../util/collection/ExternalSorterSuite.scala | 48 +-- .../sql/execution/UnsafeExternalRowSorter.java | 1 - .../execution/UnsafeFixedWidthAggregationMap.java | 12 +- .../sql/execution/UnsafeKVExternalSorter.java | 22 +- .../aggregate/TungstenAggregationIterator.scala | 9 +- .../execution/datasources/WriterContainer.scala | 3 +- .../spark/sql/execution/joins/HashedRelation.scala | 21 +- .../org/apache/spark/sql/execution/sort.scala | 5 +- .../sql/execution/TestShuffleMemoryManager.scala | 75 ----- .../UnsafeFixedWidthAggregationMapSuite.scala | 54 ++-- .../execution/UnsafeKVExternalSorterSuite.scala | 19 +- .../sql/execution/UnsafeRowSerializerSuite.scala | 10 +- .../TungstenAggregationIteratorSuite.scala | 4 +- .../streaming/ReceivedBlockHandlerSuite.scala | 2 +- .../spark/unsafe/memory/ExecutorMemoryManager.java | 111 ------- .../spark/unsafe/memory/HeapMemoryAllocator.java | 51 +++- .../apache/spark/unsafe/memory/MemoryBlock.java | 5 +- .../spark/unsafe/memory/TaskMemoryManager.java | 286 ------------------ .../unsafe/memory/TaskMemoryManagerSuite.java | 64 ---- 61 files changed, 1205 insertions(+), 1572 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java delete mode 100644 core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala create mode 100644 core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java create mode 100644 core/src/test/scala/org/apache/spark/memory/GrantEverythingMemoryManager.scala create mode 100644 core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala delete mode 100644 core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala delete mode 100644 unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java delete mode 100644 unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java delete mode 100644 unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java new file mode 100644 index 0000000000..7b31c90dac --- /dev/null +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -0,0 +1,283 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory; + +import java.util.*; + +import com.google.common.annotations.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.unsafe.memory.MemoryBlock; + +/** + * Manages the memory allocated by an individual task. + *

+ * Most of the complexity in this class deals with encoding of off-heap addresses into 64-bit longs. + * In off-heap mode, memory can be directly addressed with 64-bit longs. In on-heap mode, memory is + * addressed by the combination of a base Object reference and a 64-bit offset within that object. + * This is a problem when we want to store pointers to data structures inside of other structures, + * such as record pointers inside hashmaps or sorting buffers. Even if we decided to use 128 bits + * to address memory, we can't just store the address of the base object since it's not guaranteed + * to remain stable as the heap gets reorganized due to GC. + *

+ * Instead, we use the following approach to encode record pointers in 64-bit longs: for off-heap + * mode, just store the raw address, and for on-heap mode use the upper 13 bits of the address to + * store a "page number" and the lower 51 bits to store an offset within this page. These page + * numbers are used to index into a "page table" array inside of the MemoryManager in order to + * retrieve the base object. + *

+ * This allows us to address 8192 pages. In on-heap mode, the maximum page size is limited by the + * maximum size of a long[] array, allowing us to address 8192 * 2^32 * 8 bytes, which is + * approximately 35 terabytes of memory. + */ +public class TaskMemoryManager { + + private final Logger logger = LoggerFactory.getLogger(TaskMemoryManager.class); + + /** The number of bits used to address the page table. */ + private static final int PAGE_NUMBER_BITS = 13; + + /** The number of bits used to encode offsets in data pages. */ + @VisibleForTesting + static final int OFFSET_BITS = 64 - PAGE_NUMBER_BITS; // 51 + + /** The number of entries in the page table. */ + private static final int PAGE_TABLE_SIZE = 1 << PAGE_NUMBER_BITS; + + /** + * Maximum supported data page size (in bytes). In principle, the maximum addressable page size is + * (1L << OFFSET_BITS) bytes, which is 2+ petabytes. However, the on-heap allocator's maximum page + * size is limited by the maximum amount of data that can be stored in a long[] array, which is + * (2^32 - 1) * 8 bytes (or 16 gigabytes). Therefore, we cap this at 16 gigabytes. + */ + public static final long MAXIMUM_PAGE_SIZE_BYTES = ((1L << 31) - 1) * 8L; + + /** Bit mask for the lower 51 bits of a long. */ + private static final long MASK_LONG_LOWER_51_BITS = 0x7FFFFFFFFFFFFL; + + /** Bit mask for the upper 13 bits of a long */ + private static final long MASK_LONG_UPPER_13_BITS = ~MASK_LONG_LOWER_51_BITS; + + /** + * Similar to an operating system's page table, this array maps page numbers into base object + * pointers, allowing us to translate between the hashtable's internal 64-bit address + * representation and the baseObject+offset representation which we use to support both in- and + * off-heap addresses. When using an off-heap allocator, every entry in this map will be `null`. + * When using an in-heap allocator, the entries in this map will point to pages' base objects. + * Entries are added to this map as new data pages are allocated. + */ + private final MemoryBlock[] pageTable = new MemoryBlock[PAGE_TABLE_SIZE]; + + /** + * Bitmap for tracking free pages. + */ + private final BitSet allocatedPages = new BitSet(PAGE_TABLE_SIZE); + + private final MemoryManager memoryManager; + + private final long taskAttemptId; + + /** + * Tracks whether we're in-heap or off-heap. For off-heap, we short-circuit most of these methods + * without doing any masking or lookups. Since this branching should be well-predicted by the JIT, + * this extra layer of indirection / abstraction hopefully shouldn't be too expensive. + */ + private final boolean inHeap; + + /** + * Construct a new TaskMemoryManager. + */ + public TaskMemoryManager(MemoryManager memoryManager, long taskAttemptId) { + this.inHeap = memoryManager.tungstenMemoryIsAllocatedInHeap(); + this.memoryManager = memoryManager; + this.taskAttemptId = taskAttemptId; + } + + /** + * Acquire N bytes of memory for execution, evicting cached blocks if necessary. + * @return number of bytes successfully granted (<= N). + */ + public long acquireExecutionMemory(long size) { + return memoryManager.acquireExecutionMemory(size, taskAttemptId); + } + + /** + * Release N bytes of execution memory. + */ + public void releaseExecutionMemory(long size) { + memoryManager.releaseExecutionMemory(size, taskAttemptId); + } + + public long pageSizeBytes() { + return memoryManager.pageSizeBytes(); + } + + /** + * Allocate a block of memory that will be tracked in the MemoryManager's page table; this is + * intended for allocating large blocks of Tungsten memory that will be shared between operators. + * + * Returns `null` if there was not enough memory to allocate the page. + */ + public MemoryBlock allocatePage(long size) { + if (size > MAXIMUM_PAGE_SIZE_BYTES) { + throw new IllegalArgumentException( + "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE_BYTES + " bytes"); + } + + final int pageNumber; + synchronized (this) { + pageNumber = allocatedPages.nextClearBit(0); + if (pageNumber >= PAGE_TABLE_SIZE) { + throw new IllegalStateException( + "Have already allocated a maximum of " + PAGE_TABLE_SIZE + " pages"); + } + allocatedPages.set(pageNumber); + } + final long acquiredExecutionMemory = acquireExecutionMemory(size); + if (acquiredExecutionMemory != size) { + releaseExecutionMemory(acquiredExecutionMemory); + synchronized (this) { + allocatedPages.clear(pageNumber); + } + return null; + } + final MemoryBlock page = memoryManager.tungstenMemoryAllocator().allocate(size); + page.pageNumber = pageNumber; + pageTable[pageNumber] = page; + if (logger.isTraceEnabled()) { + logger.trace("Allocate page number {} ({} bytes)", pageNumber, size); + } + return page; + } + + /** + * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage(long)}. + */ + public void freePage(MemoryBlock page) { + assert (page.pageNumber != -1) : + "Called freePage() on memory that wasn't allocated with allocatePage()"; + assert(allocatedPages.get(page.pageNumber)); + pageTable[page.pageNumber] = null; + synchronized (this) { + allocatedPages.clear(page.pageNumber); + } + if (logger.isTraceEnabled()) { + logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size()); + } + long pageSize = page.size(); + memoryManager.tungstenMemoryAllocator().free(page); + releaseExecutionMemory(pageSize); + } + + /** + * Given a memory page and offset within that page, encode this address into a 64-bit long. + * This address will remain valid as long as the corresponding page has not been freed. + * + * @param page a data page allocated by {@link TaskMemoryManager#allocatePage(long)}/ + * @param offsetInPage an offset in this page which incorporates the base offset. In other words, + * this should be the value that you would pass as the base offset into an + * UNSAFE call (e.g. page.baseOffset() + something). + * @return an encoded page address. + */ + public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) { + if (!inHeap) { + // In off-heap mode, an offset is an absolute address that may require a full 64 bits to + // encode. Due to our page size limitation, though, we can convert this into an offset that's + // relative to the page's base offset; this relative offset will fit in 51 bits. + offsetInPage -= page.getBaseOffset(); + } + return encodePageNumberAndOffset(page.pageNumber, offsetInPage); + } + + @VisibleForTesting + public static long encodePageNumberAndOffset(int pageNumber, long offsetInPage) { + assert (pageNumber != -1) : "encodePageNumberAndOffset called with invalid page"; + return (((long) pageNumber) << OFFSET_BITS) | (offsetInPage & MASK_LONG_LOWER_51_BITS); + } + + @VisibleForTesting + public static int decodePageNumber(long pagePlusOffsetAddress) { + return (int) ((pagePlusOffsetAddress & MASK_LONG_UPPER_13_BITS) >>> OFFSET_BITS); + } + + private static long decodeOffset(long pagePlusOffsetAddress) { + return (pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS); + } + + /** + * Get the page associated with an address encoded by + * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} + */ + public Object getPage(long pagePlusOffsetAddress) { + if (inHeap) { + final int pageNumber = decodePageNumber(pagePlusOffsetAddress); + assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); + final MemoryBlock page = pageTable[pageNumber]; + assert (page != null); + assert (page.getBaseObject() != null); + return page.getBaseObject(); + } else { + return null; + } + } + + /** + * Get the offset associated with an address encoded by + * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} + */ + public long getOffsetInPage(long pagePlusOffsetAddress) { + final long offsetInPage = decodeOffset(pagePlusOffsetAddress); + if (inHeap) { + return offsetInPage; + } else { + // In off-heap mode, an offset is an absolute address. In encodePageNumberAndOffset, we + // converted the absolute address into a relative address. Here, we invert that operation: + final int pageNumber = decodePageNumber(pagePlusOffsetAddress); + assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); + final MemoryBlock page = pageTable[pageNumber]; + assert (page != null); + return page.getBaseOffset() + offsetInPage; + } + } + + /** + * Clean up all allocated memory and pages. Returns the number of bytes freed. A non-zero return + * value can be used to detect memory leaks. + */ + public long cleanUpAllAllocatedMemory() { + long freedBytes = 0; + for (MemoryBlock page : pageTable) { + if (page != null) { + freedBytes += page.size(); + freePage(page); + } + } + + freedBytes += memoryManager.releaseAllExecutionMemoryForTask(taskAttemptId); + + return freedBytes; + } + + /** + * Returns the memory consumption, in bytes, for the current task + */ + public long getMemoryConsumptionForThisTask() { + return memoryManager.getExecutionMemoryUsageForTask(taskAttemptId); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java b/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java index c11711966f..f8f2b220e1 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java @@ -17,6 +17,8 @@ package org.apache.spark.shuffle.sort; +import org.apache.spark.memory.TaskMemoryManager; + /** * Wrapper around an 8-byte word that holds a 24-bit partition number and 40-bit record pointer. *

@@ -26,7 +28,7 @@ package org.apache.spark.shuffle.sort; * * This implies that the maximum addressable page size is 2^27 bits = 128 megabytes, assuming that * our offsets in pages are not 8-byte-word-aligned. Since we have 2^13 pages (based off the - * 13-bit page numbers assigned by {@link org.apache.spark.unsafe.memory.TaskMemoryManager}), this + * 13-bit page numbers assigned by {@link TaskMemoryManager}), this * implies that we can address 2^13 * 128 megabytes = 1 terabyte of RAM per task. *

* Assuming word-alignment would allow for a 1 gigabyte maximum page size, but we leave this diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index 85fdaa8115..f43236f41a 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -33,14 +33,13 @@ import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.serializer.DummySerializerInstance; import org.apache.spark.serializer.SerializerInstance; -import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.DiskBlockObjectWriter; import org.apache.spark.storage.TempShuffleBlockId; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.memory.MemoryBlock; -import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.util.Utils; /** @@ -72,7 +71,6 @@ final class ShuffleExternalSorter { @VisibleForTesting final int maxRecordSizeBytes; private final TaskMemoryManager taskMemoryManager; - private final ShuffleMemoryManager shuffleMemoryManager; private final BlockManager blockManager; private final TaskContext taskContext; private final ShuffleWriteMetrics writeMetrics; @@ -105,7 +103,6 @@ final class ShuffleExternalSorter { public ShuffleExternalSorter( TaskMemoryManager memoryManager, - ShuffleMemoryManager shuffleMemoryManager, BlockManager blockManager, TaskContext taskContext, int initialSize, @@ -113,7 +110,6 @@ final class ShuffleExternalSorter { SparkConf conf, ShuffleWriteMetrics writeMetrics) throws IOException { this.taskMemoryManager = memoryManager; - this.shuffleMemoryManager = shuffleMemoryManager; this.blockManager = blockManager; this.taskContext = taskContext; this.initialSize = initialSize; @@ -124,7 +120,7 @@ final class ShuffleExternalSorter { this.numElementsForSpillThreshold = conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MAX_VALUE); this.pageSizeBytes = (int) Math.min( - PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, shuffleMemoryManager.pageSizeBytes()); + PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, taskMemoryManager.pageSizeBytes()); this.maxRecordSizeBytes = pageSizeBytes - 4; this.writeMetrics = writeMetrics; initializeForWriting(); @@ -140,9 +136,9 @@ final class ShuffleExternalSorter { private void initializeForWriting() throws IOException { // TODO: move this sizing calculation logic into a static method of sorter: final long memoryRequested = initialSize * 8L; - final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested); + final long memoryAcquired = taskMemoryManager.acquireExecutionMemory(memoryRequested); if (memoryAcquired != memoryRequested) { - shuffleMemoryManager.release(memoryAcquired); + taskMemoryManager.releaseExecutionMemory(memoryAcquired); throw new IOException("Could not acquire " + memoryRequested + " bytes of memory"); } @@ -272,6 +268,7 @@ final class ShuffleExternalSorter { */ @VisibleForTesting void spill() throws IOException { + assert(inMemSorter != null); logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)", Thread.currentThread().getId(), Utils.bytesToString(getMemoryUsage()), @@ -281,7 +278,7 @@ final class ShuffleExternalSorter { writeSortedFile(false); final long inMemSorterMemoryUsage = inMemSorter.getMemoryUsage(); inMemSorter = null; - shuffleMemoryManager.release(inMemSorterMemoryUsage); + taskMemoryManager.releaseExecutionMemory(inMemSorterMemoryUsage); final long spillSize = freeMemory(); taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); @@ -316,9 +313,13 @@ final class ShuffleExternalSorter { long memoryFreed = 0; for (MemoryBlock block : allocatedPages) { taskMemoryManager.freePage(block); - shuffleMemoryManager.release(block.size()); memoryFreed += block.size(); } + if (inMemSorter != null) { + long sorterMemoryUsage = inMemSorter.getMemoryUsage(); + inMemSorter = null; + taskMemoryManager.releaseExecutionMemory(sorterMemoryUsage); + } allocatedPages.clear(); currentPage = null; currentPagePosition = -1; @@ -337,8 +338,9 @@ final class ShuffleExternalSorter { } } if (inMemSorter != null) { - shuffleMemoryManager.release(inMemSorter.getMemoryUsage()); + long sorterMemoryUsage = inMemSorter.getMemoryUsage(); inMemSorter = null; + taskMemoryManager.releaseExecutionMemory(sorterMemoryUsage); } } @@ -353,21 +355,20 @@ final class ShuffleExternalSorter { logger.debug("Attempting to expand sort pointer array"); final long oldPointerArrayMemoryUsage = inMemSorter.getMemoryUsage(); final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2; - final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray); + final long memoryAcquired = taskMemoryManager.acquireExecutionMemory(memoryToGrowPointerArray); if (memoryAcquired < memoryToGrowPointerArray) { - shuffleMemoryManager.release(memoryAcquired); + taskMemoryManager.releaseExecutionMemory(memoryAcquired); spill(); } else { inMemSorter.expandPointerArray(); - shuffleMemoryManager.release(oldPointerArrayMemoryUsage); + taskMemoryManager.releaseExecutionMemory(oldPointerArrayMemoryUsage); } } } - + /** * Allocates more memory in order to insert an additional record. This will request additional - * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be - * obtained. + * memory from the memory manager and spill if the requested memory can not be obtained. * * @param requiredSpace the required space in the data page, in bytes, including space for storing * the record size. This must be less than or equal to the page size (records @@ -386,17 +387,14 @@ final class ShuffleExternalSorter { throw new IOException("Required space " + requiredSpace + " is greater than page size (" + pageSizeBytes + ")"); } else { - final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes); - if (memoryAcquired < pageSizeBytes) { - shuffleMemoryManager.release(memoryAcquired); + currentPage = taskMemoryManager.allocatePage(pageSizeBytes); + if (currentPage == null) { spill(); - final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes); - if (memoryAcquiredAfterSpilling != pageSizeBytes) { - shuffleMemoryManager.release(memoryAcquiredAfterSpilling); + currentPage = taskMemoryManager.allocatePage(pageSizeBytes); + if (currentPage == null) { throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory"); } } - currentPage = taskMemoryManager.allocatePage(pageSizeBytes); currentPagePosition = currentPage.getBaseOffset(); freeSpaceInCurrentPage = pageSizeBytes; allocatedPages.add(currentPage); @@ -430,17 +428,14 @@ final class ShuffleExternalSorter { long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired); // The record is larger than the page size, so allocate a special overflow page just to hold // that record. - final long memoryGranted = shuffleMemoryManager.tryToAcquire(overflowPageSize); - if (memoryGranted != overflowPageSize) { - shuffleMemoryManager.release(memoryGranted); + MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize); + if (overflowPage == null) { spill(); - final long memoryGrantedAfterSpill = shuffleMemoryManager.tryToAcquire(overflowPageSize); - if (memoryGrantedAfterSpill != overflowPageSize) { - shuffleMemoryManager.release(memoryGrantedAfterSpill); + overflowPage = taskMemoryManager.allocatePage(overflowPageSize); + if (overflowPage == null) { throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory"); } } - MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize); allocatedPages.add(overflowPage); dataPage = overflowPage; dataPagePosition = overflowPage.getBaseOffset(); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index e8f050cb2d..f6c5c944bd 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -49,12 +49,11 @@ import org.apache.spark.serializer.SerializationStream; import org.apache.spark.serializer.Serializer; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.IndexShuffleBlockResolver; -import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.shuffle.ShuffleWriter; import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.TimeTrackingOutputStream; import org.apache.spark.unsafe.Platform; -import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.memory.TaskMemoryManager; @Private public class UnsafeShuffleWriter extends ShuffleWriter { @@ -69,7 +68,6 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final BlockManager blockManager; private final IndexShuffleBlockResolver shuffleBlockResolver; private final TaskMemoryManager memoryManager; - private final ShuffleMemoryManager shuffleMemoryManager; private final SerializerInstance serializer; private final Partitioner partitioner; private final ShuffleWriteMetrics writeMetrics; @@ -103,7 +101,6 @@ public class UnsafeShuffleWriter extends ShuffleWriter { BlockManager blockManager, IndexShuffleBlockResolver shuffleBlockResolver, TaskMemoryManager memoryManager, - ShuffleMemoryManager shuffleMemoryManager, SerializedShuffleHandle handle, int mapId, TaskContext taskContext, @@ -117,7 +114,6 @@ public class UnsafeShuffleWriter extends ShuffleWriter { this.blockManager = blockManager; this.shuffleBlockResolver = shuffleBlockResolver; this.memoryManager = memoryManager; - this.shuffleMemoryManager = shuffleMemoryManager; this.mapId = mapId; final ShuffleDependency dep = handle.dependency(); this.shuffleId = dep.shuffleId(); @@ -197,7 +193,6 @@ public class UnsafeShuffleWriter extends ShuffleWriter { assert (sorter == null); sorter = new ShuffleExternalSorter( memoryManager, - shuffleMemoryManager, blockManager, taskContext, INITIAL_SORT_BUFFER_SIZE, diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index b24eed3952..f035bdac81 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -26,7 +26,6 @@ import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.array.LongArray; @@ -34,7 +33,7 @@ import org.apache.spark.unsafe.bitset.BitSet; import org.apache.spark.unsafe.hash.Murmur3_x86_32; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.unsafe.memory.MemoryLocation; -import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.memory.TaskMemoryManager; /** * An append-only hash map where keys and values are contiguous regions of bytes. @@ -70,8 +69,6 @@ public final class BytesToBytesMap { private final TaskMemoryManager taskMemoryManager; - private final ShuffleMemoryManager shuffleMemoryManager; - /** * A linked list for tracking all allocated data pages so that we can free all of our memory. */ @@ -169,13 +166,11 @@ public final class BytesToBytesMap { public BytesToBytesMap( TaskMemoryManager taskMemoryManager, - ShuffleMemoryManager shuffleMemoryManager, int initialCapacity, double loadFactor, long pageSizeBytes, boolean enablePerfMetrics) { this.taskMemoryManager = taskMemoryManager; - this.shuffleMemoryManager = shuffleMemoryManager; this.loadFactor = loadFactor; this.loc = new Location(); this.pageSizeBytes = pageSizeBytes; @@ -201,21 +196,18 @@ public final class BytesToBytesMap { public BytesToBytesMap( TaskMemoryManager taskMemoryManager, - ShuffleMemoryManager shuffleMemoryManager, int initialCapacity, long pageSizeBytes) { - this(taskMemoryManager, shuffleMemoryManager, initialCapacity, 0.70, pageSizeBytes, false); + this(taskMemoryManager, initialCapacity, 0.70, pageSizeBytes, false); } public BytesToBytesMap( TaskMemoryManager taskMemoryManager, - ShuffleMemoryManager shuffleMemoryManager, int initialCapacity, long pageSizeBytes, boolean enablePerfMetrics) { this( taskMemoryManager, - shuffleMemoryManager, initialCapacity, 0.70, pageSizeBytes, @@ -260,7 +252,6 @@ public final class BytesToBytesMap { if (destructive && currentPage != null) { dataPagesIterator.remove(); this.bmap.taskMemoryManager.freePage(currentPage); - this.bmap.shuffleMemoryManager.release(currentPage.size()); } currentPage = dataPagesIterator.next(); pageBaseObject = currentPage.getBaseObject(); @@ -572,14 +563,12 @@ public final class BytesToBytesMap { if (useOverflowPage) { // The record is larger than the page size, so allocate a special overflow page just to hold // that record. - final long memoryRequested = requiredSize + 8; - final long memoryGranted = shuffleMemoryManager.tryToAcquire(memoryRequested); - if (memoryGranted != memoryRequested) { - shuffleMemoryManager.release(memoryGranted); - logger.debug("Failed to acquire {} bytes of memory", memoryRequested); + final long overflowPageSize = requiredSize + 8; + MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize); + if (overflowPage == null) { + logger.debug("Failed to acquire {} bytes of memory", overflowPageSize); return false; } - MemoryBlock overflowPage = taskMemoryManager.allocatePage(memoryRequested); dataPages.add(overflowPage); dataPage = overflowPage; dataPageBaseObject = overflowPage.getBaseObject(); @@ -655,17 +644,15 @@ public final class BytesToBytesMap { } /** - * Acquire a new page from the {@link ShuffleMemoryManager}. + * Acquire a new page from the memory manager. * @return whether there is enough space to allocate the new page. */ private boolean acquireNewPage() { - final long memoryGranted = shuffleMemoryManager.tryToAcquire(pageSizeBytes); - if (memoryGranted != pageSizeBytes) { - shuffleMemoryManager.release(memoryGranted); + MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes); + if (newPage == null) { logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes); return false; } - MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes); dataPages.add(newPage); pageCursor = 0; currentDataPage = newPage; @@ -705,7 +692,6 @@ public final class BytesToBytesMap { MemoryBlock dataPage = dataPagesIterator.next(); dataPagesIterator.remove(); taskMemoryManager.freePage(dataPage); - shuffleMemoryManager.release(dataPage.size()); } assert(dataPages.isEmpty()); } @@ -714,10 +700,6 @@ public final class BytesToBytesMap { return taskMemoryManager; } - public ShuffleMemoryManager getShuffleMemoryManager() { - return shuffleMemoryManager; - } - public long getPageSizeBytes() { return pageSizeBytes; } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java index 0c4ebde407..dbf6770e07 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java @@ -17,9 +17,11 @@ package org.apache.spark.util.collection.unsafe.sort; +import org.apache.spark.memory.TaskMemoryManager; + final class RecordPointerAndKeyPrefix { /** - * A pointer to a record; see {@link org.apache.spark.unsafe.memory.TaskMemoryManager} for a + * A pointer to a record; see {@link TaskMemoryManager} for a * description of how these addresses are encoded. */ public long recordPointer; diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 0a311d2d93..e317ea391c 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -32,12 +32,11 @@ import org.slf4j.LoggerFactory; import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleWriteMetrics; -import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.MemoryBlock; -import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.util.Utils; /** @@ -52,7 +51,6 @@ public final class UnsafeExternalSorter { private final RecordComparator recordComparator; private final int initialSize; private final TaskMemoryManager taskMemoryManager; - private final ShuffleMemoryManager shuffleMemoryManager; private final BlockManager blockManager; private final TaskContext taskContext; private ShuffleWriteMetrics writeMetrics; @@ -82,7 +80,6 @@ public final class UnsafeExternalSorter { public static UnsafeExternalSorter createWithExistingInMemorySorter( TaskMemoryManager taskMemoryManager, - ShuffleMemoryManager shuffleMemoryManager, BlockManager blockManager, TaskContext taskContext, RecordComparator recordComparator, @@ -90,26 +87,24 @@ public final class UnsafeExternalSorter { int initialSize, long pageSizeBytes, UnsafeInMemorySorter inMemorySorter) throws IOException { - return new UnsafeExternalSorter(taskMemoryManager, shuffleMemoryManager, blockManager, + return new UnsafeExternalSorter(taskMemoryManager, blockManager, taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, inMemorySorter); } public static UnsafeExternalSorter create( TaskMemoryManager taskMemoryManager, - ShuffleMemoryManager shuffleMemoryManager, BlockManager blockManager, TaskContext taskContext, RecordComparator recordComparator, PrefixComparator prefixComparator, int initialSize, long pageSizeBytes) throws IOException { - return new UnsafeExternalSorter(taskMemoryManager, shuffleMemoryManager, blockManager, + return new UnsafeExternalSorter(taskMemoryManager, blockManager, taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, null); } private UnsafeExternalSorter( TaskMemoryManager taskMemoryManager, - ShuffleMemoryManager shuffleMemoryManager, BlockManager blockManager, TaskContext taskContext, RecordComparator recordComparator, @@ -118,7 +113,6 @@ public final class UnsafeExternalSorter { long pageSizeBytes, @Nullable UnsafeInMemorySorter existingInMemorySorter) throws IOException { this.taskMemoryManager = taskMemoryManager; - this.shuffleMemoryManager = shuffleMemoryManager; this.blockManager = blockManager; this.taskContext = taskContext; this.recordComparator = recordComparator; @@ -261,7 +255,6 @@ public final class UnsafeExternalSorter { long memoryFreed = 0; for (MemoryBlock block : allocatedPages) { taskMemoryManager.freePage(block); - shuffleMemoryManager.release(block.size()); memoryFreed += block.size(); } // TODO: track in-memory sorter memory usage (SPARK-10474) @@ -309,8 +302,7 @@ public final class UnsafeExternalSorter { /** * Allocates more memory in order to insert an additional record. This will request additional - * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be - * obtained. + * memory from the memory manager and spill if the requested memory can not be obtained. * * @param requiredSpace the required space in the data page, in bytes, including space for storing * the record size. This must be less than or equal to the page size (records @@ -335,23 +327,20 @@ public final class UnsafeExternalSorter { } /** - * Acquire a new page from the {@link ShuffleMemoryManager}. + * Acquire a new page from the memory manager. * * If there is not enough space to allocate the new page, spill all existing ones * and try again. If there is still not enough space, report error to the caller. */ private void acquireNewPage() throws IOException { - final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes); - if (memoryAcquired < pageSizeBytes) { - shuffleMemoryManager.release(memoryAcquired); + currentPage = taskMemoryManager.allocatePage(pageSizeBytes); + if (currentPage == null) { spill(); - final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes); - if (memoryAcquiredAfterSpilling != pageSizeBytes) { - shuffleMemoryManager.release(memoryAcquiredAfterSpilling); + currentPage = taskMemoryManager.allocatePage(pageSizeBytes); + if (currentPage == null) { throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory"); } } - currentPage = taskMemoryManager.allocatePage(pageSizeBytes); currentPagePosition = currentPage.getBaseOffset(); freeSpaceInCurrentPage = pageSizeBytes; allocatedPages.add(currentPage); @@ -379,17 +368,14 @@ public final class UnsafeExternalSorter { long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired); // The record is larger than the page size, so allocate a special overflow page just to hold // that record. - final long memoryGranted = shuffleMemoryManager.tryToAcquire(overflowPageSize); - if (memoryGranted != overflowPageSize) { - shuffleMemoryManager.release(memoryGranted); + MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize); + if (overflowPage == null) { spill(); - final long memoryGrantedAfterSpill = shuffleMemoryManager.tryToAcquire(overflowPageSize); - if (memoryGrantedAfterSpill != overflowPageSize) { - shuffleMemoryManager.release(memoryGrantedAfterSpill); + overflowPage = taskMemoryManager.allocatePage(overflowPageSize); + if (overflowPage == null) { throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory"); } } - MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize); allocatedPages.add(overflowPage); dataPage = overflowPage; dataPagePosition = overflowPage.getBaseOffset(); @@ -441,17 +427,14 @@ public final class UnsafeExternalSorter { long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired); // The record is larger than the page size, so allocate a special overflow page just to hold // that record. - final long memoryGranted = shuffleMemoryManager.tryToAcquire(overflowPageSize); - if (memoryGranted != overflowPageSize) { - shuffleMemoryManager.release(memoryGranted); + MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize); + if (overflowPage == null) { spill(); - final long memoryGrantedAfterSpill = shuffleMemoryManager.tryToAcquire(overflowPageSize); - if (memoryGrantedAfterSpill != overflowPageSize) { - shuffleMemoryManager.release(memoryGrantedAfterSpill); + overflowPage = taskMemoryManager.allocatePage(overflowPageSize); + if (overflowPage == null) { throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory"); } } - MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize); allocatedPages.add(overflowPage); dataPage = overflowPage; dataPagePosition = overflowPage.getBaseOffset(); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index f7787e1019..5aad72c374 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -21,7 +21,7 @@ import java.util.Comparator; import org.apache.spark.unsafe.Platform; import org.apache.spark.util.collection.Sorter; -import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.memory.TaskMemoryManager; /** * Sorts records using an AlphaSort-style key-prefix sort. This sort stores pointers to records diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index b5c35c569e..398e093690 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -38,9 +38,8 @@ import org.apache.spark.rpc.akka.AkkaRpcEnv import org.apache.spark.scheduler.{OutputCommitCoordinator, LiveListenerBus} import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} +import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.storage._ -import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator} import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils} /** @@ -70,10 +69,7 @@ class SparkEnv ( val httpFileServer: HttpFileServer, val sparkFilesDir: String, val metricsSystem: MetricsSystem, - // TODO: unify these *MemoryManager classes (SPARK-10984) val memoryManager: MemoryManager, - val shuffleMemoryManager: ShuffleMemoryManager, - val executorMemoryManager: ExecutorMemoryManager, val outputCommitCoordinator: OutputCommitCoordinator, val conf: SparkConf) extends Logging { @@ -340,13 +336,11 @@ object SparkEnv extends Logging { val useLegacyMemoryManager = conf.getBoolean("spark.memory.useLegacyMode", false) val memoryManager: MemoryManager = if (useLegacyMemoryManager) { - new StaticMemoryManager(conf) + new StaticMemoryManager(conf, numUsableCores) } else { - new UnifiedMemoryManager(conf) + new UnifiedMemoryManager(conf, numUsableCores) } - val shuffleMemoryManager = ShuffleMemoryManager.create(conf, memoryManager, numUsableCores) - val blockTransferService = new NettyBlockTransferService(conf, securityManager, numUsableCores) val blockManagerMaster = new BlockManagerMaster(registerOrLookupEndpoint( @@ -405,15 +399,6 @@ object SparkEnv extends Logging { new OutputCommitCoordinatorEndpoint(rpcEnv, outputCommitCoordinator)) outputCommitCoordinator.coordinatorRef = Some(outputCommitCoordinatorRef) - val executorMemoryManager: ExecutorMemoryManager = { - val allocator = if (conf.getBoolean("spark.unsafe.offHeap", false)) { - MemoryAllocator.UNSAFE - } else { - MemoryAllocator.HEAP - } - new ExecutorMemoryManager(allocator) - } - val envInstance = new SparkEnv( executorId, rpcEnv, @@ -431,8 +416,6 @@ object SparkEnv extends Logging { sparkFilesDir, metricsSystem, memoryManager, - shuffleMemoryManager, - executorMemoryManager, outputCommitCoordinator, conf) diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 63cca80b2d..af558d6e5b 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -21,8 +21,8 @@ import java.io.Serializable import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics +import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.source.Source -import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util.TaskCompletionListener diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 5df94c6d3a..f0ae83a934 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -20,9 +20,9 @@ package org.apache.spark import scala.collection.mutable.{ArrayBuffer, HashMap} import org.apache.spark.executor.TaskMetrics +import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.MetricsSystem import org.apache.spark.metrics.source.Source -import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException} private[spark] class TaskContextImpl( diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index c3491bb8b1..9e88d488c0 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -29,10 +29,10 @@ import scala.util.control.NonFatal import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{StorageLevel, TaskResultBlockId} -import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util._ /** @@ -179,7 +179,7 @@ private[spark] class Executor( } override def run(): Unit = { - val taskMemoryManager = new TaskMemoryManager(env.executorMemoryManager) + val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId) val deserializeStartTime = System.currentTimeMillis() Thread.currentThread.setContextClassLoader(replClassLoader) val ser = env.closureSerializer.newInstance() diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala index 7168ac5491..6c9a71c385 100644 --- a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala @@ -17,20 +17,38 @@ package org.apache.spark.memory +import javax.annotation.concurrent.GuardedBy + import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer -import org.apache.spark.Logging -import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore} +import com.google.common.annotations.VisibleForTesting +import org.apache.spark.{SparkException, TaskContext, SparkConf, Logging} +import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore} +import org.apache.spark.unsafe.array.ByteArrayMethods +import org.apache.spark.unsafe.memory.MemoryAllocator /** * An abstract memory manager that enforces how memory is shared between execution and storage. * * In this context, execution memory refers to that used for computation in shuffles, joins, * sorts and aggregations, while storage memory refers to that used for caching and propagating - * internal data across the cluster. There exists one of these per JVM. + * internal data across the cluster. There exists one MemoryManager per JVM. + * + * The MemoryManager abstract base class itself implements policies for sharing execution memory + * between tasks; it tries to ensure that each task gets a reasonable share of memory, instead of + * some task ramping up to a large amount first and then causing others to spill to disk repeatedly. + * If there are N tasks, it ensures that each task can acquire at least 1 / 2N of the memory + * before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the + * set of active tasks and redo the calculations of 1 / 2N and 1 / N in waiting tasks whenever + * this set changes. This is all done by synchronizing access to mutable state and using wait() and + * notifyAll() to signal changes to callers. Prior to Spark 1.6, this arbitration of memory across + * tasks was performed by the ShuffleMemoryManager. */ -private[spark] abstract class MemoryManager extends Logging { +private[spark] abstract class MemoryManager(conf: SparkConf, numCores: Int) extends Logging { + + // -- Methods related to memory allocation policies and bookkeeping ------------------------------ // The memory store used to evict cached blocks private var _memoryStore: MemoryStore = _ @@ -42,8 +60,10 @@ private[spark] abstract class MemoryManager extends Logging { } // Amount of execution/storage memory in use, accesses must be synchronized on `this` - protected var _executionMemoryUsed: Long = 0 - protected var _storageMemoryUsed: Long = 0 + @GuardedBy("this") protected var _executionMemoryUsed: Long = 0 + @GuardedBy("this") protected var _storageMemoryUsed: Long = 0 + // Map from taskAttemptId -> memory consumption in bytes + @GuardedBy("this") private val executionMemoryForTask = new mutable.HashMap[Long, Long]() /** * Set the [[MemoryStore]] used by this manager to evict cached blocks. @@ -65,15 +85,6 @@ private[spark] abstract class MemoryManager extends Logging { // TODO: avoid passing evicted blocks around to simplify method signatures (SPARK-10985) - /** - * Acquire N bytes of memory for execution, evicting cached blocks if necessary. - * Blocks evicted in the process, if any, are added to `evictedBlocks`. - * @return number of bytes successfully granted (<= N). - */ - def acquireExecutionMemory( - numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long - /** * Acquire N bytes of memory to cache the given block, evicting existing ones if necessary. * Blocks evicted in the process, if any, are added to `evictedBlocks`. @@ -102,9 +113,92 @@ private[spark] abstract class MemoryManager extends Logging { } /** - * Release N bytes of execution memory. + * Acquire N bytes of memory for execution, evicting cached blocks if necessary. + * Blocks evicted in the process, if any, are added to `evictedBlocks`. + * @return number of bytes successfully granted (<= N). + */ + @VisibleForTesting + private[memory] def doAcquireExecutionMemory( + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long + + /** + * Try to acquire up to `numBytes` of execution memory for the current task and return the number + * of bytes obtained, or 0 if none can be allocated. + * + * This call may block until there is enough free memory in some situations, to make sure each + * task has a chance to ramp up to at least 1 / 2N of the total memory pool (where N is the # of + * active tasks) before it is forced to spill. This can happen if the number of tasks increase + * but an older task had a lot of memory already. + * + * Subclasses should override `doAcquireExecutionMemory` in order to customize the policies + * that control global sharing of memory between execution and storage. */ - def releaseExecutionMemory(numBytes: Long): Unit = synchronized { + private[memory] + final def acquireExecutionMemory(numBytes: Long, taskAttemptId: Long): Long = synchronized { + assert(numBytes > 0, "invalid number of bytes requested: " + numBytes) + + // Add this task to the taskMemory map just so we can keep an accurate count of the number + // of active tasks, to let other tasks ramp down their memory in calls to tryToAcquire + if (!executionMemoryForTask.contains(taskAttemptId)) { + executionMemoryForTask(taskAttemptId) = 0L + // This will later cause waiting tasks to wake up and check numTasks again + notifyAll() + } + + // Once the cross-task memory allocation policy has decided to grant more memory to a task, + // this method is called in order to actually obtain that execution memory, potentially + // triggering eviction of storage memory: + def acquire(toGrant: Long): Long = synchronized { + val evictedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] + val acquired = doAcquireExecutionMemory(toGrant, evictedBlocks) + // Register evicted blocks, if any, with the active task metrics + Option(TaskContext.get()).foreach { tc => + val metrics = tc.taskMetrics() + val lastUpdatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]()) + metrics.updatedBlocks = Some(lastUpdatedBlocks ++ evictedBlocks.toSeq) + } + executionMemoryForTask(taskAttemptId) += acquired + acquired + } + + // Keep looping until we're either sure that we don't want to grant this request (because this + // task would have more than 1 / numActiveTasks of the memory) or we have enough free + // memory to give it (we always let each task get at least 1 / (2 * numActiveTasks)). + // TODO: simplify this to limit each task to its own slot + while (true) { + val numActiveTasks = executionMemoryForTask.keys.size + val curMem = executionMemoryForTask(taskAttemptId) + val freeMemory = maxExecutionMemory - executionMemoryForTask.values.sum + + // How much we can grant this task; don't let it grow to more than 1 / numActiveTasks; + // don't let it be negative + val maxToGrant = + math.min(numBytes, math.max(0, (maxExecutionMemory / numActiveTasks) - curMem)) + // Only give it as much memory as is free, which might be none if it reached 1 / numTasks + val toGrant = math.min(maxToGrant, freeMemory) + + if (curMem < maxExecutionMemory / (2 * numActiveTasks)) { + // We want to let each task get at least 1 / (2 * numActiveTasks) before blocking; + // if we can't give it this much now, wait for other tasks to free up memory + // (this happens if older tasks allocated lots of memory before N grew) + if ( + freeMemory >= math.min(maxToGrant, maxExecutionMemory / (2 * numActiveTasks) - curMem)) { + return acquire(toGrant) + } else { + logInfo( + s"TID $taskAttemptId waiting for at least 1/2N of execution memory pool to be free") + wait() + } + } else { + return acquire(toGrant) + } + } + 0L // Never reached + } + + @VisibleForTesting + private[memory] def releaseExecutionMemory(numBytes: Long): Unit = synchronized { if (numBytes > _executionMemoryUsed) { logWarning(s"Attempted to release $numBytes bytes of execution " + s"memory when we only have ${_executionMemoryUsed} bytes") @@ -114,6 +208,36 @@ private[spark] abstract class MemoryManager extends Logging { } } + /** + * Release numBytes of execution memory belonging to the given task. + */ + private[memory] + final def releaseExecutionMemory(numBytes: Long, taskAttemptId: Long): Unit = synchronized { + val curMem = executionMemoryForTask.getOrElse(taskAttemptId, 0L) + if (curMem < numBytes) { + throw new SparkException( + s"Internal error: release called on $numBytes bytes but task only has $curMem") + } + if (executionMemoryForTask.contains(taskAttemptId)) { + executionMemoryForTask(taskAttemptId) -= numBytes + if (executionMemoryForTask(taskAttemptId) <= 0) { + executionMemoryForTask.remove(taskAttemptId) + } + releaseExecutionMemory(numBytes) + } + notifyAll() // Notify waiters in acquireExecutionMemory() that memory has been freed + } + + /** + * Release all memory for the given task and mark it as inactive (e.g. when a task ends). + * @return the number of bytes freed. + */ + private[memory] def releaseAllExecutionMemoryForTask(taskAttemptId: Long): Long = synchronized { + val numBytesToFree = getExecutionMemoryUsageForTask(taskAttemptId) + releaseExecutionMemory(numBytesToFree, taskAttemptId) + numBytesToFree + } + /** * Release N bytes of storage memory. */ @@ -155,4 +279,43 @@ private[spark] abstract class MemoryManager extends Logging { _storageMemoryUsed } + /** + * Returns the execution memory consumption, in bytes, for the given task. + */ + private[memory] def getExecutionMemoryUsageForTask(taskAttemptId: Long): Long = synchronized { + executionMemoryForTask.getOrElse(taskAttemptId, 0L) + } + + // -- Fields related to Tungsten managed memory ------------------------------------------------- + + /** + * The default page size, in bytes. + * + * If user didn't explicitly set "spark.buffer.pageSize", we figure out the default value + * by looking at the number of cores available to the process, and the total amount of memory, + * and then divide it by a factor of safety. + */ + val pageSizeBytes: Long = { + val minPageSize = 1L * 1024 * 1024 // 1MB + val maxPageSize = 64L * minPageSize // 64MB + val cores = if (numCores > 0) numCores else Runtime.getRuntime.availableProcessors() + // Because of rounding to next power of 2, we may have safetyFactor as 8 in worst case + val safetyFactor = 16 + val size = ByteArrayMethods.nextPowerOf2(maxExecutionMemory / cores / safetyFactor) + val default = math.min(maxPageSize, math.max(minPageSize, size)) + conf.getSizeAsBytes("spark.buffer.pageSize", default) + } + + /** + * Tracks whether Tungsten memory will be allocated on the JVM heap or off-heap using + * sun.misc.Unsafe. + */ + final val tungstenMemoryIsAllocatedInHeap: Boolean = + !conf.getBoolean("spark.unsafe.offHeap", false) + + /** + * Allocates memory for use by Unsafe/Tungsten code. + */ + private[memory] final val tungstenMemoryAllocator: MemoryAllocator = + if (tungstenMemoryIsAllocatedInHeap) MemoryAllocator.HEAP else MemoryAllocator.UNSAFE } diff --git a/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala index fa44f37234..9c2c2e90a2 100644 --- a/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala @@ -33,14 +33,16 @@ import org.apache.spark.storage.{BlockId, BlockStatus} private[spark] class StaticMemoryManager( conf: SparkConf, override val maxExecutionMemory: Long, - override val maxStorageMemory: Long) - extends MemoryManager { + override val maxStorageMemory: Long, + numCores: Int) + extends MemoryManager(conf, numCores) { - def this(conf: SparkConf) { + def this(conf: SparkConf, numCores: Int) { this( conf, StaticMemoryManager.getMaxExecutionMemory(conf), - StaticMemoryManager.getMaxStorageMemory(conf)) + StaticMemoryManager.getMaxStorageMemory(conf), + numCores) } // Max number of bytes worth of blocks to evict when unrolling @@ -52,7 +54,7 @@ private[spark] class StaticMemoryManager( * Acquire N bytes of memory for execution. * @return number of bytes successfully granted (<= N). */ - override def acquireExecutionMemory( + override def doAcquireExecutionMemory( numBytes: Long, evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long = synchronized { assert(numBytes >= 0) diff --git a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala index 5bf78d5b67..a3093030a0 100644 --- a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala @@ -42,10 +42,14 @@ import org.apache.spark.storage.{BlockStatus, BlockId} * up most of the storage space, in which case the new blocks will be evicted immediately * according to their respective storage levels. */ -private[spark] class UnifiedMemoryManager(conf: SparkConf, maxMemory: Long) extends MemoryManager { +private[spark] class UnifiedMemoryManager( + conf: SparkConf, + maxMemory: Long, + numCores: Int) + extends MemoryManager(conf, numCores) { - def this(conf: SparkConf) { - this(conf, UnifiedMemoryManager.getMaxMemory(conf)) + def this(conf: SparkConf, numCores: Int) { + this(conf, UnifiedMemoryManager.getMaxMemory(conf), numCores) } /** @@ -91,7 +95,7 @@ private[spark] class UnifiedMemoryManager(conf: SparkConf, maxMemory: Long) exte * Blocks evicted in the process, if any, are added to `evictedBlocks`. * @return number of bytes successfully granted (<= N). */ - override def acquireExecutionMemory( + private[memory] override def doAcquireExecutionMemory( numBytes: Long, evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long = synchronized { assert(numBytes >= 0) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 9edf9f048f..4fb32ba8cb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -25,8 +25,8 @@ import scala.collection.mutable.HashMap import org.apache.spark.metrics.MetricsSystem import org.apache.spark.{Accumulator, SparkEnv, TaskContextImpl, TaskContext} import org.apache.spark.executor.TaskMetrics +import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.serializer.SerializerInstance -import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util.ByteBufferInputStream import org.apache.spark.util.Utils @@ -89,10 +89,6 @@ private[spark] abstract class Task[T]( } finally { context.markTaskCompleted() try { - Utils.tryLogNonFatalError { - // Release memory used by this thread for shuffles - SparkEnv.get.shuffleMemoryManager.releaseMemoryForThisTask() - } Utils.tryLogNonFatalError { // Release memory used by this thread for unrolling blocks SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask() diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 7c3e2b5a37..b0abda4a81 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -98,13 +98,14 @@ private[spark] class BlockStoreShuffleReader[K, C]( case Some(keyOrd: Ordering[K]) => // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled, // the ExternalSorter won't spill to disk. - val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser)) + val sorter = + new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = Some(ser)) sorter.insertAll(aggregatedIter) context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) context.internalMetricsToAccumulators( InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes) - sorter.iterator + CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) case None => aggregatedIter } diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala deleted file mode 100644 index 9bd18da47f..0000000000 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala +++ /dev/null @@ -1,209 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle - -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer - -import com.google.common.annotations.VisibleForTesting - -import org.apache.spark._ -import org.apache.spark.memory.{StaticMemoryManager, MemoryManager} -import org.apache.spark.storage.{BlockId, BlockStatus} -import org.apache.spark.unsafe.array.ByteArrayMethods - -/** - * Allocates a pool of memory to tasks for use in shuffle operations. Each disk-spilling - * collection (ExternalAppendOnlyMap or ExternalSorter) used by these tasks can acquire memory - * from this pool and release it as it spills data out. When a task ends, all its memory will be - * released by the Executor. - * - * This class tries to ensure that each task gets a reasonable share of memory, instead of some - * task ramping up to a large amount first and then causing others to spill to disk repeatedly. - * If there are N tasks, it ensures that each tasks can acquire at least 1 / 2N of the memory - * before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the - * set of active tasks and redo the calculations of 1 / 2N and 1 / N in waiting tasks whenever - * this set changes. This is all done by synchronizing access to `memoryManager` to mutate state - * and using wait() and notifyAll() to signal changes. - * - * Use `ShuffleMemoryManager.create()` factory method to create a new instance. - * - * @param memoryManager the interface through which this manager acquires execution memory - * @param pageSizeBytes number of bytes for each page, by default. - */ -private[spark] -class ShuffleMemoryManager protected ( - memoryManager: MemoryManager, - val pageSizeBytes: Long) - extends Logging { - - private val taskMemory = new mutable.HashMap[Long, Long]() // taskAttemptId -> memory bytes - - private def currentTaskAttemptId(): Long = { - // In case this is called on the driver, return an invalid task attempt id. - Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(-1L) - } - - /** - * Try to acquire up to numBytes memory for the current task, and return the number of bytes - * obtained, or 0 if none can be allocated. This call may block until there is enough free memory - * in some situations, to make sure each task has a chance to ramp up to at least 1 / 2N of the - * total memory pool (where N is the # of active tasks) before it is forced to spill. This can - * happen if the number of tasks increases but an older task had a lot of memory already. - */ - def tryToAcquire(numBytes: Long): Long = memoryManager.synchronized { - val taskAttemptId = currentTaskAttemptId() - assert(numBytes > 0, "invalid number of bytes requested: " + numBytes) - - // Add this task to the taskMemory map just so we can keep an accurate count of the number - // of active tasks, to let other tasks ramp down their memory in calls to tryToAcquire - if (!taskMemory.contains(taskAttemptId)) { - taskMemory(taskAttemptId) = 0L - // This will later cause waiting tasks to wake up and check numTasks again - memoryManager.notifyAll() - } - - // Keep looping until we're either sure that we don't want to grant this request (because this - // task would have more than 1 / numActiveTasks of the memory) or we have enough free - // memory to give it (we always let each task get at least 1 / (2 * numActiveTasks)). - // TODO: simplify this to limit each task to its own slot - while (true) { - val numActiveTasks = taskMemory.keys.size - val curMem = taskMemory(taskAttemptId) - val maxMemory = memoryManager.maxExecutionMemory - val freeMemory = maxMemory - taskMemory.values.sum - - // How much we can grant this task; don't let it grow to more than 1 / numActiveTasks; - // don't let it be negative - val maxToGrant = math.min(numBytes, math.max(0, (maxMemory / numActiveTasks) - curMem)) - // Only give it as much memory as is free, which might be none if it reached 1 / numTasks - val toGrant = math.min(maxToGrant, freeMemory) - - if (curMem < maxMemory / (2 * numActiveTasks)) { - // We want to let each task get at least 1 / (2 * numActiveTasks) before blocking; - // if we can't give it this much now, wait for other tasks to free up memory - // (this happens if older tasks allocated lots of memory before N grew) - if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveTasks) - curMem)) { - return acquire(toGrant) - } else { - logInfo( - s"TID $taskAttemptId waiting for at least 1/2N of shuffle memory pool to be free") - memoryManager.wait() - } - } else { - return acquire(toGrant) - } - } - 0L // Never reached - } - - /** - * Acquire N bytes of execution memory from the memory manager for the current task. - * @return number of bytes actually acquired (<= N). - */ - private def acquire(numBytes: Long): Long = memoryManager.synchronized { - val taskAttemptId = currentTaskAttemptId() - val evictedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] - val acquired = memoryManager.acquireExecutionMemory(numBytes, evictedBlocks) - // Register evicted blocks, if any, with the active task metrics - // TODO: just do this in `acquireExecutionMemory` (SPARK-10985) - Option(TaskContext.get()).foreach { tc => - val metrics = tc.taskMetrics() - val lastUpdatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]()) - metrics.updatedBlocks = Some(lastUpdatedBlocks ++ evictedBlocks.toSeq) - } - taskMemory(taskAttemptId) += acquired - acquired - } - - /** Release numBytes bytes for the current task. */ - def release(numBytes: Long): Unit = memoryManager.synchronized { - val taskAttemptId = currentTaskAttemptId() - val curMem = taskMemory.getOrElse(taskAttemptId, 0L) - if (curMem < numBytes) { - throw new SparkException( - s"Internal error: release called on $numBytes bytes but task only has $curMem") - } - if (taskMemory.contains(taskAttemptId)) { - taskMemory(taskAttemptId) -= numBytes - memoryManager.releaseExecutionMemory(numBytes) - } - memoryManager.notifyAll() // Notify waiters in tryToAcquire that memory has been freed - } - - /** Release all memory for the current task and mark it as inactive (e.g. when a task ends). */ - def releaseMemoryForThisTask(): Unit = memoryManager.synchronized { - val taskAttemptId = currentTaskAttemptId() - taskMemory.remove(taskAttemptId).foreach { numBytes => - memoryManager.releaseExecutionMemory(numBytes) - } - memoryManager.notifyAll() // Notify waiters in tryToAcquire that memory has been freed - } - - /** Returns the memory consumption, in bytes, for the current task */ - def getMemoryConsumptionForThisTask(): Long = memoryManager.synchronized { - val taskAttemptId = currentTaskAttemptId() - taskMemory.getOrElse(taskAttemptId, 0L) - } -} - - -private[spark] object ShuffleMemoryManager { - - def create( - conf: SparkConf, - memoryManager: MemoryManager, - numCores: Int): ShuffleMemoryManager = { - val maxMemory = memoryManager.maxExecutionMemory - val pageSize = ShuffleMemoryManager.getPageSize(conf, maxMemory, numCores) - new ShuffleMemoryManager(memoryManager, pageSize) - } - - /** - * Create a dummy [[ShuffleMemoryManager]] with the specified capacity and page size. - */ - def create(maxMemory: Long, pageSizeBytes: Long): ShuffleMemoryManager = { - val conf = new SparkConf - val memoryManager = new StaticMemoryManager( - conf, maxExecutionMemory = maxMemory, maxStorageMemory = Long.MaxValue) - new ShuffleMemoryManager(memoryManager, pageSizeBytes) - } - - @VisibleForTesting - def createForTesting(maxMemory: Long): ShuffleMemoryManager = { - create(maxMemory, 4 * 1024 * 1024) - } - - /** - * Sets the page size, in bytes. - * - * If user didn't explicitly set "spark.buffer.pageSize", we figure out the default value - * by looking at the number of cores available to the process, and the total amount of memory, - * and then divide it by a factor of safety. - */ - private def getPageSize(conf: SparkConf, maxMemory: Long, numCores: Int): Long = { - val minPageSize = 1L * 1024 * 1024 // 1MB - val maxPageSize = 64L * minPageSize // 64MB - val cores = if (numCores > 0) numCores else Runtime.getRuntime.availableProcessors() - // Because of rounding to next power of 2, we may have safetyFactor as 8 in worst case - val safetyFactor = 16 - val size = ByteArrayMethods.nextPowerOf2(maxMemory / cores / safetyFactor) - val default = math.min(maxPageSize, math.max(minPageSize, size)) - conf.getSizeAsBytes("spark.buffer.pageSize", default) - } -} diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 1105167d39..66b6bbc61f 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -133,7 +133,6 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager env.blockManager, shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], context.taskMemoryManager(), - env.shuffleMemoryManager, unsafeShuffleHandle, mapId, context, diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index bbd9c1ab53..808317b017 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -52,13 +52,13 @@ private[spark] class SortShuffleWriter[K, V, C]( sorter = if (dep.mapSideCombine) { require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!") new ExternalSorter[K, V, C]( - dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer) + context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer) } else { // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't // care whether the keys get sorted in each partition; that will be done on the reduce side // if the operation being run is sortByKey. new ExternalSorter[K, V, V]( - aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer) + context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer) } sorter.insertAll(records) @@ -67,7 +67,7 @@ private[spark] class SortShuffleWriter[K, V, C]( // (see SPARK-3570). val outputFile = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId) val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) - val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile) + val partitionLengths = sorter.writePartitionedFile(blockId, outputFile) shuffleBlockResolver.writeIndexFile(dep.shuffleId, mapId, partitionLengths) mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index cfa58f5ef4..f6d81ee5bf 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -28,8 +28,10 @@ import com.google.common.io.ByteStreams import org.apache.spark.{Logging, SparkEnv, TaskContext} import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.serializer.{DeserializationStream, Serializer} import org.apache.spark.storage.{BlockId, BlockManager} +import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalAppendOnlyMap.HashComparator import org.apache.spark.executor.ShuffleWriteMetrics @@ -55,12 +57,30 @@ class ExternalAppendOnlyMap[K, V, C]( mergeValue: (C, V) => C, mergeCombiners: (C, C) => C, serializer: Serializer = SparkEnv.get.serializer, - blockManager: BlockManager = SparkEnv.get.blockManager) + blockManager: BlockManager = SparkEnv.get.blockManager, + context: TaskContext = TaskContext.get()) extends Iterable[(K, C)] with Serializable with Logging with Spillable[SizeTracker] { + if (context == null) { + throw new IllegalStateException( + "Spillable collections should not be instantiated outside of tasks") + } + + // Backwards-compatibility constructor for binary compatibility + def this( + createCombiner: V => C, + mergeValue: (C, V) => C, + mergeCombiners: (C, C) => C, + serializer: Serializer, + blockManager: BlockManager) { + this(createCombiner, mergeValue, mergeCombiners, serializer, blockManager, TaskContext.get()) + } + + override protected[this] def taskMemoryManager: TaskMemoryManager = context.taskMemoryManager() + private var currentMap = new SizeTrackingAppendOnlyMap[K, C] private val spilledMaps = new ArrayBuffer[DiskMapIterator] private val sparkConf = SparkEnv.get.conf @@ -118,6 +138,10 @@ class ExternalAppendOnlyMap[K, V, C]( * The shuffle memory usage of the first trackMemoryThreshold entries is not tracked. */ def insertAll(entries: Iterator[Product2[K, V]]): Unit = { + if (currentMap == null) { + throw new IllegalStateException( + "Cannot insert new elements into a map after calling iterator") + } // An update function for the map that we reuse across entries to avoid allocating // a new closure each time var curEntry: Product2[K, V] = null @@ -215,17 +239,26 @@ class ExternalAppendOnlyMap[K, V, C]( } /** - * Return an iterator that merges the in-memory map with the spilled maps. + * Return a destructive iterator that merges the in-memory map with the spilled maps. * If no spill has occurred, simply return the in-memory map's iterator. */ override def iterator: Iterator[(K, C)] = { + if (currentMap == null) { + throw new IllegalStateException( + "ExternalAppendOnlyMap.iterator is destructive and should only be called once.") + } if (spilledMaps.isEmpty) { - currentMap.iterator + CompletionIterator[(K, C), Iterator[(K, C)]](currentMap.iterator, freeCurrentMap()) } else { new ExternalIterator() } } + private def freeCurrentMap(): Unit = { + currentMap = null // So that the memory can be garbage-collected + releaseMemory() + } + /** * An iterator that sort-merges (K, C) pairs from the in-memory map and the spilled maps */ @@ -237,7 +270,8 @@ class ExternalAppendOnlyMap[K, V, C]( // Input streams are derived both from the in-memory map and spilled maps on disk // The in-memory map is sorted in place, while the spilled maps are already in sorted order - private val sortedMap = currentMap.destructiveSortedIterator(keyComparator) + private val sortedMap = CompletionIterator[(K, C), Iterator[(K, C)]]( + currentMap.destructiveSortedIterator(keyComparator), freeCurrentMap()) private val inputStreams = (Seq(sortedMap) ++ spilledMaps).map(it => it.buffered) inputStreams.foreach { it => @@ -493,12 +527,7 @@ class ExternalAppendOnlyMap[K, V, C]( } } - val context = TaskContext.get() - // context is null in some tests of ExternalAppendOnlyMapSuite because these tests don't run in - // a TaskContext. - if (context != null) { - context.addTaskCompletionListener(context => cleanup()) - } + context.addTaskCompletionListener(context => cleanup()) } /** Convenience function to hash the given (K, C) pair by the key. */ diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index c48c453a90..a44e72b7c1 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -27,6 +27,7 @@ import com.google.common.annotations.VisibleForTesting import com.google.common.io.ByteStreams import org.apache.spark._ +import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.serializer._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter} @@ -87,6 +88,7 @@ import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter} * - Users are expected to call stop() at the end to delete all the intermediate files. */ private[spark] class ExternalSorter[K, V, C]( + context: TaskContext, aggregator: Option[Aggregator[K, V, C]] = None, partitioner: Option[Partitioner] = None, ordering: Option[Ordering[K]] = None, @@ -94,6 +96,8 @@ private[spark] class ExternalSorter[K, V, C]( extends Logging with Spillable[WritablePartitionedPairCollection[K, C]] { + override protected[this] def taskMemoryManager: TaskMemoryManager = context.taskMemoryManager() + private val conf = SparkEnv.get.conf private val numPartitions = partitioner.map(_.numPartitions).getOrElse(1) @@ -640,7 +644,6 @@ private[spark] class ExternalSorter[K, V, C]( */ def writePartitionedFile( blockId: BlockId, - context: TaskContext, outputFile: File): Array[Long] = { // Track location of each range in the output file @@ -686,8 +689,11 @@ private[spark] class ExternalSorter[K, V, C]( } def stop(): Unit = { + map = null // So that the memory can be garbage-collected + buffer = null // So that the memory can be garbage-collected spills.foreach(s => s.file.delete()) spills.clear() + releaseMemory() } /** diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala index d2a68ca7a3..a76891acf0 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -17,8 +17,8 @@ package org.apache.spark.util.collection -import org.apache.spark.Logging -import org.apache.spark.SparkEnv +import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.{Logging, SparkEnv} /** * Spills contents of an in-memory collection to disk when the memory threshold @@ -40,7 +40,7 @@ private[spark] trait Spillable[C] extends Logging { protected def addElementsRead(): Unit = { _elementsRead += 1 } // Memory manager that can be used to acquire/release memory - private[this] val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager + protected[this] def taskMemoryManager: TaskMemoryManager // Initial threshold for the size of a collection before we start tracking its memory usage // For testing only @@ -78,7 +78,7 @@ private[spark] trait Spillable[C] extends Logging { if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) { // Claim up to double our current memory from the shuffle memory pool val amountToRequest = 2 * currentMemory - myMemoryThreshold - val granted = shuffleMemoryManager.tryToAcquire(amountToRequest) + val granted = taskMemoryManager.acquireExecutionMemory(amountToRequest) myMemoryThreshold += granted // If we were granted too little memory to grow further (either tryToAcquire returned 0, // or we already had more memory than myMemoryThreshold), spill the current collection @@ -92,7 +92,7 @@ private[spark] trait Spillable[C] extends Logging { spill(collection) _elementsRead = 0 _memoryBytesSpilled += currentMemory - releaseMemoryForThisThread() + releaseMemory() } shouldSpill } @@ -103,11 +103,11 @@ private[spark] trait Spillable[C] extends Logging { def memoryBytesSpilled: Long = _memoryBytesSpilled /** - * Release our memory back to the shuffle pool so that other threads can grab it. + * Release our memory back to the execution pool so that other tasks can grab it. */ - private def releaseMemoryForThisThread(): Unit = { + def releaseMemory(): Unit = { // The amount we requested does not include the initial memory tracking threshold - shuffleMemoryManager.release(myMemoryThreshold - initialMemoryThreshold) + taskMemoryManager.releaseExecutionMemory(myMemoryThreshold - initialMemoryThreshold) myMemoryThreshold = initialMemoryThreshold } diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java new file mode 100644 index 0000000000..f381db0c62 --- /dev/null +++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory; + +import org.junit.Assert; +import org.junit.Test; + +import org.apache.spark.SparkConf; +import org.apache.spark.unsafe.memory.MemoryBlock; + +public class TaskMemoryManagerSuite { + + @Test + public void leakedPageMemoryIsDetected() { + final TaskMemoryManager manager = new TaskMemoryManager( + new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0); + manager.allocatePage(4096); // leak memory + Assert.assertEquals(4096, manager.cleanUpAllAllocatedMemory()); + } + + @Test + public void encodePageNumberAndOffsetOffHeap() { + final TaskMemoryManager manager = new TaskMemoryManager( + new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "true")), 0); + final MemoryBlock dataPage = manager.allocatePage(256); + // In off-heap mode, an offset is an absolute address that may require more than 51 bits to + // encode. This test exercises that corner-case: + final long offset = ((1L << TaskMemoryManager.OFFSET_BITS) + 10); + final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, offset); + Assert.assertEquals(null, manager.getPage(encodedAddress)); + Assert.assertEquals(offset, manager.getOffsetInPage(encodedAddress)); + } + + @Test + public void encodePageNumberAndOffsetOnHeap() { + final TaskMemoryManager manager = new TaskMemoryManager( + new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0); + final MemoryBlock dataPage = manager.allocatePage(256); + final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, 64); + Assert.assertEquals(dataPage.getBaseObject(), manager.getPage(encodedAddress)); + Assert.assertEquals(64, manager.getOffsetInPage(encodedAddress)); + } + +} diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java index 232ae4d926..7fb2f92ca8 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java @@ -21,18 +21,19 @@ import org.apache.spark.shuffle.sort.PackedRecordPointer; import org.junit.Test; import static org.junit.Assert.*; -import org.apache.spark.unsafe.memory.ExecutorMemoryManager; -import org.apache.spark.unsafe.memory.MemoryAllocator; +import org.apache.spark.SparkConf; +import org.apache.spark.memory.GrantEverythingMemoryManager; import org.apache.spark.unsafe.memory.MemoryBlock; -import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.memory.TaskMemoryManager; import static org.apache.spark.shuffle.sort.PackedRecordPointer.*; public class PackedRecordPointerSuite { @Test public void heap() { + final SparkConf conf = new SparkConf().set("spark.unsafe.offHeap", "false"); final TaskMemoryManager memoryManager = - new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + new TaskMemoryManager(new GrantEverythingMemoryManager(conf), 0); final MemoryBlock page0 = memoryManager.allocatePage(128); final MemoryBlock page1 = memoryManager.allocatePage(128); final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1, @@ -49,8 +50,9 @@ public class PackedRecordPointerSuite { @Test public void offHeap() { + final SparkConf conf = new SparkConf().set("spark.unsafe.offHeap", "true"); final TaskMemoryManager memoryManager = - new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE)); + new TaskMemoryManager(new GrantEverythingMemoryManager(conf), 0); final MemoryBlock page0 = memoryManager.allocatePage(128); final MemoryBlock page1 = memoryManager.allocatePage(128); final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1, diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java index 1ef3c5ff64..5049a5306f 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java @@ -24,11 +24,11 @@ import org.junit.Assert; import org.junit.Test; import org.apache.spark.HashPartitioner; +import org.apache.spark.SparkConf; import org.apache.spark.unsafe.Platform; -import org.apache.spark.unsafe.memory.ExecutorMemoryManager; -import org.apache.spark.unsafe.memory.MemoryAllocator; +import org.apache.spark.memory.GrantEverythingMemoryManager; import org.apache.spark.unsafe.memory.MemoryBlock; -import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.memory.TaskMemoryManager; public class ShuffleInMemorySorterSuite { @@ -58,8 +58,9 @@ public class ShuffleInMemorySorterSuite { "Lychee", "Mango" }; + final SparkConf conf = new SparkConf().set("spark.unsafe.offHeap", "false"); final TaskMemoryManager memoryManager = - new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + new TaskMemoryManager(new GrantEverythingMemoryManager(conf), 0); final MemoryBlock dataPage = memoryManager.allocatePage(2048); final Object baseObject = dataPage.getBaseObject(); final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4); diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 29d9823b1f..d65926949c 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -39,7 +39,6 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.lessThan; import static org.junit.Assert.*; -import static org.mockito.AdditionalAnswers.returnsFirstArg; import static org.mockito.Answers.RETURNS_SMART_NULLS; import static org.mockito.Mockito.*; @@ -54,19 +53,15 @@ import org.apache.spark.network.util.LimitedInputStream; import org.apache.spark.serializer.*; import org.apache.spark.scheduler.MapStatus; import org.apache.spark.shuffle.IndexShuffleBlockResolver; -import org.apache.spark.shuffle.ShuffleMemoryManager; -import org.apache.spark.shuffle.sort.SerializedShuffleHandle; import org.apache.spark.storage.*; -import org.apache.spark.unsafe.memory.ExecutorMemoryManager; -import org.apache.spark.unsafe.memory.MemoryAllocator; -import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.memory.GrantEverythingMemoryManager; +import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.util.Utils; public class UnsafeShuffleWriterSuite { static final int NUM_PARTITITONS = 4; - final TaskMemoryManager taskMemoryManager = - new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + TaskMemoryManager taskMemoryManager; final HashPartitioner hashPartitioner = new HashPartitioner(NUM_PARTITITONS); File mergedOutputFile; File tempDir; @@ -76,7 +71,6 @@ public class UnsafeShuffleWriterSuite { final Serializer serializer = new KryoSerializer(new SparkConf()); TaskMetrics taskMetrics; - @Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager; @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager; @Mock(answer = RETURNS_SMART_NULLS) IndexShuffleBlockResolver shuffleBlockResolver; @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager; @@ -111,11 +105,11 @@ public class UnsafeShuffleWriterSuite { mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir); partitionSizesInMergedFile = null; spillFilesCreated.clear(); - conf = new SparkConf().set("spark.buffer.pageSize", "128m"); + conf = new SparkConf() + .set("spark.buffer.pageSize", "128m") + .set("spark.unsafe.offHeap", "false"); taskMetrics = new TaskMetrics(); - - when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg()); - when(shuffleMemoryManager.pageSizeBytes()).thenReturn(128L * 1024 * 1024); + taskMemoryManager = new TaskMemoryManager(new GrantEverythingMemoryManager(conf), 0); when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); when(blockManager.getDiskWriter( @@ -203,7 +197,6 @@ public class UnsafeShuffleWriterSuite { blockManager, shuffleBlockResolver, taskMemoryManager, - shuffleMemoryManager, new SerializedShuffleHandle(0, 1, shuffleDep), 0, // map id taskContext, @@ -405,11 +398,12 @@ public class UnsafeShuffleWriterSuite { @Test public void writeEnoughDataToTriggerSpill() throws Exception { - when(shuffleMemoryManager.tryToAcquire(anyLong())) - .then(returnsFirstArg()) // Allocate initial sort buffer - .then(returnsFirstArg()) // Allocate initial data page - .thenReturn(0L) // Deny request to allocate new data page - .then(returnsFirstArg()); // Grant new sort buffer and data page. + taskMemoryManager = spy(taskMemoryManager); + doCallRealMethod() // initialize sort buffer + .doCallRealMethod() // allocate initial data page + .doReturn(0L) // deny request to allocate new page + .doCallRealMethod() // grant new sort buffer and data page + .when(taskMemoryManager).acquireExecutionMemory(anyLong()); final UnsafeShuffleWriter writer = createWriter(false); final ArrayList> dataToWrite = new ArrayList>(); final byte[] bigByteArray = new byte[PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES / 128]; @@ -417,7 +411,7 @@ public class UnsafeShuffleWriterSuite { dataToWrite.add(new Tuple2(i, bigByteArray)); } writer.write(dataToWrite.iterator()); - verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong()); + verify(taskMemoryManager, times(5)).acquireExecutionMemory(anyLong()); assertEquals(2, spillFilesCreated.size()); writer.stop(true); readRecordsFromFile(); @@ -432,18 +426,19 @@ public class UnsafeShuffleWriterSuite { @Test public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception { - when(shuffleMemoryManager.tryToAcquire(anyLong())) - .then(returnsFirstArg()) // Allocate initial sort buffer - .then(returnsFirstArg()) // Allocate initial data page - .thenReturn(0L) // Deny request to grow sort buffer - .then(returnsFirstArg()); // Grant new sort buffer and data page. + taskMemoryManager = spy(taskMemoryManager); + doCallRealMethod() // initialize sort buffer + .doCallRealMethod() // allocate initial data page + .doReturn(0L) // deny request to allocate new page + .doCallRealMethod() // grant new sort buffer and data page + .when(taskMemoryManager).acquireExecutionMemory(anyLong()); final UnsafeShuffleWriter writer = createWriter(false); - final ArrayList> dataToWrite = new ArrayList>(); + final ArrayList> dataToWrite = new ArrayList<>(); for (int i = 0; i < UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE; i++) { dataToWrite.add(new Tuple2(i, i)); } writer.write(dataToWrite.iterator()); - verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong()); + verify(taskMemoryManager, times(5)).acquireExecutionMemory(anyLong()); assertEquals(2, spillFilesCreated.size()); writer.stop(true); readRecordsFromFile(); @@ -509,13 +504,13 @@ public class UnsafeShuffleWriterSuite { final long recordLengthBytes = 8; final long pageSizeBytes = 256; final long numRecordsPerPage = pageSizeBytes / recordLengthBytes; - when(shuffleMemoryManager.pageSizeBytes()).thenReturn(pageSizeBytes); + taskMemoryManager = spy(taskMemoryManager); + when(taskMemoryManager.pageSizeBytes()).thenReturn(pageSizeBytes); final UnsafeShuffleWriter writer = new UnsafeShuffleWriter( blockManager, shuffleBlockResolver, taskMemoryManager, - shuffleMemoryManager, new SerializedShuffleHandle<>(0, 1, shuffleDep), 0, // map id taskContext, diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index ab480b60ad..6e52496cf9 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -21,15 +21,13 @@ import java.lang.Exception; import java.nio.ByteBuffer; import java.util.*; +import org.apache.spark.memory.TaskMemoryManager; import org.junit.*; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; import static org.hamcrest.Matchers.greaterThan; import static org.junit.Assert.*; -import static org.mockito.AdditionalMatchers.geq; -import static org.mockito.Mockito.*; -import org.apache.spark.shuffle.ShuffleMemoryManager; +import org.apache.spark.SparkConf; +import org.apache.spark.memory.GrantEverythingMemoryManager; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.memory.*; import org.apache.spark.unsafe.Platform; @@ -39,42 +37,29 @@ public abstract class AbstractBytesToBytesMapSuite { private final Random rand = new Random(42); - private ShuffleMemoryManager shuffleMemoryManager; + private GrantEverythingMemoryManager memoryManager; private TaskMemoryManager taskMemoryManager; - private TaskMemoryManager sizeLimitedTaskMemoryManager; private final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes @Before public void setup() { - shuffleMemoryManager = ShuffleMemoryManager.create(Long.MAX_VALUE, PAGE_SIZE_BYTES); - taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(getMemoryAllocator())); - // Mocked memory manager for tests that check the maximum array size, since actually allocating - // such large arrays will cause us to run out of memory in our tests. - sizeLimitedTaskMemoryManager = mock(TaskMemoryManager.class); - when(sizeLimitedTaskMemoryManager.allocate(geq(1L << 20))).thenAnswer( - new Answer() { - @Override - public MemoryBlock answer(InvocationOnMock invocation) throws Throwable { - if (((Long) invocation.getArguments()[0] / 8) > Integer.MAX_VALUE) { - throw new OutOfMemoryError("Requested array size exceeds VM limit"); - } - return new MemoryBlock(null, 0, (Long) invocation.getArguments()[0]); - } - } - ); + memoryManager = + new GrantEverythingMemoryManager( + new SparkConf().set("spark.unsafe.offHeap", "" + useOffHeapMemoryAllocator())); + taskMemoryManager = new TaskMemoryManager(memoryManager, 0); } @After public void tearDown() { Assert.assertEquals(0L, taskMemoryManager.cleanUpAllAllocatedMemory()); - if (shuffleMemoryManager != null) { - long leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask(); - shuffleMemoryManager = null; - Assert.assertEquals(0L, leakedShuffleMemory); + if (taskMemoryManager != null) { + long leakedMemory = taskMemoryManager.getMemoryConsumptionForThisTask(); + taskMemoryManager = null; + Assert.assertEquals(0L, leakedMemory); } } - protected abstract MemoryAllocator getMemoryAllocator(); + protected abstract boolean useOffHeapMemoryAllocator(); private static byte[] getByteArray(MemoryLocation loc, int size) { final byte[] arr = new byte[size]; @@ -110,8 +95,7 @@ public abstract class AbstractBytesToBytesMapSuite { @Test public void emptyMap() { - BytesToBytesMap map = new BytesToBytesMap( - taskMemoryManager, shuffleMemoryManager, 64, PAGE_SIZE_BYTES); + BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 64, PAGE_SIZE_BYTES); try { Assert.assertEquals(0, map.numElements()); final int keyLengthInWords = 10; @@ -126,8 +110,7 @@ public abstract class AbstractBytesToBytesMapSuite { @Test public void setAndRetrieveAKey() { - BytesToBytesMap map = new BytesToBytesMap( - taskMemoryManager, shuffleMemoryManager, 64, PAGE_SIZE_BYTES); + BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 64, PAGE_SIZE_BYTES); final int recordLengthWords = 10; final int recordLengthBytes = recordLengthWords * 8; final byte[] keyData = getRandomByteArray(recordLengthWords); @@ -179,8 +162,7 @@ public abstract class AbstractBytesToBytesMapSuite { private void iteratorTestBase(boolean destructive) throws Exception { final int size = 4096; - BytesToBytesMap map = new BytesToBytesMap( - taskMemoryManager, shuffleMemoryManager, size / 2, PAGE_SIZE_BYTES); + BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, size / 2, PAGE_SIZE_BYTES); try { for (long i = 0; i < size; i++) { final long[] value = new long[] { i }; @@ -265,8 +247,8 @@ public abstract class AbstractBytesToBytesMapSuite { final int NUM_ENTRIES = 1000 * 1000; final int KEY_LENGTH = 24; final int VALUE_LENGTH = 40; - final BytesToBytesMap map = new BytesToBytesMap( - taskMemoryManager, shuffleMemoryManager, NUM_ENTRIES, PAGE_SIZE_BYTES); + final BytesToBytesMap map = + new BytesToBytesMap(taskMemoryManager, NUM_ENTRIES, PAGE_SIZE_BYTES); // Each record will take 8 + 24 + 40 = 72 bytes of space in the data page. Our 64-megabyte // pages won't be evenly-divisible by records of this size, which will cause us to waste some // space at the end of the page. This is necessary in order for us to take the end-of-record @@ -335,9 +317,7 @@ public abstract class AbstractBytesToBytesMapSuite { // Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays // into ByteBuffers in order to use them as keys here. final Map expected = new HashMap(); - final BytesToBytesMap map = new BytesToBytesMap( - taskMemoryManager, shuffleMemoryManager, size, PAGE_SIZE_BYTES); - + final BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, size, PAGE_SIZE_BYTES); try { // Fill the map to 90% full so that we can trigger probing for (int i = 0; i < size * 0.9; i++) { @@ -386,8 +366,7 @@ public abstract class AbstractBytesToBytesMapSuite { @Test public void randomizedTestWithRecordsLargerThanPageSize() { final long pageSizeBytes = 128; - final BytesToBytesMap map = new BytesToBytesMap( - taskMemoryManager, shuffleMemoryManager, 64, pageSizeBytes); + final BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 64, pageSizeBytes); // Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays // into ByteBuffers in order to use them as keys here. final Map expected = new HashMap(); @@ -436,9 +415,9 @@ public abstract class AbstractBytesToBytesMapSuite { @Test public void failureToAllocateFirstPage() { - shuffleMemoryManager = ShuffleMemoryManager.createForTesting(1024); - BytesToBytesMap map = - new BytesToBytesMap(taskMemoryManager, shuffleMemoryManager, 1, PAGE_SIZE_BYTES); + memoryManager.markExecutionAsOutOfMemory(); + BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 1, PAGE_SIZE_BYTES); + memoryManager.markExecutionAsOutOfMemory(); try { final long[] emptyArray = new long[0]; final BytesToBytesMap.Location loc = @@ -454,12 +433,14 @@ public abstract class AbstractBytesToBytesMapSuite { @Test public void failureToGrow() { - shuffleMemoryManager = ShuffleMemoryManager.createForTesting(1024 * 10); - BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, shuffleMemoryManager, 1, 1024); + BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 1, 1024); try { boolean success = true; int i; - for (i = 0; i < 1024; i++) { + for (i = 0; i < 127; i++) { + if (i > 0) { + memoryManager.markExecutionAsOutOfMemory(); + } final long[] arr = new long[]{i}; final BytesToBytesMap.Location loc = map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8); success = @@ -478,7 +459,7 @@ public abstract class AbstractBytesToBytesMapSuite { @Test public void initialCapacityBoundsChecking() { try { - new BytesToBytesMap(sizeLimitedTaskMemoryManager, shuffleMemoryManager, 0, PAGE_SIZE_BYTES); + new BytesToBytesMap(taskMemoryManager, 0, PAGE_SIZE_BYTES); Assert.fail("Expected IllegalArgumentException to be thrown"); } catch (IllegalArgumentException e) { // expected exception @@ -486,36 +467,13 @@ public abstract class AbstractBytesToBytesMapSuite { try { new BytesToBytesMap( - sizeLimitedTaskMemoryManager, - shuffleMemoryManager, + taskMemoryManager, BytesToBytesMap.MAX_CAPACITY + 1, PAGE_SIZE_BYTES); Assert.fail("Expected IllegalArgumentException to be thrown"); } catch (IllegalArgumentException e) { // expected exception } - - // Ignored because this can OOM now that we allocate the long array w/o a TaskMemoryManager - // Can allocate _at_ the max capacity - // BytesToBytesMap map = new BytesToBytesMap( - // sizeLimitedTaskMemoryManager, - // shuffleMemoryManager, - // BytesToBytesMap.MAX_CAPACITY, - // PAGE_SIZE_BYTES); - // map.free(); - } - - // Ignored because this can OOM now that we allocate the long array w/o a TaskMemoryManager - @Ignore - public void resizingLargeMap() { - // As long as a map's capacity is below the max, we should be able to resize up to the max - BytesToBytesMap map = new BytesToBytesMap( - sizeLimitedTaskMemoryManager, - shuffleMemoryManager, - BytesToBytesMap.MAX_CAPACITY - 64, - PAGE_SIZE_BYTES); - map.growAndRehash(); - map.free(); } @Test @@ -523,8 +481,7 @@ public abstract class AbstractBytesToBytesMapSuite { final long recordLengthBytes = 24; final long pageSizeBytes = 256 + 8; // 8 bytes for end-of-page marker final long numRecordsPerPage = (pageSizeBytes - 8) / recordLengthBytes; - final BytesToBytesMap map = new BytesToBytesMap( - taskMemoryManager, shuffleMemoryManager, 1024, pageSizeBytes); + final BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 1024, pageSizeBytes); // Since BytesToBytesMap is append-only, we expect the total memory consumption to be // monotonically increasing. More specifically, every time we allocate a new page it @@ -564,8 +521,7 @@ public abstract class AbstractBytesToBytesMapSuite { @Test public void testAcquirePageInConstructor() { - final BytesToBytesMap map = new BytesToBytesMap( - taskMemoryManager, shuffleMemoryManager, 1, PAGE_SIZE_BYTES); + final BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 1, PAGE_SIZE_BYTES); assertEquals(1, map.getNumDataPages()); map.free(); } diff --git a/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java index 5a10de49f5..f0bad4d760 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java @@ -17,13 +17,10 @@ package org.apache.spark.unsafe.map; -import org.apache.spark.unsafe.memory.MemoryAllocator; - public class BytesToBytesMapOffHeapSuite extends AbstractBytesToBytesMapSuite { @Override - protected MemoryAllocator getMemoryAllocator() { - return MemoryAllocator.UNSAFE; + protected boolean useOffHeapMemoryAllocator() { + return true; } - } diff --git a/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java index 12cc9b25d9..d76bb4fd05 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java @@ -17,13 +17,10 @@ package org.apache.spark.unsafe.map; -import org.apache.spark.unsafe.memory.MemoryAllocator; - public class BytesToBytesMapOnHeapSuite extends AbstractBytesToBytesMapSuite { @Override - protected MemoryAllocator getMemoryAllocator() { - return MemoryAllocator.HEAP; + protected boolean useOffHeapMemoryAllocator() { + return false; } - } diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index a5bbaa95fa..94d50b94fd 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -46,20 +46,19 @@ import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.executor.TaskMetrics; +import org.apache.spark.memory.GrantEverythingMemoryManager; import org.apache.spark.serializer.SerializerInstance; -import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.storage.*; import org.apache.spark.unsafe.Platform; -import org.apache.spark.unsafe.memory.ExecutorMemoryManager; -import org.apache.spark.unsafe.memory.MemoryAllocator; -import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.util.Utils; public class UnsafeExternalSorterSuite { final LinkedList spillFilesCreated = new LinkedList(); - final TaskMemoryManager taskMemoryManager = - new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + final GrantEverythingMemoryManager memoryManager = + new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")); + final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0); // Use integer comparison for comparing prefixes (which are partition ids, in this case) final PrefixComparator prefixComparator = new PrefixComparator() { @Override @@ -82,7 +81,6 @@ public class UnsafeExternalSorterSuite { SparkConf sparkConf; File tempDir; - ShuffleMemoryManager shuffleMemoryManager; @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager; @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager; @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext; @@ -102,7 +100,6 @@ public class UnsafeExternalSorterSuite { MockitoAnnotations.initMocks(this); sparkConf = new SparkConf(); tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "unsafe-test"); - shuffleMemoryManager = ShuffleMemoryManager.create(Long.MAX_VALUE, pageSizeBytes); spillFilesCreated.clear(); taskContext = mock(TaskContext.class); when(taskContext.taskMetrics()).thenReturn(new TaskMetrics()); @@ -143,13 +140,7 @@ public class UnsafeExternalSorterSuite { @After public void tearDown() { try { - long leakedUnsafeMemory = taskMemoryManager.cleanUpAllAllocatedMemory(); - if (shuffleMemoryManager != null) { - long leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask(); - shuffleMemoryManager = null; - assertEquals(0L, leakedShuffleMemory); - } - assertEquals(0, leakedUnsafeMemory); + assertEquals(0L, taskMemoryManager.cleanUpAllAllocatedMemory()); } finally { Utils.deleteRecursively(tempDir); tempDir = null; @@ -178,7 +169,6 @@ public class UnsafeExternalSorterSuite { private UnsafeExternalSorter newSorter() throws IOException { return UnsafeExternalSorter.create( taskMemoryManager, - shuffleMemoryManager, blockManager, taskContext, recordComparator, @@ -236,12 +226,16 @@ public class UnsafeExternalSorterSuite { @Test public void spillingOccursInResponseToMemoryPressure() throws Exception { - shuffleMemoryManager = ShuffleMemoryManager.create(pageSizeBytes * 2, pageSizeBytes); final UnsafeExternalSorter sorter = newSorter(); - final int numRecords = (int) pageSizeBytes / 4; - for (int i = 0; i <= numRecords; i++) { + // This should be enough records to completely fill up a data page: + final int numRecords = (int) (pageSizeBytes / (4 + 4)); + for (int i = 0; i < numRecords; i++) { insertNumber(sorter, numRecords - i); } + assertEquals(1, sorter.getNumberOfAllocatedPages()); + memoryManager.markExecutionAsOutOfMemory(); + // The insertion of this record should trigger a spill: + insertNumber(sorter, 0); // Ensure that spill files were created assertThat(tempDir.listFiles().length, greaterThanOrEqualTo(1)); // Read back the sorted data: @@ -255,6 +249,7 @@ public class UnsafeExternalSorterSuite { assertEquals(i, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset())); i++; } + assertEquals(numRecords + 1, i); sorter.cleanupResources(); assertSpillFilesWereCleanedUp(); } @@ -323,7 +318,6 @@ public class UnsafeExternalSorterSuite { final long numRecordsPerPage = pageSizeBytes / recordLengthBytes; final UnsafeExternalSorter sorter = UnsafeExternalSorter.create( taskMemoryManager, - shuffleMemoryManager, blockManager, taskContext, recordComparator, diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index 778e813df6..d5de56a051 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -26,11 +26,11 @@ import static org.junit.Assert.*; import static org.mockito.Mockito.mock; import org.apache.spark.HashPartitioner; +import org.apache.spark.SparkConf; import org.apache.spark.unsafe.Platform; -import org.apache.spark.unsafe.memory.ExecutorMemoryManager; -import org.apache.spark.unsafe.memory.MemoryAllocator; +import org.apache.spark.memory.GrantEverythingMemoryManager; import org.apache.spark.unsafe.memory.MemoryBlock; -import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.memory.TaskMemoryManager; public class UnsafeInMemorySorterSuite { @@ -43,7 +43,8 @@ public class UnsafeInMemorySorterSuite { @Test public void testSortingEmptyInput() { final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter( - new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)), + new TaskMemoryManager( + new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0), mock(RecordComparator.class), mock(PrefixComparator.class), 100); @@ -64,8 +65,8 @@ public class UnsafeInMemorySorterSuite { "Lychee", "Mango" }; - final TaskMemoryManager memoryManager = - new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); + final TaskMemoryManager memoryManager = new TaskMemoryManager( + new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0); final MemoryBlock dataPage = memoryManager.allocatePage(2048); final Object baseObject = dataPage.getBaseObject(); // Write the records into the data page: diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index f58756e6f6..0242cbc924 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -149,7 +149,7 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext { // cause is preserved val thrownDueToTaskFailure = intercept[SparkException] { sc.parallelize(Seq(0)).mapPartitions { iter => - TaskContext.get().taskMemoryManager().allocate(128) + TaskContext.get().taskMemoryManager().allocatePage(128) throw new Exception("intentional task failure") iter }.count() @@ -159,7 +159,7 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext { // If the task succeeded but memory was leaked, then the task should fail due to that leak val thrownDueToMemoryLeak = intercept[SparkException] { sc.parallelize(Seq(0)).mapPartitions { iter => - TaskContext.get().taskMemoryManager().allocate(128) + TaskContext.get().taskMemoryManager().allocatePage(128) iter }.count() } diff --git a/core/src/test/scala/org/apache/spark/memory/GrantEverythingMemoryManager.scala b/core/src/test/scala/org/apache/spark/memory/GrantEverythingMemoryManager.scala new file mode 100644 index 0000000000..fe102d8aeb --- /dev/null +++ b/core/src/test/scala/org/apache/spark/memory/GrantEverythingMemoryManager.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory + +import scala.collection.mutable + +import org.apache.spark.SparkConf +import org.apache.spark.storage.{BlockStatus, BlockId} + +class GrantEverythingMemoryManager(conf: SparkConf) extends MemoryManager(conf, numCores = 1) { + private[memory] override def doAcquireExecutionMemory( + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long = synchronized { + if (oom) { + oom = false + 0 + } else { + _executionMemoryUsed += numBytes // To suppress warnings when freeing unallocated memory + numBytes + } + } + override def acquireStorageMemory( + blockId: BlockId, + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = true + override def acquireUnrollMemory( + blockId: BlockId, + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = true + override def releaseStorageMemory(numBytes: Long): Unit = { } + override def maxExecutionMemory: Long = Long.MaxValue + override def maxStorageMemory: Long = Long.MaxValue + + private var oom = false + + def markExecutionAsOutOfMemory(): Unit = { + oom = true + } +} diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala index 36e4566310..1265087743 100644 --- a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala @@ -19,10 +19,14 @@ package org.apache.spark.memory import java.util.concurrent.atomic.AtomicLong +import scala.concurrent.duration.Duration +import scala.concurrent.{Await, ExecutionContext, Future} + import org.mockito.Matchers.{any, anyLong} import org.mockito.Mockito.{mock, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer +import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkFunSuite import org.apache.spark.storage.MemoryStore @@ -126,6 +130,136 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite { assert(ensureFreeSpaceCalled.get() === DEFAULT_ENSURE_FREE_SPACE_CALLED, "ensure free space should not have been called!") } + + /** + * Create a MemoryManager with the specified execution memory limit and no storage memory. + */ + protected def createMemoryManager(maxExecutionMemory: Long): MemoryManager + + // -- Tests of sharing of execution memory between tasks ---------------------------------------- + // Prior to Spark 1.6, these tests were part of ShuffleMemoryManagerSuite. + + implicit val ec = ExecutionContext.global + + test("single task requesting execution memory") { + val manager = createMemoryManager(1000L) + val taskMemoryManager = new TaskMemoryManager(manager, 0) + + assert(taskMemoryManager.acquireExecutionMemory(100L) === 100L) + assert(taskMemoryManager.acquireExecutionMemory(400L) === 400L) + assert(taskMemoryManager.acquireExecutionMemory(400L) === 400L) + assert(taskMemoryManager.acquireExecutionMemory(200L) === 100L) + assert(taskMemoryManager.acquireExecutionMemory(100L) === 0L) + assert(taskMemoryManager.acquireExecutionMemory(100L) === 0L) + + taskMemoryManager.releaseExecutionMemory(500L) + assert(taskMemoryManager.acquireExecutionMemory(300L) === 300L) + assert(taskMemoryManager.acquireExecutionMemory(300L) === 200L) + + taskMemoryManager.cleanUpAllAllocatedMemory() + assert(taskMemoryManager.acquireExecutionMemory(1000L) === 1000L) + assert(taskMemoryManager.acquireExecutionMemory(100L) === 0L) + } + + test("two tasks requesting full execution memory") { + val memoryManager = createMemoryManager(1000L) + val t1MemManager = new TaskMemoryManager(memoryManager, 1) + val t2MemManager = new TaskMemoryManager(memoryManager, 2) + val futureTimeout: Duration = 20.seconds + + // Have both tasks request 500 bytes, then wait until both requests have been granted: + val t1Result1 = Future { t1MemManager.acquireExecutionMemory(500L) } + val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L) } + assert(Await.result(t1Result1, futureTimeout) === 500L) + assert(Await.result(t2Result1, futureTimeout) === 500L) + + // Have both tasks each request 500 bytes more; both should immediately return 0 as they are + // both now at 1 / N + val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L) } + val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L) } + assert(Await.result(t1Result2, 200.millis) === 0L) + assert(Await.result(t2Result2, 200.millis) === 0L) + } + + test("two tasks cannot grow past 1 / N of execution memory") { + val memoryManager = createMemoryManager(1000L) + val t1MemManager = new TaskMemoryManager(memoryManager, 1) + val t2MemManager = new TaskMemoryManager(memoryManager, 2) + val futureTimeout: Duration = 20.seconds + + // Have both tasks request 250 bytes, then wait until both requests have been granted: + val t1Result1 = Future { t1MemManager.acquireExecutionMemory(250L) } + val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L) } + assert(Await.result(t1Result1, futureTimeout) === 250L) + assert(Await.result(t2Result1, futureTimeout) === 250L) + + // Have both tasks each request 500 bytes more. + // We should only grant 250 bytes to each of them on this second request + val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L) } + val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L) } + assert(Await.result(t1Result2, futureTimeout) === 250L) + assert(Await.result(t2Result2, futureTimeout) === 250L) + } + + test("tasks can block to get at least 1 / 2N of execution memory") { + val memoryManager = createMemoryManager(1000L) + val t1MemManager = new TaskMemoryManager(memoryManager, 1) + val t2MemManager = new TaskMemoryManager(memoryManager, 2) + val futureTimeout: Duration = 20.seconds + + // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. + val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L) } + assert(Await.result(t1Result1, futureTimeout) === 1000L) + val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L) } + // Make sure that t2 didn't grab the memory right away. This is hacky but it would be difficult + // to make sure the other thread blocks for some time otherwise. + Thread.sleep(300) + t1MemManager.releaseExecutionMemory(250L) + // The memory freed from t1 should now be granted to t2. + assert(Await.result(t2Result1, futureTimeout) === 250L) + // Further requests by t2 should be denied immediately because it now has 1 / 2N of the memory. + val t2Result2 = Future { t2MemManager.acquireExecutionMemory(100L) } + assert(Await.result(t2Result2, 200.millis) === 0L) + } + + test("TaskMemoryManager.cleanUpAllAllocatedMemory") { + val memoryManager = createMemoryManager(1000L) + val t1MemManager = new TaskMemoryManager(memoryManager, 1) + val t2MemManager = new TaskMemoryManager(memoryManager, 2) + val futureTimeout: Duration = 20.seconds + + // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. + val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L) } + assert(Await.result(t1Result1, futureTimeout) === 1000L) + val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L) } + // Make sure that t2 didn't grab the memory right away. This is hacky but it would be difficult + // to make sure the other thread blocks for some time otherwise. + Thread.sleep(300) + // t1 releases all of its memory, so t2 should be able to grab all of the memory + t1MemManager.cleanUpAllAllocatedMemory() + assert(Await.result(t2Result1, futureTimeout) === 500L) + val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L) } + assert(Await.result(t2Result2, futureTimeout) === 500L) + val t2Result3 = Future { t2MemManager.acquireExecutionMemory(500L) } + assert(Await.result(t2Result3, 200.millis) === 0L) + } + + test("tasks should not be granted a negative amount of execution memory") { + // This is a regression test for SPARK-4715. + val memoryManager = createMemoryManager(1000L) + val t1MemManager = new TaskMemoryManager(memoryManager, 1) + val t2MemManager = new TaskMemoryManager(memoryManager, 2) + val futureTimeout: Duration = 20.seconds + + val t1Result1 = Future { t1MemManager.acquireExecutionMemory(700L) } + assert(Await.result(t1Result1, futureTimeout) === 700L) + + val t2Result1 = Future { t2MemManager.acquireExecutionMemory(300L) } + assert(Await.result(t2Result1, futureTimeout) === 300L) + + val t1Result2 = Future { t1MemManager.acquireExecutionMemory(300L) } + assert(Await.result(t1Result2, 200.millis) === 0L) + } } private object MemoryManagerSuite { diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala new file mode 100644 index 0000000000..4b4c3b0311 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory + +import org.apache.spark.{SparkEnv, TaskContextImpl, TaskContext} + +/** + * Helper methods for mocking out memory-management-related classes in tests. + */ +object MemoryTestingUtils { + def fakeTaskContext(env: SparkEnv): TaskContext = { + val taskMemoryManager = new TaskMemoryManager(env.memoryManager, 0) + new TaskContextImpl( + stageId = 0, + partitionId = 0, + taskAttemptId = 0, + attemptNumber = 0, + taskMemoryManager = taskMemoryManager, + metricsSystem = env.metricsSystem, + internalAccumulators = Seq.empty) + } +} diff --git a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala index 6cae1f871e..885c450d6d 100644 --- a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala @@ -36,27 +36,35 @@ class StaticMemoryManagerSuite extends MemoryManagerSuite { maxExecutionMem: Long, maxStorageMem: Long): (StaticMemoryManager, MemoryStore) = { val mm = new StaticMemoryManager( - conf, maxExecutionMemory = maxExecutionMem, maxStorageMemory = maxStorageMem) + conf, maxExecutionMemory = maxExecutionMem, maxStorageMemory = maxStorageMem, numCores = 1) val ms = makeMemoryStore(mm) (mm, ms) } + override protected def createMemoryManager(maxMemory: Long): MemoryManager = { + new StaticMemoryManager( + conf, + maxExecutionMemory = maxMemory, + maxStorageMemory = 0, + numCores = 1) + } + test("basic execution memory") { val maxExecutionMem = 1000L val (mm, _) = makeThings(maxExecutionMem, Long.MaxValue) assert(mm.executionMemoryUsed === 0L) - assert(mm.acquireExecutionMemory(10L, evictedBlocks) === 10L) + assert(mm.doAcquireExecutionMemory(10L, evictedBlocks) === 10L) assert(mm.executionMemoryUsed === 10L) - assert(mm.acquireExecutionMemory(100L, evictedBlocks) === 100L) + assert(mm.doAcquireExecutionMemory(100L, evictedBlocks) === 100L) // Acquire up to the max - assert(mm.acquireExecutionMemory(1000L, evictedBlocks) === 890L) + assert(mm.doAcquireExecutionMemory(1000L, evictedBlocks) === 890L) assert(mm.executionMemoryUsed === maxExecutionMem) - assert(mm.acquireExecutionMemory(1L, evictedBlocks) === 0L) + assert(mm.doAcquireExecutionMemory(1L, evictedBlocks) === 0L) assert(mm.executionMemoryUsed === maxExecutionMem) mm.releaseExecutionMemory(800L) assert(mm.executionMemoryUsed === 200L) // Acquire after release - assert(mm.acquireExecutionMemory(1L, evictedBlocks) === 1L) + assert(mm.doAcquireExecutionMemory(1L, evictedBlocks) === 1L) assert(mm.executionMemoryUsed === 201L) // Release beyond what was acquired mm.releaseExecutionMemory(maxExecutionMem) @@ -108,10 +116,10 @@ class StaticMemoryManagerSuite extends MemoryManagerSuite { val dummyBlock = TestBlockId("ain't nobody love like you do") val (mm, ms) = makeThings(maxExecutionMem, maxStorageMem) // Only execution memory should increase - assert(mm.acquireExecutionMemory(100L, evictedBlocks) === 100L) + assert(mm.doAcquireExecutionMemory(100L, evictedBlocks) === 100L) assert(mm.storageMemoryUsed === 0L) assert(mm.executionMemoryUsed === 100L) - assert(mm.acquireExecutionMemory(1000L, evictedBlocks) === 100L) + assert(mm.doAcquireExecutionMemory(1000L, evictedBlocks) === 100L) assert(mm.storageMemoryUsed === 0L) assert(mm.executionMemoryUsed === 200L) // Only storage memory should increase diff --git a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala index e7baa50dc2..0c97f2bd89 100644 --- a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala @@ -34,11 +34,15 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes * Make a [[UnifiedMemoryManager]] and a [[MemoryStore]] with limited class dependencies. */ private def makeThings(maxMemory: Long): (UnifiedMemoryManager, MemoryStore) = { - val mm = new UnifiedMemoryManager(conf, maxMemory) + val mm = new UnifiedMemoryManager(conf, maxMemory, numCores = 1) val ms = makeMemoryStore(mm) (mm, ms) } + override protected def createMemoryManager(maxMemory: Long): MemoryManager = { + new UnifiedMemoryManager(conf, maxMemory, numCores = 1) + } + private def getStorageRegionSize(mm: UnifiedMemoryManager): Long = { mm invokePrivate PrivateMethod[Long]('storageRegionSize)() } @@ -56,18 +60,18 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes val maxMemory = 1000L val (mm, _) = makeThings(maxMemory) assert(mm.executionMemoryUsed === 0L) - assert(mm.acquireExecutionMemory(10L, evictedBlocks) === 10L) + assert(mm.doAcquireExecutionMemory(10L, evictedBlocks) === 10L) assert(mm.executionMemoryUsed === 10L) - assert(mm.acquireExecutionMemory(100L, evictedBlocks) === 100L) + assert(mm.doAcquireExecutionMemory(100L, evictedBlocks) === 100L) // Acquire up to the max - assert(mm.acquireExecutionMemory(1000L, evictedBlocks) === 890L) + assert(mm.doAcquireExecutionMemory(1000L, evictedBlocks) === 890L) assert(mm.executionMemoryUsed === maxMemory) - assert(mm.acquireExecutionMemory(1L, evictedBlocks) === 0L) + assert(mm.doAcquireExecutionMemory(1L, evictedBlocks) === 0L) assert(mm.executionMemoryUsed === maxMemory) mm.releaseExecutionMemory(800L) assert(mm.executionMemoryUsed === 200L) // Acquire after release - assert(mm.acquireExecutionMemory(1L, evictedBlocks) === 1L) + assert(mm.doAcquireExecutionMemory(1L, evictedBlocks) === 1L) assert(mm.executionMemoryUsed === 201L) // Release beyond what was acquired mm.releaseExecutionMemory(maxMemory) @@ -132,12 +136,12 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes require(mm.storageMemoryUsed > storageRegionSize, s"bad test: storage memory used should exceed the storage region") // Execution needs to request 250 bytes to evict storage memory - assert(mm.acquireExecutionMemory(100L, evictedBlocks) === 100L) + assert(mm.doAcquireExecutionMemory(100L, evictedBlocks) === 100L) assert(mm.executionMemoryUsed === 100L) assert(mm.storageMemoryUsed === 750L) assertEnsureFreeSpaceNotCalled(ms) // Execution wants 200 bytes but only 150 are free, so storage is evicted - assert(mm.acquireExecutionMemory(200L, evictedBlocks) === 200L) + assert(mm.doAcquireExecutionMemory(200L, evictedBlocks) === 200L) assertEnsureFreeSpaceCalled(ms, 200L) assert(mm.executionMemoryUsed === 300L) mm.releaseAllStorageMemory() @@ -151,7 +155,7 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes s"bad test: storage memory used should be within the storage region") // Execution cannot evict storage because the latter is within the storage fraction, // so grant only what's remaining without evicting anything, i.e. 1000 - 300 - 400 = 300 - assert(mm.acquireExecutionMemory(400L, evictedBlocks) === 300L) + assert(mm.doAcquireExecutionMemory(400L, evictedBlocks) === 300L) assert(mm.executionMemoryUsed === 600L) assert(mm.storageMemoryUsed === 400L) assertEnsureFreeSpaceNotCalled(ms) @@ -170,7 +174,7 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes require(executionRegionSize === expectedExecutionRegionSize, "bad test: storage region size is unexpected") // Acquire enough execution memory to exceed the execution region - assert(mm.acquireExecutionMemory(800L, evictedBlocks) === 800L) + assert(mm.doAcquireExecutionMemory(800L, evictedBlocks) === 800L) assert(mm.executionMemoryUsed === 800L) assert(mm.storageMemoryUsed === 0L) assertEnsureFreeSpaceNotCalled(ms) @@ -188,7 +192,7 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes mm.releaseExecutionMemory(maxMemory) mm.releaseStorageMemory(maxMemory) // Acquire some execution memory again, but this time keep it within the execution region - assert(mm.acquireExecutionMemory(200L, evictedBlocks) === 200L) + assert(mm.doAcquireExecutionMemory(200L, evictedBlocks) === 200L) assert(mm.executionMemoryUsed === 200L) assert(mm.storageMemoryUsed === 0L) assertEnsureFreeSpaceNotCalled(ms) diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala deleted file mode 100644 index 5877aa042d..0000000000 --- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala +++ /dev/null @@ -1,326 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle - -import java.util.concurrent.CountDownLatch -import java.util.concurrent.atomic.AtomicInteger - -import org.mockito.Mockito._ -import org.scalatest.concurrent.Timeouts -import org.scalatest.time.SpanSugar._ - -import org.apache.spark.{SparkFunSuite, TaskContext} -import org.apache.spark.executor.TaskMetrics - -class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { - - val nextTaskAttemptId = new AtomicInteger() - - /** Launch a thread with the given body block and return it. */ - private def startThread(name: String)(body: => Unit): Thread = { - val thread = new Thread("ShuffleMemorySuite " + name) { - override def run() { - try { - val taskAttemptId = nextTaskAttemptId.getAndIncrement - val mockTaskContext = mock(classOf[TaskContext], RETURNS_SMART_NULLS) - val taskMetrics = new TaskMetrics - when(mockTaskContext.taskAttemptId()).thenReturn(taskAttemptId) - when(mockTaskContext.taskMetrics()).thenReturn(taskMetrics) - TaskContext.setTaskContext(mockTaskContext) - body - } finally { - TaskContext.unset() - } - } - } - thread.start() - thread - } - - test("single task requesting memory") { - val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L) - - assert(manager.tryToAcquire(100L) === 100L) - assert(manager.tryToAcquire(400L) === 400L) - assert(manager.tryToAcquire(400L) === 400L) - assert(manager.tryToAcquire(200L) === 100L) - assert(manager.tryToAcquire(100L) === 0L) - assert(manager.tryToAcquire(100L) === 0L) - - manager.release(500L) - assert(manager.tryToAcquire(300L) === 300L) - assert(manager.tryToAcquire(300L) === 200L) - - manager.releaseMemoryForThisTask() - assert(manager.tryToAcquire(1000L) === 1000L) - assert(manager.tryToAcquire(100L) === 0L) - } - - test("two threads requesting full memory") { - // Two threads request 500 bytes first, wait for each other to get it, and then request - // 500 more; we should immediately return 0 as both are now at 1 / N - - val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L) - - class State { - var t1Result1 = -1L - var t2Result1 = -1L - var t1Result2 = -1L - var t2Result2 = -1L - } - val state = new State - - val t1 = startThread("t1") { - val r1 = manager.tryToAcquire(500L) - state.synchronized { - state.t1Result1 = r1 - state.notifyAll() - while (state.t2Result1 === -1L) { - state.wait() - } - } - val r2 = manager.tryToAcquire(500L) - state.synchronized { state.t1Result2 = r2 } - } - - val t2 = startThread("t2") { - val r1 = manager.tryToAcquire(500L) - state.synchronized { - state.t2Result1 = r1 - state.notifyAll() - while (state.t1Result1 === -1L) { - state.wait() - } - } - val r2 = manager.tryToAcquire(500L) - state.synchronized { state.t2Result2 = r2 } - } - - failAfter(20 seconds) { - t1.join() - t2.join() - } - - assert(state.t1Result1 === 500L) - assert(state.t2Result1 === 500L) - assert(state.t1Result2 === 0L) - assert(state.t2Result2 === 0L) - } - - - test("tasks cannot grow past 1 / N") { - // Two tasks request 250 bytes first, wait for each other to get it, and then request - // 500 more; we should only grant 250 bytes to each of them on this second request - - val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L) - - class State { - var t1Result1 = -1L - var t2Result1 = -1L - var t1Result2 = -1L - var t2Result2 = -1L - } - val state = new State - - val t1 = startThread("t1") { - val r1 = manager.tryToAcquire(250L) - state.synchronized { - state.t1Result1 = r1 - state.notifyAll() - while (state.t2Result1 === -1L) { - state.wait() - } - } - val r2 = manager.tryToAcquire(500L) - state.synchronized { state.t1Result2 = r2 } - } - - val t2 = startThread("t2") { - val r1 = manager.tryToAcquire(250L) - state.synchronized { - state.t2Result1 = r1 - state.notifyAll() - while (state.t1Result1 === -1L) { - state.wait() - } - } - val r2 = manager.tryToAcquire(500L) - state.synchronized { state.t2Result2 = r2 } - } - - failAfter(20 seconds) { - t1.join() - t2.join() - } - - assert(state.t1Result1 === 250L) - assert(state.t2Result1 === 250L) - assert(state.t1Result2 === 250L) - assert(state.t2Result2 === 250L) - } - - test("tasks can block to get at least 1 / 2N memory") { - // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps - // for a bit and releases 250 bytes, which should then be granted to t2. Further requests - // by t2 will return false right away because it now has 1 / 2N of the memory. - - val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L) - - class State { - var t1Requested = false - var t2Requested = false - var t1Result = -1L - var t2Result = -1L - var t2Result2 = -1L - var t2WaitTime = 0L - } - val state = new State - - val t1 = startThread("t1") { - state.synchronized { - state.t1Result = manager.tryToAcquire(1000L) - state.t1Requested = true - state.notifyAll() - while (!state.t2Requested) { - state.wait() - } - } - // Sleep a bit before releasing our memory; this is hacky but it would be difficult to make - // sure the other thread blocks for some time otherwise - Thread.sleep(300) - manager.release(250L) - } - - val t2 = startThread("t2") { - state.synchronized { - while (!state.t1Requested) { - state.wait() - } - state.t2Requested = true - state.notifyAll() - } - val startTime = System.currentTimeMillis() - val result = manager.tryToAcquire(250L) - val endTime = System.currentTimeMillis() - state.synchronized { - state.t2Result = result - // A second call should return 0 because we're now already at 1 / 2N - state.t2Result2 = manager.tryToAcquire(100L) - state.t2WaitTime = endTime - startTime - } - } - - failAfter(20 seconds) { - t1.join() - t2.join() - } - - // Both threads should've been able to acquire their memory; the second one will have waited - // until the first one acquired 1000 bytes and then released 250 - state.synchronized { - assert(state.t1Result === 1000L, "t1 could not allocate memory") - assert(state.t2Result === 250L, "t2 could not allocate memory") - assert(state.t2WaitTime > 200, s"t2 waited less than 200 ms (${state.t2WaitTime})") - assert(state.t2Result2 === 0L, "t1 got extra memory the second time") - } - } - - test("releaseMemoryForThisTask") { - // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps - // for a bit and releases all its memory. t2 should now be able to grab all the memory. - - val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L) - - class State { - var t1Requested = false - var t2Requested = false - var t1Result = -1L - var t2Result1 = -1L - var t2Result2 = -1L - var t2Result3 = -1L - var t2WaitTime = 0L - } - val state = new State - - val t1 = startThread("t1") { - state.synchronized { - state.t1Result = manager.tryToAcquire(1000L) - state.t1Requested = true - state.notifyAll() - while (!state.t2Requested) { - state.wait() - } - } - // Sleep a bit before releasing our memory; this is hacky but it would be difficult to make - // sure the other task blocks for some time otherwise - Thread.sleep(300) - manager.releaseMemoryForThisTask() - } - - val t2 = startThread("t2") { - state.synchronized { - while (!state.t1Requested) { - state.wait() - } - state.t2Requested = true - state.notifyAll() - } - val startTime = System.currentTimeMillis() - val r1 = manager.tryToAcquire(500L) - val endTime = System.currentTimeMillis() - val r2 = manager.tryToAcquire(500L) - val r3 = manager.tryToAcquire(500L) - state.synchronized { - state.t2Result1 = r1 - state.t2Result2 = r2 - state.t2Result3 = r3 - state.t2WaitTime = endTime - startTime - } - } - - failAfter(20 seconds) { - t1.join() - t2.join() - } - - // Both tasks should've been able to acquire their memory; the second one will have waited - // until the first one acquired 1000 bytes and then released all of it - state.synchronized { - assert(state.t1Result === 1000L, "t1 could not allocate memory") - assert(state.t2Result1 === 500L, "t2 didn't get 500 bytes the first time") - assert(state.t2Result2 === 500L, "t2 didn't get 500 bytes the second time") - assert(state.t2Result3 === 0L, s"t2 got more bytes a third time (${state.t2Result3})") - assert(state.t2WaitTime > 200, s"t2 waited less than 200 ms (${state.t2WaitTime})") - } - } - - test("tasks should not be granted a negative size") { - val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L) - manager.tryToAcquire(700L) - - val latch = new CountDownLatch(1) - startThread("t1") { - manager.tryToAcquire(300L) - latch.countDown() - } - latch.await() // Wait until `t1` calls `tryToAcquire` - - val granted = manager.tryToAcquire(300L) - assert(0 === granted, "granted is negative") - } -} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index cc44c676b2..6e3f500e15 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -61,7 +61,7 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) - val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem) + val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1) val store = new BlockManager(name, rpcEnv, master, serializer, conf, memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) memManager.setMemoryStore(store.memoryStore) @@ -261,7 +261,7 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo val failableTransfer = mock(classOf[BlockTransferService]) // this wont actually work when(failableTransfer.hostName).thenReturn("some-hostname") when(failableTransfer.port).thenReturn(1000) - val memManager = new StaticMemoryManager(conf, Long.MaxValue, 10000) + val memManager = new StaticMemoryManager(conf, Long.MaxValue, 10000, numCores = 1) val failableStore = new BlockManager("failable-store", rpcEnv, master, serializer, conf, memManager, mapOutputTracker, shuffleManager, failableTransfer, securityMgr, 0) memManager.setMemoryStore(failableStore.memoryStore) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index f3fab33ca2..d49015afcd 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -68,7 +68,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) - val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem) + val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1) val blockManager = new BlockManager(name, rpcEnv, master, serializer, conf, memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) memManager.setMemoryStore(blockManager.memoryStore) @@ -823,7 +823,11 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE test("block store put failure") { // Use Java serializer so we can create an unserializable error. val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) - val memoryManager = new StaticMemoryManager(conf, Long.MaxValue, 1200) + val memoryManager = new StaticMemoryManager( + conf, + maxExecutionMemory = Long.MaxValue, + maxStorageMemory = 1200, + numCores = 1) store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, master, new JavaSerializer(conf), conf, memoryManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index 5cb506ea21..dc3185a6d5 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -21,7 +21,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark._ import org.apache.spark.io.CompressionCodec - +import org.apache.spark.memory.MemoryTestingUtils class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { import TestUtils.{assertNotSpilled, assertSpilled} @@ -32,8 +32,11 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { private def mergeCombiners[T](buf1: ArrayBuffer[T], buf2: ArrayBuffer[T]): ArrayBuffer[T] = buf1 ++= buf2 - private def createExternalMap[T] = new ExternalAppendOnlyMap[T, T, ArrayBuffer[T]]( - createCombiner[T], mergeValue[T], mergeCombiners[T]) + private def createExternalMap[T] = { + val context = MemoryTestingUtils.fakeTaskContext(sc.env) + new ExternalAppendOnlyMap[T, T, ArrayBuffer[T]]( + createCombiner[T], mergeValue[T], mergeCombiners[T], context = context) + } private def createSparkConf(loadDefaults: Boolean, codec: Option[String] = None): SparkConf = { val conf = new SparkConf(loadDefaults) @@ -49,23 +52,27 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { conf } - test("simple insert") { + test("single insert insert") { val conf = createSparkConf(loadDefaults = false) sc = new SparkContext("local", "test", conf) val map = createExternalMap[Int] - - // Single insert map.insert(1, 10) - var it = map.iterator + val it = map.iterator assert(it.hasNext) val kv = it.next() assert(kv._1 === 1 && kv._2 === ArrayBuffer[Int](10)) assert(!it.hasNext) + sc.stop() + } - // Multiple insert + test("multiple insert") { + val conf = createSparkConf(loadDefaults = false) + sc = new SparkContext("local", "test", conf) + val map = createExternalMap[Int] + map.insert(1, 10) map.insert(2, 20) map.insert(3, 30) - it = map.iterator + val it = map.iterator assert(it.hasNext) assert(it.toSet === Set[(Int, ArrayBuffer[Int])]( (1, ArrayBuffer[Int](10)), @@ -144,39 +151,22 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { sc = new SparkContext("local", "test", conf) val map = createExternalMap[Int] + val nullInt = null.asInstanceOf[Int] map.insert(1, 5) map.insert(2, 6) map.insert(3, 7) - assert(map.size === 3) - assert(map.iterator.toSet === Set[(Int, Seq[Int])]( - (1, Seq[Int](5)), - (2, Seq[Int](6)), - (3, Seq[Int](7)) - )) - - // Null keys - val nullInt = null.asInstanceOf[Int] + map.insert(4, nullInt) map.insert(nullInt, 8) - assert(map.size === 4) - assert(map.iterator.toSet === Set[(Int, Seq[Int])]( + map.insert(nullInt, nullInt) + val result = map.iterator.toSet[(Int, ArrayBuffer[Int])].map(kv => (kv._1, kv._2.sorted)) + assert(result === Set[(Int, Seq[Int])]( (1, Seq[Int](5)), (2, Seq[Int](6)), (3, Seq[Int](7)), - (nullInt, Seq[Int](8)) + (4, Seq[Int](nullInt)), + (nullInt, Seq[Int](nullInt, 8)) )) - // Null values - map.insert(4, nullInt) - map.insert(nullInt, nullInt) - assert(map.size === 5) - val result = map.iterator.toSet[(Int, ArrayBuffer[Int])].map(kv => (kv._1, kv._2.toSet)) - assert(result === Set[(Int, Set[Int])]( - (1, Set[Int](5)), - (2, Set[Int](6)), - (3, Set[Int](7)), - (4, Set[Int](nullInt)), - (nullInt, Set[Int](nullInt, 8)) - )) sc.stop() } @@ -344,7 +334,9 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { val conf = createSparkConf(loadDefaults = true) conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) - val map = new ExternalAppendOnlyMap[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _) + val context = MemoryTestingUtils.fakeTaskContext(sc.env) + val map = + new ExternalAppendOnlyMap[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _, context = context) // Insert 10 copies each of lots of objects whose hash codes are either 0 or 1. This causes // problems if the map fails to group together the objects with the same code (SPARK-2043). diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index e2cb791771..d7b2d07a40 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.util.collection +import org.apache.spark.memory.MemoryTestingUtils + import scala.collection.mutable.ArrayBuffer import scala.util.Random @@ -98,6 +100,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { val conf = createSparkConf(loadDefaults = true, kryo = false) conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) + val context = MemoryTestingUtils.fakeTaskContext(sc.env) def createCombiner(i: String): ArrayBuffer[String] = ArrayBuffer[String](i) def mergeValue(buffer: ArrayBuffer[String], i: String): ArrayBuffer[String] = buffer += i @@ -109,7 +112,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { createCombiner _, mergeValue _, mergeCombiners _) val sorter = new ExternalSorter[String, String, ArrayBuffer[String]]( - Some(agg), None, None, None) + context, Some(agg), None, None, None) val collisionPairs = Seq( ("Aa", "BB"), // 2112 @@ -158,8 +161,9 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { val conf = createSparkConf(loadDefaults = true, kryo = false) conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) + val context = MemoryTestingUtils.fakeTaskContext(sc.env) val agg = new Aggregator[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _) - val sorter = new ExternalSorter[FixedHashObject, Int, Int](Some(agg), None, None, None) + val sorter = new ExternalSorter[FixedHashObject, Int, Int](context, Some(agg), None, None, None) // Insert 10 copies each of lots of objects whose hash codes are either 0 or 1. This causes // problems if the map fails to group together the objects with the same code (SPARK-2043). val toInsert = for (i <- 1 to 10; j <- 1 to size) yield (FixedHashObject(j, j % 2), 1) @@ -180,6 +184,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { val conf = createSparkConf(loadDefaults = true, kryo = false) conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) + val context = MemoryTestingUtils.fakeTaskContext(sc.env) def createCombiner(i: Int): ArrayBuffer[Int] = ArrayBuffer[Int](i) def mergeValue(buffer: ArrayBuffer[Int], i: Int): ArrayBuffer[Int] = buffer += i @@ -188,7 +193,8 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { } val agg = new Aggregator[Int, Int, ArrayBuffer[Int]](createCombiner, mergeValue, mergeCombiners) - val sorter = new ExternalSorter[Int, Int, ArrayBuffer[Int]](Some(agg), None, None, None) + val sorter = + new ExternalSorter[Int, Int, ArrayBuffer[Int]](context, Some(agg), None, None, None) sorter.insertAll( (1 to size).iterator.map(i => (i, i)) ++ Iterator((Int.MaxValue, Int.MaxValue))) assert(sorter.numSpills > 0, "sorter did not spill") @@ -204,6 +210,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { val conf = createSparkConf(loadDefaults = true, kryo = false) conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) + val context = MemoryTestingUtils.fakeTaskContext(sc.env) def createCombiner(i: String): ArrayBuffer[String] = ArrayBuffer[String](i) def mergeValue(buffer: ArrayBuffer[String], i: String): ArrayBuffer[String] = buffer += i @@ -214,7 +221,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { createCombiner, mergeValue, mergeCombiners) val sorter = new ExternalSorter[String, String, ArrayBuffer[String]]( - Some(agg), None, None, None) + context, Some(agg), None, None, None) sorter.insertAll((1 to size).iterator.map(i => (i.toString, i.toString)) ++ Iterator( (null.asInstanceOf[String], "1"), @@ -271,31 +278,32 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { private def emptyDataStream(conf: SparkConf) { conf.set("spark.shuffle.manager", "sort") sc = new SparkContext("local", "test", conf) + val context = MemoryTestingUtils.fakeTaskContext(sc.env) val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) val ord = implicitly[Ordering[Int]] // Both aggregator and ordering val sorter = new ExternalSorter[Int, Int, Int]( - Some(agg), Some(new HashPartitioner(3)), Some(ord), None) + context, Some(agg), Some(new HashPartitioner(3)), Some(ord), None) assert(sorter.iterator.toSeq === Seq()) sorter.stop() // Only aggregator val sorter2 = new ExternalSorter[Int, Int, Int]( - Some(agg), Some(new HashPartitioner(3)), None, None) + context, Some(agg), Some(new HashPartitioner(3)), None, None) assert(sorter2.iterator.toSeq === Seq()) sorter2.stop() // Only ordering val sorter3 = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(3)), Some(ord), None) + context, None, Some(new HashPartitioner(3)), Some(ord), None) assert(sorter3.iterator.toSeq === Seq()) sorter3.stop() // Neither aggregator nor ordering val sorter4 = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(3)), None, None) + context, None, Some(new HashPartitioner(3)), None, None) assert(sorter4.iterator.toSeq === Seq()) sorter4.stop() } @@ -303,6 +311,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { private def fewElementsPerPartition(conf: SparkConf) { conf.set("spark.shuffle.manager", "sort") sc = new SparkContext("local", "test", conf) + val context = MemoryTestingUtils.fakeTaskContext(sc.env) val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j) val ord = implicitly[Ordering[Int]] @@ -313,28 +322,28 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { // Both aggregator and ordering val sorter = new ExternalSorter[Int, Int, Int]( - Some(agg), Some(new HashPartitioner(7)), Some(ord), None) + context, Some(agg), Some(new HashPartitioner(7)), Some(ord), None) sorter.insertAll(elements.iterator) assert(sorter.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) sorter.stop() // Only aggregator val sorter2 = new ExternalSorter[Int, Int, Int]( - Some(agg), Some(new HashPartitioner(7)), None, None) + context, Some(agg), Some(new HashPartitioner(7)), None, None) sorter2.insertAll(elements.iterator) assert(sorter2.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) sorter2.stop() // Only ordering val sorter3 = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(7)), Some(ord), None) + context, None, Some(new HashPartitioner(7)), Some(ord), None) sorter3.insertAll(elements.iterator) assert(sorter3.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) sorter3.stop() // Neither aggregator nor ordering val sorter4 = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(7)), None, None) + context, None, Some(new HashPartitioner(7)), None, None) sorter4.insertAll(elements.iterator) assert(sorter4.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected) sorter4.stop() @@ -345,12 +354,13 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { conf.set("spark.shuffle.manager", "sort") conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString) sc = new SparkContext("local", "test", conf) + val context = MemoryTestingUtils.fakeTaskContext(sc.env) val ord = implicitly[Ordering[Int]] val elements = Iterator((1, 1), (5, 5)) ++ (0 until size).iterator.map(x => (2, 2)) val sorter = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(7)), Some(ord), None) + context, None, Some(new HashPartitioner(7)), Some(ord), None) sorter.insertAll(elements) assert(sorter.numSpills > 0, "sorter did not spill") val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList)) @@ -432,8 +442,9 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { val diskBlockManager = sc.env.blockManager.diskBlockManager val ord = implicitly[Ordering[Int]] val expectedSize = if (withFailures) size - 1 else size + val context = MemoryTestingUtils.fakeTaskContext(sc.env) val sorter = new ExternalSorter[Int, Int, Int]( - None, Some(new HashPartitioner(3)), Some(ord), None) + context, None, Some(new HashPartitioner(3)), Some(ord), None) if (withFailures) { intercept[SparkException] { sorter.insertAll((0 until size).iterator.map { i => @@ -501,7 +512,9 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { None } val ord = if (withOrdering) Some(implicitly[Ordering[Int]]) else None - val sorter = new ExternalSorter[Int, Int, Int](agg, Some(new HashPartitioner(3)), ord, None) + val context = MemoryTestingUtils.fakeTaskContext(sc.env) + val sorter = + new ExternalSorter[Int, Int, Int](context, agg, Some(new HashPartitioner(3)), ord, None) sorter.insertAll((0 until size).iterator.map { i => (i / 4, i) }) if (withSpilling) { assert(sorter.numSpills > 0, "sorter did not spill") @@ -538,8 +551,9 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { val testData = Array.tabulate(size) { _ => rand.nextInt().toString } + val context = MemoryTestingUtils.fakeTaskContext(sc.env) val sorter1 = new ExternalSorter[String, String, String]( - None, None, Some(wrongOrdering), None) + context, None, None, Some(wrongOrdering), None) val thrown = intercept[IllegalArgumentException] { sorter1.insertAll(testData.iterator.map(i => (i, i))) assert(sorter1.numSpills > 0, "sorter did not spill") @@ -561,7 +575,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { createCombiner, mergeValue, mergeCombiners) val sorter2 = new ExternalSorter[String, String, ArrayBuffer[String]]( - Some(agg), None, None, None) + context, Some(agg), None, None, None) sorter2.insertAll(testData.iterator.map(i => (i, i))) assert(sorter2.numSpills > 0, "sorter did not spill") diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 7d94e0566f..810c74fd2f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -67,7 +67,6 @@ final class UnsafeExternalRowSorter { final TaskContext taskContext = TaskContext.get(); sorter = UnsafeExternalSorter.create( taskContext.taskMemoryManager(), - sparkEnv.shuffleMemoryManager(), sparkEnv.blockManager(), taskContext, new RowComparator(ordering, schema.length()), diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index 09511ff35f..82c645df28 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -22,7 +22,6 @@ import java.io.IOException; import com.google.common.annotations.VisibleForTesting; import org.apache.spark.SparkEnv; -import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; @@ -32,7 +31,7 @@ import org.apache.spark.unsafe.KVIterator; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.map.BytesToBytesMap; import org.apache.spark.unsafe.memory.MemoryLocation; -import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.memory.TaskMemoryManager; /** * Unsafe-based HashMap for performing aggregations where the aggregated values are fixed-width. @@ -88,8 +87,6 @@ public final class UnsafeFixedWidthAggregationMap { * @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion. * @param groupingKeySchema the schema of the grouping key, used for row conversion. * @param taskMemoryManager the memory manager used to allocate our Unsafe memory structures. - * @param shuffleMemoryManager the shuffle memory manager, for coordinating our memory usage with - * other tasks. * @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing). * @param pageSizeBytes the data page size, in bytes; limits the maximum record size. * @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact) @@ -99,15 +96,14 @@ public final class UnsafeFixedWidthAggregationMap { StructType aggregationBufferSchema, StructType groupingKeySchema, TaskMemoryManager taskMemoryManager, - ShuffleMemoryManager shuffleMemoryManager, int initialCapacity, long pageSizeBytes, boolean enablePerfMetrics) { this.aggregationBufferSchema = aggregationBufferSchema; this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema); this.groupingKeySchema = groupingKeySchema; - this.map = new BytesToBytesMap( - taskMemoryManager, shuffleMemoryManager, initialCapacity, pageSizeBytes, enablePerfMetrics); + this.map = + new BytesToBytesMap(taskMemoryManager, initialCapacity, pageSizeBytes, enablePerfMetrics); this.enablePerfMetrics = enablePerfMetrics; // Initialize the buffer for aggregation value @@ -256,7 +252,7 @@ public final class UnsafeFixedWidthAggregationMap { public UnsafeKVExternalSorter destructAndCreateExternalSorter() throws IOException { UnsafeKVExternalSorter sorter = new UnsafeKVExternalSorter( groupingKeySchema, aggregationBufferSchema, - SparkEnv.get().blockManager(), map.getShuffleMemoryManager(), map.getPageSizeBytes(), map); + SparkEnv.get().blockManager(), map.getPageSizeBytes(), map); return sorter; } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index 9df5780e4f..46301f0042 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -24,7 +24,6 @@ import javax.annotation.Nullable; import com.google.common.annotations.VisibleForTesting; import org.apache.spark.TaskContext; -import org.apache.spark.shuffle.ShuffleMemoryManager; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.catalyst.expressions.codegen.BaseOrdering; import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering; @@ -34,7 +33,7 @@ import org.apache.spark.unsafe.KVIterator; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.map.BytesToBytesMap; import org.apache.spark.unsafe.memory.MemoryBlock; -import org.apache.spark.unsafe.memory.TaskMemoryManager; +import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.util.collection.unsafe.sort.*; /** @@ -50,14 +49,19 @@ public final class UnsafeKVExternalSorter { private final UnsafeExternalRowSorter.PrefixComputer prefixComputer; private final UnsafeExternalSorter sorter; - public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema, - BlockManager blockManager, ShuffleMemoryManager shuffleMemoryManager, long pageSizeBytes) - throws IOException { - this(keySchema, valueSchema, blockManager, shuffleMemoryManager, pageSizeBytes, null); + public UnsafeKVExternalSorter( + StructType keySchema, + StructType valueSchema, + BlockManager blockManager, + long pageSizeBytes) throws IOException { + this(keySchema, valueSchema, blockManager, pageSizeBytes, null); } - public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema, - BlockManager blockManager, ShuffleMemoryManager shuffleMemoryManager, long pageSizeBytes, + public UnsafeKVExternalSorter( + StructType keySchema, + StructType valueSchema, + BlockManager blockManager, + long pageSizeBytes, @Nullable BytesToBytesMap map) throws IOException { this.keySchema = keySchema; this.valueSchema = valueSchema; @@ -73,7 +77,6 @@ public final class UnsafeKVExternalSorter { if (map == null) { sorter = UnsafeExternalSorter.create( taskMemoryManager, - shuffleMemoryManager, blockManager, taskContext, recordComparator, @@ -115,7 +118,6 @@ public final class UnsafeKVExternalSorter { sorter = UnsafeExternalSorter.createWithExistingInMemorySorter( taskContext.taskMemoryManager(), - shuffleMemoryManager, blockManager, taskContext, new KVComparator(ordering, keySchema.length()), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 7cd0f7b81e..fb2fc98e34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.aggregate import scala.collection.mutable.ArrayBuffer import org.apache.spark.unsafe.KVIterator -import org.apache.spark.{InternalAccumulator, Logging, SparkEnv, TaskContext} +import org.apache.spark.{InternalAccumulator, Logging, TaskContext} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner @@ -34,7 +34,7 @@ import org.apache.spark.sql.types.StructType * * This iterator first uses hash-based aggregation to process input rows. It uses * a hash map to store groups and their corresponding aggregation buffers. If we - * this map cannot allocate memory from [[org.apache.spark.shuffle.ShuffleMemoryManager]], + * this map cannot allocate memory from memory manager, * it switches to sort-based aggregation. The process of the switch has the following step: * - Step 1: Sort all entries of the hash map based on values of grouping expressions and * spill them to disk. @@ -480,10 +480,9 @@ class TungstenAggregationIterator( initialAggregationBuffer, StructType.fromAttributes(allAggregateFunctions.flatMap(_.aggBufferAttributes)), StructType.fromAttributes(groupingExpressions.map(_.toAttribute)), - TaskContext.get.taskMemoryManager(), - SparkEnv.get.shuffleMemoryManager, + TaskContext.get().taskMemoryManager(), 1024 * 16, // initial capacity - SparkEnv.get.shuffleMemoryManager.pageSizeBytes, + TaskContext.get().taskMemoryManager().pageSizeBytes, false // disable tracking of performance metrics ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index cfd64c1d9e..1b59b19d94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -344,8 +344,7 @@ private[sql] class DynamicPartitionWriterContainer( StructType.fromAttributes(partitionColumns), StructType.fromAttributes(dataColumns), SparkEnv.get.blockManager, - SparkEnv.get.shuffleMemoryManager, - SparkEnv.get.shuffleMemoryManager.pageSizeBytes) + TaskContext.get().taskMemoryManager().pageSizeBytes) sorter.insertKV(currentKey, getOutputRow(inputRow)) } } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index bc255b2750..cc8abb1ba4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -21,7 +21,7 @@ import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} import java.nio.ByteOrder import java.util.{HashMap => JavaHashMap} -import org.apache.spark.shuffle.ShuffleMemoryManager +import org.apache.spark.memory.{TaskMemoryManager, StaticMemoryManager} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkSqlSerializer @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.local.LocalNode import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap -import org.apache.spark.unsafe.memory.{MemoryLocation, ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} +import org.apache.spark.unsafe.memory.MemoryLocation import org.apache.spark.util.Utils import org.apache.spark.util.collection.CompactBuffer import org.apache.spark.{SparkConf, SparkEnv} @@ -320,21 +320,20 @@ private[joins] final class UnsafeHashedRelation( override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { val nKeys = in.readInt() // This is used in Broadcast, shared by multiple tasks, so we use on-heap memory - val taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) + // TODO(josh): This needs to be revisited before we merge this patch; making this change now + // so that tests compile: + val taskMemoryManager = new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.unsafe.offHeap", "false"), Long.MaxValue, Long.MaxValue, 1), 0) - val pageSizeBytes = Option(SparkEnv.get).map(_.shuffleMemoryManager.pageSizeBytes) + val pageSizeBytes = Option(SparkEnv.get).map(_.memoryManager.pageSizeBytes) .getOrElse(new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "16m")) - // Dummy shuffle memory manager which always grants all memory allocation requests. - // We use this because it doesn't make sense count shared broadcast variables' memory usage - // towards individual tasks' quotas. In the future, we should devise a better way of handling - // this. - val shuffleMemoryManager = - ShuffleMemoryManager.create(maxMemory = Long.MaxValue, pageSizeBytes = pageSizeBytes) + // TODO(josh): We won't need this dummy memory manager after future refactorings; revisit + // during code review binaryMap = new BytesToBytesMap( taskMemoryManager, - shuffleMemoryManager, (nKeys * 1.5 + 1).toInt, // reduce hash collision pageSizeBytes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala index 9385e5734d..dd92dda480 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala @@ -49,7 +49,8 @@ case class Sort( protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { child.execute().mapPartitions( { iterator => val ordering = newOrdering(sortOrder, child.output) - val sorter = new ExternalSorter[InternalRow, Null, InternalRow](ordering = Some(ordering)) + val sorter = new ExternalSorter[InternalRow, Null, InternalRow]( + TaskContext.get(), ordering = Some(ordering)) sorter.insertAll(iterator.map(r => (r.copy(), null))) val baseIterator = sorter.iterator.map(_._1) val context = TaskContext.get() @@ -124,7 +125,7 @@ case class TungstenSort( } } - val pageSize = SparkEnv.get.shuffleMemoryManager.pageSizeBytes + val pageSize = SparkEnv.get.memoryManager.pageSizeBytes val sorter = new UnsafeExternalRowSorter( schema, ordering, prefixComparator, prefixComputer, pageSize) if (testSpillFrequency > 0) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala deleted file mode 100644 index c4358f409b..0000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import scala.collection.mutable - -import org.apache.spark.memory.MemoryManager -import org.apache.spark.shuffle.ShuffleMemoryManager -import org.apache.spark.storage.{BlockId, BlockStatus} - - -/** - * A [[ShuffleMemoryManager]] that can be controlled to run out of memory. - */ -class TestShuffleMemoryManager - extends ShuffleMemoryManager(new GrantEverythingMemoryManager, 4 * 1024 * 1024) { - private var oom = false - - override def tryToAcquire(numBytes: Long): Long = { - if (oom) { - oom = false - 0 - } else { - // Uncomment the following to trace memory allocations. - // println(s"tryToAcquire $numBytes in " + - // Thread.currentThread().getStackTrace.mkString("", "\n -", "")) - val acquired = super.tryToAcquire(numBytes) - acquired - } - } - - override def release(numBytes: Long): Unit = { - // Uncomment the following to trace memory releases. - // println(s"release $numBytes in " + - // Thread.currentThread().getStackTrace.mkString("", "\n -", "")) - super.release(numBytes) - } - - def markAsOutOfMemory(): Unit = { - oom = true - } -} - -private class GrantEverythingMemoryManager extends MemoryManager { - override def acquireExecutionMemory( - numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long = numBytes - override def acquireStorageMemory( - blockId: BlockId, - numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = true - override def acquireUnrollMemory( - blockId: BlockId, - numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = true - override def releaseExecutionMemory(numBytes: Long): Unit = { } - override def releaseStorageMemory(numBytes: Long): Unit = { } - override def maxExecutionMemory: Long = Long.MaxValue - override def maxStorageMemory: Long = Long.MaxValue -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 1739798a24..dbf4863b76 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -23,13 +23,12 @@ import scala.util.{Try, Random} import org.scalatest.Matchers -import org.apache.spark.{TaskContextImpl, TaskContext, SparkFunSuite} -import org.apache.spark.shuffle.ShuffleMemoryManager +import org.apache.spark.{SparkConf, TaskContextImpl, TaskContext, SparkFunSuite} +import org.apache.spark.memory.{TaskMemoryManager, GrantEverythingMemoryManager} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} import org.apache.spark.unsafe.types.UTF8String /** @@ -49,23 +48,22 @@ class UnsafeFixedWidthAggregationMapSuite private def emptyAggregationBuffer: InternalRow = InternalRow(0) private val PAGE_SIZE_BYTES: Long = 1L << 26; // 64 megabytes + private var memoryManager: GrantEverythingMemoryManager = null private var taskMemoryManager: TaskMemoryManager = null - private var shuffleMemoryManager: TestShuffleMemoryManager = null def testWithMemoryLeakDetection(name: String)(f: => Unit) { def cleanup(): Unit = { if (taskMemoryManager != null) { - val leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask() assert(taskMemoryManager.cleanUpAllAllocatedMemory() === 0) - assert(leakedShuffleMemory === 0) taskMemoryManager = null } TaskContext.unset() } test(name) { - taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) - shuffleMemoryManager = new TestShuffleMemoryManager + val conf = new SparkConf().set("spark.unsafe.offHeap", "false") + memoryManager = new GrantEverythingMemoryManager(conf) + taskMemoryManager = new TaskMemoryManager(memoryManager, 0) TaskContext.setTaskContext(new TaskContextImpl( stageId = 0, @@ -110,7 +108,6 @@ class UnsafeFixedWidthAggregationMapSuite aggBufferSchema, groupKeySchema, taskMemoryManager, - shuffleMemoryManager, 1024, // initial capacity, PAGE_SIZE_BYTES, false // disable perf metrics @@ -125,7 +122,6 @@ class UnsafeFixedWidthAggregationMapSuite aggBufferSchema, groupKeySchema, taskMemoryManager, - shuffleMemoryManager, 1024, // initial capacity PAGE_SIZE_BYTES, false // disable perf metrics @@ -153,7 +149,6 @@ class UnsafeFixedWidthAggregationMapSuite aggBufferSchema, groupKeySchema, taskMemoryManager, - shuffleMemoryManager, 128, // initial capacity PAGE_SIZE_BYTES, false // disable perf metrics @@ -176,14 +171,13 @@ class UnsafeFixedWidthAggregationMapSuite testWithMemoryLeakDetection("test external sorting") { // Memory consumption in the beginning of the task. - val initialMemoryConsumption = shuffleMemoryManager.getMemoryConsumptionForThisTask() + val initialMemoryConsumption = taskMemoryManager.getMemoryConsumptionForThisTask() val map = new UnsafeFixedWidthAggregationMap( emptyAggregationBuffer, aggBufferSchema, groupKeySchema, taskMemoryManager, - shuffleMemoryManager, 128, // initial capacity PAGE_SIZE_BYTES, false // disable perf metrics @@ -200,7 +194,7 @@ class UnsafeFixedWidthAggregationMapSuite val sorter = map.destructAndCreateExternalSorter() withClue(s"destructAndCreateExternalSorter should release memory used by the map") { - assert(shuffleMemoryManager.getMemoryConsumptionForThisTask() === initialMemoryConsumption) + assert(taskMemoryManager.getMemoryConsumptionForThisTask() === initialMemoryConsumption) } // Add more keys to the sorter and make sure the results come out sorted. @@ -214,7 +208,7 @@ class UnsafeFixedWidthAggregationMapSuite sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v)) if ((i % 100) == 0) { - shuffleMemoryManager.markAsOutOfMemory() + memoryManager.markExecutionAsOutOfMemory() sorter.closeCurrentPage() } } @@ -238,7 +232,6 @@ class UnsafeFixedWidthAggregationMapSuite aggBufferSchema, groupKeySchema, taskMemoryManager, - shuffleMemoryManager, 128, // initial capacity PAGE_SIZE_BYTES, false // disable perf metrics @@ -258,7 +251,7 @@ class UnsafeFixedWidthAggregationMapSuite sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v)) if ((i % 100) == 0) { - shuffleMemoryManager.markAsOutOfMemory() + memoryManager.markExecutionAsOutOfMemory() sorter.closeCurrentPage() } } @@ -281,14 +274,13 @@ class UnsafeFixedWidthAggregationMapSuite testWithMemoryLeakDetection("test external sorting with empty records") { // Memory consumption in the beginning of the task. - val initialMemoryConsumption = shuffleMemoryManager.getMemoryConsumptionForThisTask() + val initialMemoryConsumption = taskMemoryManager.getMemoryConsumptionForThisTask() val map = new UnsafeFixedWidthAggregationMap( emptyAggregationBuffer, StructType(Nil), StructType(Nil), taskMemoryManager, - shuffleMemoryManager, 128, // initial capacity PAGE_SIZE_BYTES, false // disable perf metrics @@ -303,7 +295,7 @@ class UnsafeFixedWidthAggregationMapSuite val sorter = map.destructAndCreateExternalSorter() withClue(s"destructAndCreateExternalSorter should release memory used by the map") { - assert(shuffleMemoryManager.getMemoryConsumptionForThisTask() === initialMemoryConsumption) + assert(taskMemoryManager.getMemoryConsumptionForThisTask() === initialMemoryConsumption) } // Add more keys to the sorter and make sure the results come out sorted. @@ -311,7 +303,7 @@ class UnsafeFixedWidthAggregationMapSuite sorter.insertKV(UnsafeRow.createFromByteArray(0, 0), UnsafeRow.createFromByteArray(0, 0)) if ((i % 100) == 0) { - shuffleMemoryManager.markAsOutOfMemory() + memoryManager.markExecutionAsOutOfMemory() sorter.closeCurrentPage() } } @@ -332,34 +324,28 @@ class UnsafeFixedWidthAggregationMapSuite } testWithMemoryLeakDetection("convert to external sorter under memory pressure (SPARK-10474)") { - val smm = ShuffleMemoryManager.createForTesting(65536) val pageSize = 4096 val map = new UnsafeFixedWidthAggregationMap( emptyAggregationBuffer, aggBufferSchema, groupKeySchema, taskMemoryManager, - smm, 128, // initial capacity pageSize, false // disable perf metrics ) - // Insert into the map until we've run out of space val rand = new Random(42) - var hasSpace = true - while (hasSpace) { + for (i <- 1 to 100) { val str = rand.nextString(1024) val buf = map.getAggregationBuffer(InternalRow(UTF8String.fromString(str))) - if (buf == null) { - hasSpace = false - } else { - buf.setInt(0, str.length) - } + buf.setInt(0, str.length) } - - // Ensure we're actually maxed out by asserting that we can't acquire even just 1 byte - assert(smm.tryToAcquire(1) === 0) + // Simulate running out of space + memoryManager.markExecutionAsOutOfMemory() + val str = rand.nextString(1024) + val buf = map.getAggregationBuffer(InternalRow(UTF8String.fromString(str))) + assert(buf == null) // Convert the map into a sorter. This used to fail before the fix for SPARK-10474 // because we would try to acquire space for the in-memory sorter pointer array before diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index d3be568a87..13dc1754c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -20,12 +20,12 @@ package org.apache.spark.sql.execution import scala.util.Random import org.apache.spark._ +import org.apache.spark.memory.{TaskMemoryManager, GrantEverythingMemoryManager} import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeRow, UnsafeProjection} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} /** * Test suite for [[UnsafeKVExternalSorter]], with randomly generated test data. @@ -108,9 +108,9 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { inputData: Seq[(InternalRow, InternalRow)], pageSize: Long, spill: Boolean): Unit = { - - val taskMemMgr = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)) - val shuffleMemMgr = new TestShuffleMemoryManager + val memoryManager = + new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")) + val taskMemMgr = new TaskMemoryManager(memoryManager, 0) TaskContext.setTaskContext(new TaskContextImpl( stageId = 0, partitionId = 0, @@ -121,14 +121,14 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { internalAccumulators = Seq.empty)) val sorter = new UnsafeKVExternalSorter( - keySchema, valueSchema, SparkEnv.get.blockManager, shuffleMemMgr, pageSize) + keySchema, valueSchema, SparkEnv.get.blockManager, pageSize) // Insert the keys and values into the sorter inputData.foreach { case (k, v) => sorter.insertKV(k.asInstanceOf[UnsafeRow], v.asInstanceOf[UnsafeRow]) // 1% chance we will spill if (rand.nextDouble() < 0.01 && spill) { - shuffleMemMgr.markAsOutOfMemory() + memoryManager.markExecutionAsOutOfMemory() sorter.closeCurrentPage() } } @@ -170,12 +170,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { assert(out.sorted(kvOrdering) === inputData.sorted(kvOrdering)) // Make sure there is no memory leak - val leakedUnsafeMemory: Long = taskMemMgr.cleanUpAllAllocatedMemory - if (shuffleMemMgr != null) { - val leakedShuffleMemory: Long = shuffleMemMgr.getMemoryConsumptionForThisTask() - assert(0L === leakedShuffleMemory) - } - assert(0 === leakedUnsafeMemory) + assert(0 === taskMemMgr.cleanUpAllAllocatedMemory) TaskContext.unset() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index 1680d7e0a8..d32572b54b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import java.io.{File, ByteArrayInputStream, ByteArrayOutputStream} import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.rdd.RDD import org.apache.spark.storage.ShuffleBlockId import org.apache.spark.util.collection.ExternalSorter @@ -112,7 +113,12 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { val data = (1 to 10000).iterator.map { i => (i, converter(Row(i))) } + val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0) + val taskContext = new TaskContextImpl( + 0, 0, 0, 0, taskMemoryManager, null, InternalAccumulator.create(sc)) + val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow]( + taskContext, partitioner = Some(new HashPartitioner(10)), serializer = Some(new UnsafeRowSerializer(numFields = 1))) @@ -122,10 +128,8 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { assert(sorter.numSpills > 0) // Merging spilled files should not throw assertion error - val taskContext = - new TaskContextImpl(0, 0, 0, 0, null, null, InternalAccumulator.create(sc)) taskContext.taskMetrics.shuffleWriteMetrics = Some(new ShuffleWriteMetrics) - sorter.writePartitionedFile(ShuffleBlockId(0, 0, 0), taskContext, outputFile) + sorter.writePartitionedFile(ShuffleBlockId(0, 0, 0), outputFile) } { // Clean up if (sc != null) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala index cc0ac1b07c..475037bd45 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala @@ -18,16 +18,16 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark._ +import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.InterpretedMutableProjection import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.unsafe.memory.TaskMemoryManager class TungstenAggregationIteratorSuite extends SparkFunSuite with SharedSQLContext { test("memory acquired on construction") { - val taskMemoryManager = new TaskMemoryManager(SparkEnv.get.executorMemoryManager) + val taskMemoryManager = new TaskMemoryManager(SparkEnv.get.memoryManager, 0) val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, null, Seq.empty) TaskContext.setTaskContext(taskContext) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index b2b6848719..c17fb72381 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -254,7 +254,7 @@ class ReceivedBlockHandlerSuite maxMem: Long, conf: SparkConf, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { - val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem) + val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1) val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) val blockManager = new BlockManager(name, rpcEnv, blockManagerMaster, serializer, conf, memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java deleted file mode 100644 index cbbe859462..0000000000 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java +++ /dev/null @@ -1,111 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.unsafe.memory; - -import java.lang.ref.WeakReference; -import java.util.HashMap; -import java.util.LinkedList; -import java.util.Map; -import javax.annotation.concurrent.GuardedBy; - -/** - * Manages memory for an executor. Individual operators / tasks allocate memory through - * {@link TaskMemoryManager} objects, which obtain their memory from ExecutorMemoryManager. - */ -public class ExecutorMemoryManager { - - /** - * Allocator, exposed for enabling untracked allocations of temporary data structures. - */ - public final MemoryAllocator allocator; - - /** - * Tracks whether memory will be allocated on the JVM heap or off-heap using sun.misc.Unsafe. - */ - final boolean inHeap; - - @GuardedBy("this") - private final Map>> bufferPoolsBySize = - new HashMap>>(); - - private static final int POOLING_THRESHOLD_BYTES = 1024 * 1024; - - /** - * Construct a new ExecutorMemoryManager. - * - * @param allocator the allocator that will be used - */ - public ExecutorMemoryManager(MemoryAllocator allocator) { - this.inHeap = allocator instanceof HeapMemoryAllocator; - this.allocator = allocator; - } - - /** - * Returns true if allocations of the given size should go through the pooling mechanism and - * false otherwise. - */ - private boolean shouldPool(long size) { - // Very small allocations are less likely to benefit from pooling. - // At some point, we should explore supporting pooling for off-heap memory, but for now we'll - // ignore that case in the interest of simplicity. - return size >= POOLING_THRESHOLD_BYTES && allocator instanceof HeapMemoryAllocator; - } - - /** - * Allocates a contiguous block of memory. Note that the allocated memory is not guaranteed - * to be zeroed out (call `zero()` on the result if this is necessary). - */ - MemoryBlock allocate(long size) throws OutOfMemoryError { - if (shouldPool(size)) { - synchronized (this) { - final LinkedList> pool = bufferPoolsBySize.get(size); - if (pool != null) { - while (!pool.isEmpty()) { - final WeakReference blockReference = pool.pop(); - final MemoryBlock memory = blockReference.get(); - if (memory != null) { - assert (memory.size() == size); - return memory; - } - } - bufferPoolsBySize.remove(size); - } - } - return allocator.allocate(size); - } else { - return allocator.allocate(size); - } - } - - void free(MemoryBlock memory) { - final long size = memory.size(); - if (shouldPool(size)) { - synchronized (this) { - LinkedList> pool = bufferPoolsBySize.get(size); - if (pool == null) { - pool = new LinkedList>(); - bufferPoolsBySize.put(size, pool); - } - pool.add(new WeakReference(memory)); - } - } else { - allocator.free(memory); - } - } - -} diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java index 6722301df1..ebe90d9e63 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java @@ -17,22 +17,71 @@ package org.apache.spark.unsafe.memory; +import javax.annotation.concurrent.GuardedBy; +import java.lang.ref.WeakReference; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.Map; + /** * A simple {@link MemoryAllocator} that can allocate up to 16GB using a JVM long primitive array. */ public class HeapMemoryAllocator implements MemoryAllocator { + @GuardedBy("this") + private final Map>> bufferPoolsBySize = + new HashMap<>(); + + private static final int POOLING_THRESHOLD_BYTES = 1024 * 1024; + + /** + * Returns true if allocations of the given size should go through the pooling mechanism and + * false otherwise. + */ + private boolean shouldPool(long size) { + // Very small allocations are less likely to benefit from pooling. + return size >= POOLING_THRESHOLD_BYTES; + } + @Override public MemoryBlock allocate(long size) throws OutOfMemoryError { if (size % 8 != 0) { throw new IllegalArgumentException("Size " + size + " was not a multiple of 8"); } + if (shouldPool(size)) { + synchronized (this) { + final LinkedList> pool = bufferPoolsBySize.get(size); + if (pool != null) { + while (!pool.isEmpty()) { + final WeakReference blockReference = pool.pop(); + final MemoryBlock memory = blockReference.get(); + if (memory != null) { + assert (memory.size() == size); + return memory; + } + } + bufferPoolsBySize.remove(size); + } + } + } long[] array = new long[(int) (size / 8)]; return MemoryBlock.fromLongArray(array); } @Override public void free(MemoryBlock memory) { - // Do nothing + final long size = memory.size(); + if (shouldPool(size)) { + synchronized (this) { + LinkedList> pool = bufferPoolsBySize.get(size); + if (pool == null) { + pool = new LinkedList<>(); + bufferPoolsBySize.put(size, pool); + } + pool.add(new WeakReference<>(memory)); + } + } else { + // Do nothing + } } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java index dd75820834..e3e7947115 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java @@ -30,9 +30,10 @@ public class MemoryBlock extends MemoryLocation { /** * Optional page number; used when this MemoryBlock represents a page allocated by a - * MemoryManager. This is package-private and is modified by MemoryManager. + * TaskMemoryManager. This field is public so that it can be modified by the TaskMemoryManager, + * which lives in a different package. */ - int pageNumber = -1; + public int pageNumber = -1; public MemoryBlock(@Nullable Object obj, long offset, long length) { super(obj, offset); diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java deleted file mode 100644 index 97b2c93f0d..0000000000 --- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java +++ /dev/null @@ -1,286 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.unsafe.memory; - -import java.util.*; - -import com.google.common.annotations.VisibleForTesting; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Manages the memory allocated by an individual task. - *

- * Most of the complexity in this class deals with encoding of off-heap addresses into 64-bit longs. - * In off-heap mode, memory can be directly addressed with 64-bit longs. In on-heap mode, memory is - * addressed by the combination of a base Object reference and a 64-bit offset within that object. - * This is a problem when we want to store pointers to data structures inside of other structures, - * such as record pointers inside hashmaps or sorting buffers. Even if we decided to use 128 bits - * to address memory, we can't just store the address of the base object since it's not guaranteed - * to remain stable as the heap gets reorganized due to GC. - *

- * Instead, we use the following approach to encode record pointers in 64-bit longs: for off-heap - * mode, just store the raw address, and for on-heap mode use the upper 13 bits of the address to - * store a "page number" and the lower 51 bits to store an offset within this page. These page - * numbers are used to index into a "page table" array inside of the MemoryManager in order to - * retrieve the base object. - *

- * This allows us to address 8192 pages. In on-heap mode, the maximum page size is limited by the - * maximum size of a long[] array, allowing us to address 8192 * 2^32 * 8 bytes, which is - * approximately 35 terabytes of memory. - */ -public class TaskMemoryManager { - - private final Logger logger = LoggerFactory.getLogger(TaskMemoryManager.class); - - /** The number of bits used to address the page table. */ - private static final int PAGE_NUMBER_BITS = 13; - - /** The number of bits used to encode offsets in data pages. */ - @VisibleForTesting - static final int OFFSET_BITS = 64 - PAGE_NUMBER_BITS; // 51 - - /** The number of entries in the page table. */ - private static final int PAGE_TABLE_SIZE = 1 << PAGE_NUMBER_BITS; - - /** - * Maximum supported data page size (in bytes). In principle, the maximum addressable page size is - * (1L << OFFSET_BITS) bytes, which is 2+ petabytes. However, the on-heap allocator's maximum page - * size is limited by the maximum amount of data that can be stored in a long[] array, which is - * (2^32 - 1) * 8 bytes (or 16 gigabytes). Therefore, we cap this at 16 gigabytes. - */ - public static final long MAXIMUM_PAGE_SIZE_BYTES = ((1L << 31) - 1) * 8L; - - /** Bit mask for the lower 51 bits of a long. */ - private static final long MASK_LONG_LOWER_51_BITS = 0x7FFFFFFFFFFFFL; - - /** Bit mask for the upper 13 bits of a long */ - private static final long MASK_LONG_UPPER_13_BITS = ~MASK_LONG_LOWER_51_BITS; - - /** - * Similar to an operating system's page table, this array maps page numbers into base object - * pointers, allowing us to translate between the hashtable's internal 64-bit address - * representation and the baseObject+offset representation which we use to support both in- and - * off-heap addresses. When using an off-heap allocator, every entry in this map will be `null`. - * When using an in-heap allocator, the entries in this map will point to pages' base objects. - * Entries are added to this map as new data pages are allocated. - */ - private final MemoryBlock[] pageTable = new MemoryBlock[PAGE_TABLE_SIZE]; - - /** - * Bitmap for tracking free pages. - */ - private final BitSet allocatedPages = new BitSet(PAGE_TABLE_SIZE); - - /** - * Tracks memory allocated with {@link TaskMemoryManager#allocate(long)}, used to detect / clean - * up leaked memory. - */ - private final HashSet allocatedNonPageMemory = new HashSet(); - - private final ExecutorMemoryManager executorMemoryManager; - - /** - * Tracks whether we're in-heap or off-heap. For off-heap, we short-circuit most of these methods - * without doing any masking or lookups. Since this branching should be well-predicted by the JIT, - * this extra layer of indirection / abstraction hopefully shouldn't be too expensive. - */ - private final boolean inHeap; - - /** - * Construct a new MemoryManager. - */ - public TaskMemoryManager(ExecutorMemoryManager executorMemoryManager) { - this.inHeap = executorMemoryManager.inHeap; - this.executorMemoryManager = executorMemoryManager; - } - - /** - * Allocate a block of memory that will be tracked in the MemoryManager's page table; this is - * intended for allocating large blocks of memory that will be shared between operators. - */ - public MemoryBlock allocatePage(long size) { - if (size > MAXIMUM_PAGE_SIZE_BYTES) { - throw new IllegalArgumentException( - "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE_BYTES + " bytes"); - } - - final int pageNumber; - synchronized (this) { - pageNumber = allocatedPages.nextClearBit(0); - if (pageNumber >= PAGE_TABLE_SIZE) { - throw new IllegalStateException( - "Have already allocated a maximum of " + PAGE_TABLE_SIZE + " pages"); - } - allocatedPages.set(pageNumber); - } - final MemoryBlock page = executorMemoryManager.allocate(size); - page.pageNumber = pageNumber; - pageTable[pageNumber] = page; - if (logger.isTraceEnabled()) { - logger.trace("Allocate page number {} ({} bytes)", pageNumber, size); - } - return page; - } - - /** - * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage(long)}. - */ - public void freePage(MemoryBlock page) { - assert (page.pageNumber != -1) : - "Called freePage() on memory that wasn't allocated with allocatePage()"; - assert(allocatedPages.get(page.pageNumber)); - pageTable[page.pageNumber] = null; - synchronized (this) { - allocatedPages.clear(page.pageNumber); - } - if (logger.isTraceEnabled()) { - logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size()); - } - // Cannot access a page once it's freed. - executorMemoryManager.free(page); - } - - /** - * Allocates a contiguous block of memory. Note that the allocated memory is not guaranteed - * to be zeroed out (call `zero()` on the result if this is necessary). This method is intended - * to be used for allocating operators' internal data structures. For data pages that you want to - * exchange between operators, consider using {@link TaskMemoryManager#allocatePage(long)}, since - * that will enable intra-memory pointers (see - * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} and this class's - * top-level Javadoc for more details). - */ - public MemoryBlock allocate(long size) throws OutOfMemoryError { - assert(size > 0) : "Size must be positive, but got " + size; - final MemoryBlock memory = executorMemoryManager.allocate(size); - synchronized(allocatedNonPageMemory) { - allocatedNonPageMemory.add(memory); - } - return memory; - } - - /** - * Free memory allocated by {@link TaskMemoryManager#allocate(long)}. - */ - public void free(MemoryBlock memory) { - assert (memory.pageNumber == -1) : "Should call freePage() for pages, not free()"; - executorMemoryManager.free(memory); - synchronized(allocatedNonPageMemory) { - final boolean wasAlreadyRemoved = !allocatedNonPageMemory.remove(memory); - assert (!wasAlreadyRemoved) : "Called free() on memory that was already freed!"; - } - } - - /** - * Given a memory page and offset within that page, encode this address into a 64-bit long. - * This address will remain valid as long as the corresponding page has not been freed. - * - * @param page a data page allocated by {@link TaskMemoryManager#allocate(long)}. - * @param offsetInPage an offset in this page which incorporates the base offset. In other words, - * this should be the value that you would pass as the base offset into an - * UNSAFE call (e.g. page.baseOffset() + something). - * @return an encoded page address. - */ - public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) { - if (!inHeap) { - // In off-heap mode, an offset is an absolute address that may require a full 64 bits to - // encode. Due to our page size limitation, though, we can convert this into an offset that's - // relative to the page's base offset; this relative offset will fit in 51 bits. - offsetInPage -= page.getBaseOffset(); - } - return encodePageNumberAndOffset(page.pageNumber, offsetInPage); - } - - @VisibleForTesting - public static long encodePageNumberAndOffset(int pageNumber, long offsetInPage) { - assert (pageNumber != -1) : "encodePageNumberAndOffset called with invalid page"; - return (((long) pageNumber) << OFFSET_BITS) | (offsetInPage & MASK_LONG_LOWER_51_BITS); - } - - @VisibleForTesting - public static int decodePageNumber(long pagePlusOffsetAddress) { - return (int) ((pagePlusOffsetAddress & MASK_LONG_UPPER_13_BITS) >>> OFFSET_BITS); - } - - private static long decodeOffset(long pagePlusOffsetAddress) { - return (pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS); - } - - /** - * Get the page associated with an address encoded by - * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} - */ - public Object getPage(long pagePlusOffsetAddress) { - if (inHeap) { - final int pageNumber = decodePageNumber(pagePlusOffsetAddress); - assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); - final MemoryBlock page = pageTable[pageNumber]; - assert (page != null); - assert (page.getBaseObject() != null); - return page.getBaseObject(); - } else { - return null; - } - } - - /** - * Get the offset associated with an address encoded by - * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} - */ - public long getOffsetInPage(long pagePlusOffsetAddress) { - final long offsetInPage = decodeOffset(pagePlusOffsetAddress); - if (inHeap) { - return offsetInPage; - } else { - // In off-heap mode, an offset is an absolute address. In encodePageNumberAndOffset, we - // converted the absolute address into a relative address. Here, we invert that operation: - final int pageNumber = decodePageNumber(pagePlusOffsetAddress); - assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); - final MemoryBlock page = pageTable[pageNumber]; - assert (page != null); - return page.getBaseOffset() + offsetInPage; - } - } - - /** - * Clean up all allocated memory and pages. Returns the number of bytes freed. A non-zero return - * value can be used to detect memory leaks. - */ - public long cleanUpAllAllocatedMemory() { - long freedBytes = 0; - for (MemoryBlock page : pageTable) { - if (page != null) { - freedBytes += page.size(); - freePage(page); - } - } - - synchronized (allocatedNonPageMemory) { - final Iterator iter = allocatedNonPageMemory.iterator(); - while (iter.hasNext()) { - final MemoryBlock memory = iter.next(); - freedBytes += memory.size(); - // We don't call free() here because that calls Set.remove, which would lead to a - // ConcurrentModificationException here. - executorMemoryManager.free(memory); - iter.remove(); - } - } - return freedBytes; - } -} diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java deleted file mode 100644 index 06fb081183..0000000000 --- a/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.unsafe.memory; - -import org.junit.Assert; -import org.junit.Test; - -public class TaskMemoryManagerSuite { - - @Test - public void leakedNonPageMemoryIsDetected() { - final TaskMemoryManager manager = - new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); - manager.allocate(1024); // leak memory - Assert.assertEquals(1024, manager.cleanUpAllAllocatedMemory()); - } - - @Test - public void leakedPageMemoryIsDetected() { - final TaskMemoryManager manager = - new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); - manager.allocatePage(4096); // leak memory - Assert.assertEquals(4096, manager.cleanUpAllAllocatedMemory()); - } - - @Test - public void encodePageNumberAndOffsetOffHeap() { - final TaskMemoryManager manager = - new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE)); - final MemoryBlock dataPage = manager.allocatePage(256); - // In off-heap mode, an offset is an absolute address that may require more than 51 bits to - // encode. This test exercises that corner-case: - final long offset = ((1L << TaskMemoryManager.OFFSET_BITS) + 10); - final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, offset); - Assert.assertEquals(null, manager.getPage(encodedAddress)); - Assert.assertEquals(offset, manager.getOffsetInPage(encodedAddress)); - } - - @Test - public void encodePageNumberAndOffsetOnHeap() { - final TaskMemoryManager manager = - new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); - final MemoryBlock dataPage = manager.allocatePage(256); - final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, 64); - Assert.assertEquals(dataPage.getBaseObject(), manager.getPage(encodedAddress)); - Assert.assertEquals(64, manager.getOffsetInPage(encodedAddress)); - } - -} -- cgit v1.2.3