aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2015-10-25 21:19:52 -0700
committerJosh Rosen <joshrosen@databricks.com>2015-10-25 21:19:52 -0700
commit85e654c5ec87e666a8845bfd77185c1ea57b268a (patch)
tree2beadbc8fbb54369325970a4e2c7189506efad89
parent63accc79625d8a03d0624717af5e1d81b18a6da3 (diff)
downloadspark-85e654c5ec87e666a8845bfd77185c1ea57b268a.tar.gz
spark-85e654c5ec87e666a8845bfd77185c1ea57b268a.tar.bz2
spark-85e654c5ec87e666a8845bfd77185c1ea57b268a.zip
[SPARK-10984] Simplify *MemoryManager class structure
This patch refactors the MemoryManager class structure. After #9000, Spark had the following classes: - MemoryManager - StaticMemoryManager - ExecutorMemoryManager - TaskMemoryManager - ShuffleMemoryManager This is fairly confusing. To simplify things, this patch consolidates several of these classes: - ShuffleMemoryManager and ExecutorMemoryManager were merged into MemoryManager. - TaskMemoryManager is moved into Spark Core. **Key changes and tasks**: - [x] Merge ExecutorMemoryManager into MemoryManager. - [x] Move pooling logic into Allocator. - [x] Move TaskMemoryManager from `spark-unsafe` to `spark-core`. - [x] Refactor the existing Tungsten TaskMemoryManager interactions so Tungsten code use only this and not both this and ShuffleMemoryManager. - [x] Refactor non-Tungsten code to use the TaskMemoryManager instead of ShuffleMemoryManager. - [x] Merge ShuffleMemoryManager into MemoryManager. - [x] Move code - [x] ~~Simplify 1/n calculation.~~ **Will defer to followup, since this needs more work.** - [x] Port ShuffleMemoryManagerSuite tests. - [x] Move classes from `unsafe` package to `memory` package. - [ ] Figure out how to handle the hacky use of the memory managers in HashedRelation's broadcast variable construction. - [x] Test porting and cleanup: several tests relied on mock functionality (such as `TestShuffleMemoryManager.markAsOutOfMemory`) which has been changed or broken during the memory manager consolidation - [x] AbstractBytesToBytesMapSuite - [x] UnsafeExternalSorterSuite - [x] UnsafeFixedWidthAggregationMapSuite - [x] UnsafeKVExternalSorterSuite **Compatiblity notes**: - This patch introduces breaking changes in `ExternalAppendOnlyMap`, which is marked as `DevloperAPI` (likely for legacy reasons): this class now cannot be used outside of a task. Author: Josh Rosen <joshrosen@databricks.com> Closes #9127 from JoshRosen/SPARK-10984.
-rw-r--r--core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java (renamed from unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java)111
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java4
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java57
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java7
-rw-r--r--core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java36
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java4
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java51
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java2
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala23
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContext.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/TaskContextImpl.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/memory/MemoryManager.scala197
-rw-r--r--core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/Task.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala209
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala49
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/Spillable.scala16
-rw-r--r--core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java (renamed from unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java)25
-rw-r--r--core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java12
-rw-r--r--core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java9
-rw-r--r--core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java53
-rw-r--r--core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java108
-rw-r--r--core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java7
-rw-r--r--core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java7
-rw-r--r--core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java34
-rw-r--r--core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java13
-rw-r--r--core/src/test/scala/org/apache/spark/FailureSuite.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/memory/GrantEverythingMemoryManager.scala (renamed from sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala)51
-rw-r--r--core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala134
-rw-r--r--core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala37
-rw-r--r--core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala24
-rw-r--r--core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala26
-rw-r--r--core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala326
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala8
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala60
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala48
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java1
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java12
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java22
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala21
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala5
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala54
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala19
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala4
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala2
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java111
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java51
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java5
58 files changed, 888 insertions, 1255 deletions
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
index 97b2c93f0d..7b31c90dac 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
+++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.unsafe.memory;
+package org.apache.spark.memory;
import java.util.*;
@@ -23,6 +23,8 @@ import com.google.common.annotations.VisibleForTesting;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+
/**
* Manages the memory allocated by an individual task.
* <p>
@@ -87,13 +89,9 @@ public class TaskMemoryManager {
*/
private final BitSet allocatedPages = new BitSet(PAGE_TABLE_SIZE);
- /**
- * Tracks memory allocated with {@link TaskMemoryManager#allocate(long)}, used to detect / clean
- * up leaked memory.
- */
- private final HashSet<MemoryBlock> allocatedNonPageMemory = new HashSet<MemoryBlock>();
+ private final MemoryManager memoryManager;
- private final ExecutorMemoryManager executorMemoryManager;
+ private final long taskAttemptId;
/**
* Tracks whether we're in-heap or off-heap. For off-heap, we short-circuit most of these methods
@@ -103,16 +101,38 @@ public class TaskMemoryManager {
private final boolean inHeap;
/**
- * Construct a new MemoryManager.
+ * Construct a new TaskMemoryManager.
*/
- public TaskMemoryManager(ExecutorMemoryManager executorMemoryManager) {
- this.inHeap = executorMemoryManager.inHeap;
- this.executorMemoryManager = executorMemoryManager;
+ public TaskMemoryManager(MemoryManager memoryManager, long taskAttemptId) {
+ this.inHeap = memoryManager.tungstenMemoryIsAllocatedInHeap();
+ this.memoryManager = memoryManager;
+ this.taskAttemptId = taskAttemptId;
+ }
+
+ /**
+ * Acquire N bytes of memory for execution, evicting cached blocks if necessary.
+ * @return number of bytes successfully granted (<= N).
+ */
+ public long acquireExecutionMemory(long size) {
+ return memoryManager.acquireExecutionMemory(size, taskAttemptId);
+ }
+
+ /**
+ * Release N bytes of execution memory.
+ */
+ public void releaseExecutionMemory(long size) {
+ memoryManager.releaseExecutionMemory(size, taskAttemptId);
+ }
+
+ public long pageSizeBytes() {
+ return memoryManager.pageSizeBytes();
}
/**
* Allocate a block of memory that will be tracked in the MemoryManager's page table; this is
- * intended for allocating large blocks of memory that will be shared between operators.
+ * intended for allocating large blocks of Tungsten memory that will be shared between operators.
+ *
+ * Returns `null` if there was not enough memory to allocate the page.
*/
public MemoryBlock allocatePage(long size) {
if (size > MAXIMUM_PAGE_SIZE_BYTES) {
@@ -129,7 +149,15 @@ public class TaskMemoryManager {
}
allocatedPages.set(pageNumber);
}
- final MemoryBlock page = executorMemoryManager.allocate(size);
+ final long acquiredExecutionMemory = acquireExecutionMemory(size);
+ if (acquiredExecutionMemory != size) {
+ releaseExecutionMemory(acquiredExecutionMemory);
+ synchronized (this) {
+ allocatedPages.clear(pageNumber);
+ }
+ return null;
+ }
+ final MemoryBlock page = memoryManager.tungstenMemoryAllocator().allocate(size);
page.pageNumber = pageNumber;
pageTable[pageNumber] = page;
if (logger.isTraceEnabled()) {
@@ -152,45 +180,16 @@ public class TaskMemoryManager {
if (logger.isTraceEnabled()) {
logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size());
}
- // Cannot access a page once it's freed.
- executorMemoryManager.free(page);
- }
-
- /**
- * Allocates a contiguous block of memory. Note that the allocated memory is not guaranteed
- * to be zeroed out (call `zero()` on the result if this is necessary). This method is intended
- * to be used for allocating operators' internal data structures. For data pages that you want to
- * exchange between operators, consider using {@link TaskMemoryManager#allocatePage(long)}, since
- * that will enable intra-memory pointers (see
- * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} and this class's
- * top-level Javadoc for more details).
- */
- public MemoryBlock allocate(long size) throws OutOfMemoryError {
- assert(size > 0) : "Size must be positive, but got " + size;
- final MemoryBlock memory = executorMemoryManager.allocate(size);
- synchronized(allocatedNonPageMemory) {
- allocatedNonPageMemory.add(memory);
- }
- return memory;
- }
-
- /**
- * Free memory allocated by {@link TaskMemoryManager#allocate(long)}.
- */
- public void free(MemoryBlock memory) {
- assert (memory.pageNumber == -1) : "Should call freePage() for pages, not free()";
- executorMemoryManager.free(memory);
- synchronized(allocatedNonPageMemory) {
- final boolean wasAlreadyRemoved = !allocatedNonPageMemory.remove(memory);
- assert (!wasAlreadyRemoved) : "Called free() on memory that was already freed!";
- }
+ long pageSize = page.size();
+ memoryManager.tungstenMemoryAllocator().free(page);
+ releaseExecutionMemory(pageSize);
}
/**
* Given a memory page and offset within that page, encode this address into a 64-bit long.
* This address will remain valid as long as the corresponding page has not been freed.
*
- * @param page a data page allocated by {@link TaskMemoryManager#allocate(long)}.
+ * @param page a data page allocated by {@link TaskMemoryManager#allocatePage(long)}/
* @param offsetInPage an offset in this page which incorporates the base offset. In other words,
* this should be the value that you would pass as the base offset into an
* UNSAFE call (e.g. page.baseOffset() + something).
@@ -270,17 +269,15 @@ public class TaskMemoryManager {
}
}
- synchronized (allocatedNonPageMemory) {
- final Iterator<MemoryBlock> iter = allocatedNonPageMemory.iterator();
- while (iter.hasNext()) {
- final MemoryBlock memory = iter.next();
- freedBytes += memory.size();
- // We don't call free() here because that calls Set.remove, which would lead to a
- // ConcurrentModificationException here.
- executorMemoryManager.free(memory);
- iter.remove();
- }
- }
+ freedBytes += memoryManager.releaseAllExecutionMemoryForTask(taskAttemptId);
+
return freedBytes;
}
+
+ /**
+ * Returns the memory consumption, in bytes, for the current task
+ */
+ public long getMemoryConsumptionForThisTask() {
+ return memoryManager.getExecutionMemoryUsageForTask(taskAttemptId);
+ }
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java b/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java
index c11711966f..f8f2b220e1 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java
@@ -17,6 +17,8 @@
package org.apache.spark.shuffle.sort;
+import org.apache.spark.memory.TaskMemoryManager;
+
/**
* Wrapper around an 8-byte word that holds a 24-bit partition number and 40-bit record pointer.
* <p>
@@ -26,7 +28,7 @@ package org.apache.spark.shuffle.sort;
* </pre>
* This implies that the maximum addressable page size is 2^27 bits = 128 megabytes, assuming that
* our offsets in pages are not 8-byte-word-aligned. Since we have 2^13 pages (based off the
- * 13-bit page numbers assigned by {@link org.apache.spark.unsafe.memory.TaskMemoryManager}), this
+ * 13-bit page numbers assigned by {@link TaskMemoryManager}), this
* implies that we can address 2^13 * 128 megabytes = 1 terabyte of RAM per task.
* <p>
* Assuming word-alignment would allow for a 1 gigabyte maximum page size, but we leave this
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
index 85fdaa8115..f43236f41a 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
@@ -33,14 +33,13 @@ import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.serializer.DummySerializerInstance;
import org.apache.spark.serializer.SerializerInstance;
-import org.apache.spark.shuffle.ShuffleMemoryManager;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.DiskBlockObjectWriter;
import org.apache.spark.storage.TempShuffleBlockId;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.memory.MemoryBlock;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.util.Utils;
/**
@@ -72,7 +71,6 @@ final class ShuffleExternalSorter {
@VisibleForTesting
final int maxRecordSizeBytes;
private final TaskMemoryManager taskMemoryManager;
- private final ShuffleMemoryManager shuffleMemoryManager;
private final BlockManager blockManager;
private final TaskContext taskContext;
private final ShuffleWriteMetrics writeMetrics;
@@ -105,7 +103,6 @@ final class ShuffleExternalSorter {
public ShuffleExternalSorter(
TaskMemoryManager memoryManager,
- ShuffleMemoryManager shuffleMemoryManager,
BlockManager blockManager,
TaskContext taskContext,
int initialSize,
@@ -113,7 +110,6 @@ final class ShuffleExternalSorter {
SparkConf conf,
ShuffleWriteMetrics writeMetrics) throws IOException {
this.taskMemoryManager = memoryManager;
- this.shuffleMemoryManager = shuffleMemoryManager;
this.blockManager = blockManager;
this.taskContext = taskContext;
this.initialSize = initialSize;
@@ -124,7 +120,7 @@ final class ShuffleExternalSorter {
this.numElementsForSpillThreshold =
conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MAX_VALUE);
this.pageSizeBytes = (int) Math.min(
- PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, shuffleMemoryManager.pageSizeBytes());
+ PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, taskMemoryManager.pageSizeBytes());
this.maxRecordSizeBytes = pageSizeBytes - 4;
this.writeMetrics = writeMetrics;
initializeForWriting();
@@ -140,9 +136,9 @@ final class ShuffleExternalSorter {
private void initializeForWriting() throws IOException {
// TODO: move this sizing calculation logic into a static method of sorter:
final long memoryRequested = initialSize * 8L;
- final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested);
+ final long memoryAcquired = taskMemoryManager.acquireExecutionMemory(memoryRequested);
if (memoryAcquired != memoryRequested) {
- shuffleMemoryManager.release(memoryAcquired);
+ taskMemoryManager.releaseExecutionMemory(memoryAcquired);
throw new IOException("Could not acquire " + memoryRequested + " bytes of memory");
}
@@ -272,6 +268,7 @@ final class ShuffleExternalSorter {
*/
@VisibleForTesting
void spill() throws IOException {
+ assert(inMemSorter != null);
logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)",
Thread.currentThread().getId(),
Utils.bytesToString(getMemoryUsage()),
@@ -281,7 +278,7 @@ final class ShuffleExternalSorter {
writeSortedFile(false);
final long inMemSorterMemoryUsage = inMemSorter.getMemoryUsage();
inMemSorter = null;
- shuffleMemoryManager.release(inMemSorterMemoryUsage);
+ taskMemoryManager.releaseExecutionMemory(inMemSorterMemoryUsage);
final long spillSize = freeMemory();
taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
@@ -316,9 +313,13 @@ final class ShuffleExternalSorter {
long memoryFreed = 0;
for (MemoryBlock block : allocatedPages) {
taskMemoryManager.freePage(block);
- shuffleMemoryManager.release(block.size());
memoryFreed += block.size();
}
+ if (inMemSorter != null) {
+ long sorterMemoryUsage = inMemSorter.getMemoryUsage();
+ inMemSorter = null;
+ taskMemoryManager.releaseExecutionMemory(sorterMemoryUsage);
+ }
allocatedPages.clear();
currentPage = null;
currentPagePosition = -1;
@@ -337,8 +338,9 @@ final class ShuffleExternalSorter {
}
}
if (inMemSorter != null) {
- shuffleMemoryManager.release(inMemSorter.getMemoryUsage());
+ long sorterMemoryUsage = inMemSorter.getMemoryUsage();
inMemSorter = null;
+ taskMemoryManager.releaseExecutionMemory(sorterMemoryUsage);
}
}
@@ -353,21 +355,20 @@ final class ShuffleExternalSorter {
logger.debug("Attempting to expand sort pointer array");
final long oldPointerArrayMemoryUsage = inMemSorter.getMemoryUsage();
final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2;
- final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray);
+ final long memoryAcquired = taskMemoryManager.acquireExecutionMemory(memoryToGrowPointerArray);
if (memoryAcquired < memoryToGrowPointerArray) {
- shuffleMemoryManager.release(memoryAcquired);
+ taskMemoryManager.releaseExecutionMemory(memoryAcquired);
spill();
} else {
inMemSorter.expandPointerArray();
- shuffleMemoryManager.release(oldPointerArrayMemoryUsage);
+ taskMemoryManager.releaseExecutionMemory(oldPointerArrayMemoryUsage);
}
}
}
-
+
/**
* Allocates more memory in order to insert an additional record. This will request additional
- * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be
- * obtained.
+ * memory from the memory manager and spill if the requested memory can not be obtained.
*
* @param requiredSpace the required space in the data page, in bytes, including space for storing
* the record size. This must be less than or equal to the page size (records
@@ -386,17 +387,14 @@ final class ShuffleExternalSorter {
throw new IOException("Required space " + requiredSpace + " is greater than page size (" +
pageSizeBytes + ")");
} else {
- final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
- if (memoryAcquired < pageSizeBytes) {
- shuffleMemoryManager.release(memoryAcquired);
+ currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
+ if (currentPage == null) {
spill();
- final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
- if (memoryAcquiredAfterSpilling != pageSizeBytes) {
- shuffleMemoryManager.release(memoryAcquiredAfterSpilling);
+ currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
+ if (currentPage == null) {
throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory");
}
}
- currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
currentPagePosition = currentPage.getBaseOffset();
freeSpaceInCurrentPage = pageSizeBytes;
allocatedPages.add(currentPage);
@@ -430,17 +428,14 @@ final class ShuffleExternalSorter {
long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired);
// The record is larger than the page size, so allocate a special overflow page just to hold
// that record.
- final long memoryGranted = shuffleMemoryManager.tryToAcquire(overflowPageSize);
- if (memoryGranted != overflowPageSize) {
- shuffleMemoryManager.release(memoryGranted);
+ MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
+ if (overflowPage == null) {
spill();
- final long memoryGrantedAfterSpill = shuffleMemoryManager.tryToAcquire(overflowPageSize);
- if (memoryGrantedAfterSpill != overflowPageSize) {
- shuffleMemoryManager.release(memoryGrantedAfterSpill);
+ overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
+ if (overflowPage == null) {
throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory");
}
}
- MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
allocatedPages.add(overflowPage);
dataPage = overflowPage;
dataPagePosition = overflowPage.getBaseOffset();
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index e8f050cb2d..f6c5c944bd 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -49,12 +49,11 @@ import org.apache.spark.serializer.SerializationStream;
import org.apache.spark.serializer.Serializer;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
-import org.apache.spark.shuffle.ShuffleMemoryManager;
import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.TimeTrackingOutputStream;
import org.apache.spark.unsafe.Platform;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.memory.TaskMemoryManager;
@Private
public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
@@ -69,7 +68,6 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
private final BlockManager blockManager;
private final IndexShuffleBlockResolver shuffleBlockResolver;
private final TaskMemoryManager memoryManager;
- private final ShuffleMemoryManager shuffleMemoryManager;
private final SerializerInstance serializer;
private final Partitioner partitioner;
private final ShuffleWriteMetrics writeMetrics;
@@ -103,7 +101,6 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
BlockManager blockManager,
IndexShuffleBlockResolver shuffleBlockResolver,
TaskMemoryManager memoryManager,
- ShuffleMemoryManager shuffleMemoryManager,
SerializedShuffleHandle<K, V> handle,
int mapId,
TaskContext taskContext,
@@ -117,7 +114,6 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
this.blockManager = blockManager;
this.shuffleBlockResolver = shuffleBlockResolver;
this.memoryManager = memoryManager;
- this.shuffleMemoryManager = shuffleMemoryManager;
this.mapId = mapId;
final ShuffleDependency<K, V, V> dep = handle.dependency();
this.shuffleId = dep.shuffleId();
@@ -197,7 +193,6 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
assert (sorter == null);
sorter = new ShuffleExternalSorter(
memoryManager,
- shuffleMemoryManager,
blockManager,
taskContext,
INITIAL_SORT_BUFFER_SIZE,
diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index b24eed3952..f035bdac81 100644
--- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -26,7 +26,6 @@ import com.google.common.annotations.VisibleForTesting;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import org.apache.spark.shuffle.ShuffleMemoryManager;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.array.LongArray;
@@ -34,7 +33,7 @@ import org.apache.spark.unsafe.bitset.BitSet;
import org.apache.spark.unsafe.hash.Murmur3_x86_32;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.unsafe.memory.MemoryLocation;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.memory.TaskMemoryManager;
/**
* An append-only hash map where keys and values are contiguous regions of bytes.
@@ -70,8 +69,6 @@ public final class BytesToBytesMap {
private final TaskMemoryManager taskMemoryManager;
- private final ShuffleMemoryManager shuffleMemoryManager;
-
/**
* A linked list for tracking all allocated data pages so that we can free all of our memory.
*/
@@ -169,13 +166,11 @@ public final class BytesToBytesMap {
public BytesToBytesMap(
TaskMemoryManager taskMemoryManager,
- ShuffleMemoryManager shuffleMemoryManager,
int initialCapacity,
double loadFactor,
long pageSizeBytes,
boolean enablePerfMetrics) {
this.taskMemoryManager = taskMemoryManager;
- this.shuffleMemoryManager = shuffleMemoryManager;
this.loadFactor = loadFactor;
this.loc = new Location();
this.pageSizeBytes = pageSizeBytes;
@@ -201,21 +196,18 @@ public final class BytesToBytesMap {
public BytesToBytesMap(
TaskMemoryManager taskMemoryManager,
- ShuffleMemoryManager shuffleMemoryManager,
int initialCapacity,
long pageSizeBytes) {
- this(taskMemoryManager, shuffleMemoryManager, initialCapacity, 0.70, pageSizeBytes, false);
+ this(taskMemoryManager, initialCapacity, 0.70, pageSizeBytes, false);
}
public BytesToBytesMap(
TaskMemoryManager taskMemoryManager,
- ShuffleMemoryManager shuffleMemoryManager,
int initialCapacity,
long pageSizeBytes,
boolean enablePerfMetrics) {
this(
taskMemoryManager,
- shuffleMemoryManager,
initialCapacity,
0.70,
pageSizeBytes,
@@ -260,7 +252,6 @@ public final class BytesToBytesMap {
if (destructive && currentPage != null) {
dataPagesIterator.remove();
this.bmap.taskMemoryManager.freePage(currentPage);
- this.bmap.shuffleMemoryManager.release(currentPage.size());
}
currentPage = dataPagesIterator.next();
pageBaseObject = currentPage.getBaseObject();
@@ -572,14 +563,12 @@ public final class BytesToBytesMap {
if (useOverflowPage) {
// The record is larger than the page size, so allocate a special overflow page just to hold
// that record.
- final long memoryRequested = requiredSize + 8;
- final long memoryGranted = shuffleMemoryManager.tryToAcquire(memoryRequested);
- if (memoryGranted != memoryRequested) {
- shuffleMemoryManager.release(memoryGranted);
- logger.debug("Failed to acquire {} bytes of memory", memoryRequested);
+ final long overflowPageSize = requiredSize + 8;
+ MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
+ if (overflowPage == null) {
+ logger.debug("Failed to acquire {} bytes of memory", overflowPageSize);
return false;
}
- MemoryBlock overflowPage = taskMemoryManager.allocatePage(memoryRequested);
dataPages.add(overflowPage);
dataPage = overflowPage;
dataPageBaseObject = overflowPage.getBaseObject();
@@ -655,17 +644,15 @@ public final class BytesToBytesMap {
}
/**
- * Acquire a new page from the {@link ShuffleMemoryManager}.
+ * Acquire a new page from the memory manager.
* @return whether there is enough space to allocate the new page.
*/
private boolean acquireNewPage() {
- final long memoryGranted = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
- if (memoryGranted != pageSizeBytes) {
- shuffleMemoryManager.release(memoryGranted);
+ MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes);
+ if (newPage == null) {
logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes);
return false;
}
- MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes);
dataPages.add(newPage);
pageCursor = 0;
currentDataPage = newPage;
@@ -705,7 +692,6 @@ public final class BytesToBytesMap {
MemoryBlock dataPage = dataPagesIterator.next();
dataPagesIterator.remove();
taskMemoryManager.freePage(dataPage);
- shuffleMemoryManager.release(dataPage.size());
}
assert(dataPages.isEmpty());
}
@@ -714,10 +700,6 @@ public final class BytesToBytesMap {
return taskMemoryManager;
}
- public ShuffleMemoryManager getShuffleMemoryManager() {
- return shuffleMemoryManager;
- }
-
public long getPageSizeBytes() {
return pageSizeBytes;
}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java
index 0c4ebde407..dbf6770e07 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordPointerAndKeyPrefix.java
@@ -17,9 +17,11 @@
package org.apache.spark.util.collection.unsafe.sort;
+import org.apache.spark.memory.TaskMemoryManager;
+
final class RecordPointerAndKeyPrefix {
/**
- * A pointer to a record; see {@link org.apache.spark.unsafe.memory.TaskMemoryManager} for a
+ * A pointer to a record; see {@link TaskMemoryManager} for a
* description of how these addresses are encoded.
*/
public long recordPointer;
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index 0a311d2d93..e317ea391c 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -32,12 +32,11 @@ import org.slf4j.LoggerFactory;
import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
-import org.apache.spark.shuffle.ShuffleMemoryManager;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.memory.MemoryBlock;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.util.Utils;
/**
@@ -52,7 +51,6 @@ public final class UnsafeExternalSorter {
private final RecordComparator recordComparator;
private final int initialSize;
private final TaskMemoryManager taskMemoryManager;
- private final ShuffleMemoryManager shuffleMemoryManager;
private final BlockManager blockManager;
private final TaskContext taskContext;
private ShuffleWriteMetrics writeMetrics;
@@ -82,7 +80,6 @@ public final class UnsafeExternalSorter {
public static UnsafeExternalSorter createWithExistingInMemorySorter(
TaskMemoryManager taskMemoryManager,
- ShuffleMemoryManager shuffleMemoryManager,
BlockManager blockManager,
TaskContext taskContext,
RecordComparator recordComparator,
@@ -90,26 +87,24 @@ public final class UnsafeExternalSorter {
int initialSize,
long pageSizeBytes,
UnsafeInMemorySorter inMemorySorter) throws IOException {
- return new UnsafeExternalSorter(taskMemoryManager, shuffleMemoryManager, blockManager,
+ return new UnsafeExternalSorter(taskMemoryManager, blockManager,
taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, inMemorySorter);
}
public static UnsafeExternalSorter create(
TaskMemoryManager taskMemoryManager,
- ShuffleMemoryManager shuffleMemoryManager,
BlockManager blockManager,
TaskContext taskContext,
RecordComparator recordComparator,
PrefixComparator prefixComparator,
int initialSize,
long pageSizeBytes) throws IOException {
- return new UnsafeExternalSorter(taskMemoryManager, shuffleMemoryManager, blockManager,
+ return new UnsafeExternalSorter(taskMemoryManager, blockManager,
taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, null);
}
private UnsafeExternalSorter(
TaskMemoryManager taskMemoryManager,
- ShuffleMemoryManager shuffleMemoryManager,
BlockManager blockManager,
TaskContext taskContext,
RecordComparator recordComparator,
@@ -118,7 +113,6 @@ public final class UnsafeExternalSorter {
long pageSizeBytes,
@Nullable UnsafeInMemorySorter existingInMemorySorter) throws IOException {
this.taskMemoryManager = taskMemoryManager;
- this.shuffleMemoryManager = shuffleMemoryManager;
this.blockManager = blockManager;
this.taskContext = taskContext;
this.recordComparator = recordComparator;
@@ -261,7 +255,6 @@ public final class UnsafeExternalSorter {
long memoryFreed = 0;
for (MemoryBlock block : allocatedPages) {
taskMemoryManager.freePage(block);
- shuffleMemoryManager.release(block.size());
memoryFreed += block.size();
}
// TODO: track in-memory sorter memory usage (SPARK-10474)
@@ -309,8 +302,7 @@ public final class UnsafeExternalSorter {
/**
* Allocates more memory in order to insert an additional record. This will request additional
- * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be
- * obtained.
+ * memory from the memory manager and spill if the requested memory can not be obtained.
*
* @param requiredSpace the required space in the data page, in bytes, including space for storing
* the record size. This must be less than or equal to the page size (records
@@ -335,23 +327,20 @@ public final class UnsafeExternalSorter {
}
/**
- * Acquire a new page from the {@link ShuffleMemoryManager}.
+ * Acquire a new page from the memory manager.
*
* If there is not enough space to allocate the new page, spill all existing ones
* and try again. If there is still not enough space, report error to the caller.
*/
private void acquireNewPage() throws IOException {
- final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
- if (memoryAcquired < pageSizeBytes) {
- shuffleMemoryManager.release(memoryAcquired);
+ currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
+ if (currentPage == null) {
spill();
- final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
- if (memoryAcquiredAfterSpilling != pageSizeBytes) {
- shuffleMemoryManager.release(memoryAcquiredAfterSpilling);
+ currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
+ if (currentPage == null) {
throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory");
}
}
- currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
currentPagePosition = currentPage.getBaseOffset();
freeSpaceInCurrentPage = pageSizeBytes;
allocatedPages.add(currentPage);
@@ -379,17 +368,14 @@ public final class UnsafeExternalSorter {
long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired);
// The record is larger than the page size, so allocate a special overflow page just to hold
// that record.
- final long memoryGranted = shuffleMemoryManager.tryToAcquire(overflowPageSize);
- if (memoryGranted != overflowPageSize) {
- shuffleMemoryManager.release(memoryGranted);
+ MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
+ if (overflowPage == null) {
spill();
- final long memoryGrantedAfterSpill = shuffleMemoryManager.tryToAcquire(overflowPageSize);
- if (memoryGrantedAfterSpill != overflowPageSize) {
- shuffleMemoryManager.release(memoryGrantedAfterSpill);
+ overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
+ if (overflowPage == null) {
throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory");
}
}
- MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
allocatedPages.add(overflowPage);
dataPage = overflowPage;
dataPagePosition = overflowPage.getBaseOffset();
@@ -441,17 +427,14 @@ public final class UnsafeExternalSorter {
long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired);
// The record is larger than the page size, so allocate a special overflow page just to hold
// that record.
- final long memoryGranted = shuffleMemoryManager.tryToAcquire(overflowPageSize);
- if (memoryGranted != overflowPageSize) {
- shuffleMemoryManager.release(memoryGranted);
+ MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
+ if (overflowPage == null) {
spill();
- final long memoryGrantedAfterSpill = shuffleMemoryManager.tryToAcquire(overflowPageSize);
- if (memoryGrantedAfterSpill != overflowPageSize) {
- shuffleMemoryManager.release(memoryGrantedAfterSpill);
+ overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
+ if (overflowPage == null) {
throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory");
}
}
- MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
allocatedPages.add(overflowPage);
dataPage = overflowPage;
dataPagePosition = overflowPage.getBaseOffset();
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index f7787e1019..5aad72c374 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -21,7 +21,7 @@ import java.util.Comparator;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.util.collection.Sorter;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.memory.TaskMemoryManager;
/**
* Sorts records using an AlphaSort-style key-prefix sort. This sort stores pointers to records
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index b5c35c569e..398e093690 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -38,9 +38,8 @@ import org.apache.spark.rpc.akka.AkkaRpcEnv
import org.apache.spark.scheduler.{OutputCommitCoordinator, LiveListenerBus}
import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint
import org.apache.spark.serializer.Serializer
-import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager}
+import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.storage._
-import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator}
import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils}
/**
@@ -70,10 +69,7 @@ class SparkEnv (
val httpFileServer: HttpFileServer,
val sparkFilesDir: String,
val metricsSystem: MetricsSystem,
- // TODO: unify these *MemoryManager classes (SPARK-10984)
val memoryManager: MemoryManager,
- val shuffleMemoryManager: ShuffleMemoryManager,
- val executorMemoryManager: ExecutorMemoryManager,
val outputCommitCoordinator: OutputCommitCoordinator,
val conf: SparkConf) extends Logging {
@@ -340,13 +336,11 @@ object SparkEnv extends Logging {
val useLegacyMemoryManager = conf.getBoolean("spark.memory.useLegacyMode", false)
val memoryManager: MemoryManager =
if (useLegacyMemoryManager) {
- new StaticMemoryManager(conf)
+ new StaticMemoryManager(conf, numUsableCores)
} else {
- new UnifiedMemoryManager(conf)
+ new UnifiedMemoryManager(conf, numUsableCores)
}
- val shuffleMemoryManager = ShuffleMemoryManager.create(conf, memoryManager, numUsableCores)
-
val blockTransferService = new NettyBlockTransferService(conf, securityManager, numUsableCores)
val blockManagerMaster = new BlockManagerMaster(registerOrLookupEndpoint(
@@ -405,15 +399,6 @@ object SparkEnv extends Logging {
new OutputCommitCoordinatorEndpoint(rpcEnv, outputCommitCoordinator))
outputCommitCoordinator.coordinatorRef = Some(outputCommitCoordinatorRef)
- val executorMemoryManager: ExecutorMemoryManager = {
- val allocator = if (conf.getBoolean("spark.unsafe.offHeap", false)) {
- MemoryAllocator.UNSAFE
- } else {
- MemoryAllocator.HEAP
- }
- new ExecutorMemoryManager(allocator)
- }
-
val envInstance = new SparkEnv(
executorId,
rpcEnv,
@@ -431,8 +416,6 @@ object SparkEnv extends Logging {
sparkFilesDir,
metricsSystem,
memoryManager,
- shuffleMemoryManager,
- executorMemoryManager,
outputCommitCoordinator,
conf)
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
index 63cca80b2d..af558d6e5b 100644
--- a/core/src/main/scala/org/apache/spark/TaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -21,8 +21,8 @@ import java.io.Serializable
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.metrics.source.Source
-import org.apache.spark.unsafe.memory.TaskMemoryManager
import org.apache.spark.util.TaskCompletionListener
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index 5df94c6d3a..f0ae83a934 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -20,9 +20,9 @@ package org.apache.spark
import scala.collection.mutable.{ArrayBuffer, HashMap}
import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.metrics.source.Source
-import org.apache.spark.unsafe.memory.TaskMemoryManager
import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException}
private[spark] class TaskContextImpl(
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index c3491bb8b1..9e88d488c0 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -29,10 +29,10 @@ import scala.util.control.NonFatal
import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task}
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
-import org.apache.spark.unsafe.memory.TaskMemoryManager
import org.apache.spark.util._
/**
@@ -179,7 +179,7 @@ private[spark] class Executor(
}
override def run(): Unit = {
- val taskMemoryManager = new TaskMemoryManager(env.executorMemoryManager)
+ val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId)
val deserializeStartTime = System.currentTimeMillis()
Thread.currentThread.setContextClassLoader(replClassLoader)
val ser = env.closureSerializer.newInstance()
diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala
index 7168ac5491..6c9a71c385 100644
--- a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala
+++ b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala
@@ -17,20 +17,38 @@
package org.apache.spark.memory
+import javax.annotation.concurrent.GuardedBy
+
import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.Logging
-import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore}
+import com.google.common.annotations.VisibleForTesting
+import org.apache.spark.{SparkException, TaskContext, SparkConf, Logging}
+import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore}
+import org.apache.spark.unsafe.array.ByteArrayMethods
+import org.apache.spark.unsafe.memory.MemoryAllocator
/**
* An abstract memory manager that enforces how memory is shared between execution and storage.
*
* In this context, execution memory refers to that used for computation in shuffles, joins,
* sorts and aggregations, while storage memory refers to that used for caching and propagating
- * internal data across the cluster. There exists one of these per JVM.
+ * internal data across the cluster. There exists one MemoryManager per JVM.
+ *
+ * The MemoryManager abstract base class itself implements policies for sharing execution memory
+ * between tasks; it tries to ensure that each task gets a reasonable share of memory, instead of
+ * some task ramping up to a large amount first and then causing others to spill to disk repeatedly.
+ * If there are N tasks, it ensures that each task can acquire at least 1 / 2N of the memory
+ * before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the
+ * set of active tasks and redo the calculations of 1 / 2N and 1 / N in waiting tasks whenever
+ * this set changes. This is all done by synchronizing access to mutable state and using wait() and
+ * notifyAll() to signal changes to callers. Prior to Spark 1.6, this arbitration of memory across
+ * tasks was performed by the ShuffleMemoryManager.
*/
-private[spark] abstract class MemoryManager extends Logging {
+private[spark] abstract class MemoryManager(conf: SparkConf, numCores: Int) extends Logging {
+
+ // -- Methods related to memory allocation policies and bookkeeping ------------------------------
// The memory store used to evict cached blocks
private var _memoryStore: MemoryStore = _
@@ -42,8 +60,10 @@ private[spark] abstract class MemoryManager extends Logging {
}
// Amount of execution/storage memory in use, accesses must be synchronized on `this`
- protected var _executionMemoryUsed: Long = 0
- protected var _storageMemoryUsed: Long = 0
+ @GuardedBy("this") protected var _executionMemoryUsed: Long = 0
+ @GuardedBy("this") protected var _storageMemoryUsed: Long = 0
+ // Map from taskAttemptId -> memory consumption in bytes
+ @GuardedBy("this") private val executionMemoryForTask = new mutable.HashMap[Long, Long]()
/**
* Set the [[MemoryStore]] used by this manager to evict cached blocks.
@@ -66,15 +86,6 @@ private[spark] abstract class MemoryManager extends Logging {
// TODO: avoid passing evicted blocks around to simplify method signatures (SPARK-10985)
/**
- * Acquire N bytes of memory for execution, evicting cached blocks if necessary.
- * Blocks evicted in the process, if any, are added to `evictedBlocks`.
- * @return number of bytes successfully granted (<= N).
- */
- def acquireExecutionMemory(
- numBytes: Long,
- evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long
-
- /**
* Acquire N bytes of memory to cache the given block, evicting existing ones if necessary.
* Blocks evicted in the process, if any, are added to `evictedBlocks`.
* @return whether all N bytes were successfully granted.
@@ -102,9 +113,92 @@ private[spark] abstract class MemoryManager extends Logging {
}
/**
- * Release N bytes of execution memory.
+ * Acquire N bytes of memory for execution, evicting cached blocks if necessary.
+ * Blocks evicted in the process, if any, are added to `evictedBlocks`.
+ * @return number of bytes successfully granted (<= N).
+ */
+ @VisibleForTesting
+ private[memory] def doAcquireExecutionMemory(
+ numBytes: Long,
+ evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long
+
+ /**
+ * Try to acquire up to `numBytes` of execution memory for the current task and return the number
+ * of bytes obtained, or 0 if none can be allocated.
+ *
+ * This call may block until there is enough free memory in some situations, to make sure each
+ * task has a chance to ramp up to at least 1 / 2N of the total memory pool (where N is the # of
+ * active tasks) before it is forced to spill. This can happen if the number of tasks increase
+ * but an older task had a lot of memory already.
+ *
+ * Subclasses should override `doAcquireExecutionMemory` in order to customize the policies
+ * that control global sharing of memory between execution and storage.
*/
- def releaseExecutionMemory(numBytes: Long): Unit = synchronized {
+ private[memory]
+ final def acquireExecutionMemory(numBytes: Long, taskAttemptId: Long): Long = synchronized {
+ assert(numBytes > 0, "invalid number of bytes requested: " + numBytes)
+
+ // Add this task to the taskMemory map just so we can keep an accurate count of the number
+ // of active tasks, to let other tasks ramp down their memory in calls to tryToAcquire
+ if (!executionMemoryForTask.contains(taskAttemptId)) {
+ executionMemoryForTask(taskAttemptId) = 0L
+ // This will later cause waiting tasks to wake up and check numTasks again
+ notifyAll()
+ }
+
+ // Once the cross-task memory allocation policy has decided to grant more memory to a task,
+ // this method is called in order to actually obtain that execution memory, potentially
+ // triggering eviction of storage memory:
+ def acquire(toGrant: Long): Long = synchronized {
+ val evictedBlocks = new ArrayBuffer[(BlockId, BlockStatus)]
+ val acquired = doAcquireExecutionMemory(toGrant, evictedBlocks)
+ // Register evicted blocks, if any, with the active task metrics
+ Option(TaskContext.get()).foreach { tc =>
+ val metrics = tc.taskMetrics()
+ val lastUpdatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]())
+ metrics.updatedBlocks = Some(lastUpdatedBlocks ++ evictedBlocks.toSeq)
+ }
+ executionMemoryForTask(taskAttemptId) += acquired
+ acquired
+ }
+
+ // Keep looping until we're either sure that we don't want to grant this request (because this
+ // task would have more than 1 / numActiveTasks of the memory) or we have enough free
+ // memory to give it (we always let each task get at least 1 / (2 * numActiveTasks)).
+ // TODO: simplify this to limit each task to its own slot
+ while (true) {
+ val numActiveTasks = executionMemoryForTask.keys.size
+ val curMem = executionMemoryForTask(taskAttemptId)
+ val freeMemory = maxExecutionMemory - executionMemoryForTask.values.sum
+
+ // How much we can grant this task; don't let it grow to more than 1 / numActiveTasks;
+ // don't let it be negative
+ val maxToGrant =
+ math.min(numBytes, math.max(0, (maxExecutionMemory / numActiveTasks) - curMem))
+ // Only give it as much memory as is free, which might be none if it reached 1 / numTasks
+ val toGrant = math.min(maxToGrant, freeMemory)
+
+ if (curMem < maxExecutionMemory / (2 * numActiveTasks)) {
+ // We want to let each task get at least 1 / (2 * numActiveTasks) before blocking;
+ // if we can't give it this much now, wait for other tasks to free up memory
+ // (this happens if older tasks allocated lots of memory before N grew)
+ if (
+ freeMemory >= math.min(maxToGrant, maxExecutionMemory / (2 * numActiveTasks) - curMem)) {
+ return acquire(toGrant)
+ } else {
+ logInfo(
+ s"TID $taskAttemptId waiting for at least 1/2N of execution memory pool to be free")
+ wait()
+ }
+ } else {
+ return acquire(toGrant)
+ }
+ }
+ 0L // Never reached
+ }
+
+ @VisibleForTesting
+ private[memory] def releaseExecutionMemory(numBytes: Long): Unit = synchronized {
if (numBytes > _executionMemoryUsed) {
logWarning(s"Attempted to release $numBytes bytes of execution " +
s"memory when we only have ${_executionMemoryUsed} bytes")
@@ -115,6 +209,36 @@ private[spark] abstract class MemoryManager extends Logging {
}
/**
+ * Release numBytes of execution memory belonging to the given task.
+ */
+ private[memory]
+ final def releaseExecutionMemory(numBytes: Long, taskAttemptId: Long): Unit = synchronized {
+ val curMem = executionMemoryForTask.getOrElse(taskAttemptId, 0L)
+ if (curMem < numBytes) {
+ throw new SparkException(
+ s"Internal error: release called on $numBytes bytes but task only has $curMem")
+ }
+ if (executionMemoryForTask.contains(taskAttemptId)) {
+ executionMemoryForTask(taskAttemptId) -= numBytes
+ if (executionMemoryForTask(taskAttemptId) <= 0) {
+ executionMemoryForTask.remove(taskAttemptId)
+ }
+ releaseExecutionMemory(numBytes)
+ }
+ notifyAll() // Notify waiters in acquireExecutionMemory() that memory has been freed
+ }
+
+ /**
+ * Release all memory for the given task and mark it as inactive (e.g. when a task ends).
+ * @return the number of bytes freed.
+ */
+ private[memory] def releaseAllExecutionMemoryForTask(taskAttemptId: Long): Long = synchronized {
+ val numBytesToFree = getExecutionMemoryUsageForTask(taskAttemptId)
+ releaseExecutionMemory(numBytesToFree, taskAttemptId)
+ numBytesToFree
+ }
+
+ /**
* Release N bytes of storage memory.
*/
def releaseStorageMemory(numBytes: Long): Unit = synchronized {
@@ -155,4 +279,43 @@ private[spark] abstract class MemoryManager extends Logging {
_storageMemoryUsed
}
+ /**
+ * Returns the execution memory consumption, in bytes, for the given task.
+ */
+ private[memory] def getExecutionMemoryUsageForTask(taskAttemptId: Long): Long = synchronized {
+ executionMemoryForTask.getOrElse(taskAttemptId, 0L)
+ }
+
+ // -- Fields related to Tungsten managed memory -------------------------------------------------
+
+ /**
+ * The default page size, in bytes.
+ *
+ * If user didn't explicitly set "spark.buffer.pageSize", we figure out the default value
+ * by looking at the number of cores available to the process, and the total amount of memory,
+ * and then divide it by a factor of safety.
+ */
+ val pageSizeBytes: Long = {
+ val minPageSize = 1L * 1024 * 1024 // 1MB
+ val maxPageSize = 64L * minPageSize // 64MB
+ val cores = if (numCores > 0) numCores else Runtime.getRuntime.availableProcessors()
+ // Because of rounding to next power of 2, we may have safetyFactor as 8 in worst case
+ val safetyFactor = 16
+ val size = ByteArrayMethods.nextPowerOf2(maxExecutionMemory / cores / safetyFactor)
+ val default = math.min(maxPageSize, math.max(minPageSize, size))
+ conf.getSizeAsBytes("spark.buffer.pageSize", default)
+ }
+
+ /**
+ * Tracks whether Tungsten memory will be allocated on the JVM heap or off-heap using
+ * sun.misc.Unsafe.
+ */
+ final val tungstenMemoryIsAllocatedInHeap: Boolean =
+ !conf.getBoolean("spark.unsafe.offHeap", false)
+
+ /**
+ * Allocates memory for use by Unsafe/Tungsten code.
+ */
+ private[memory] final val tungstenMemoryAllocator: MemoryAllocator =
+ if (tungstenMemoryIsAllocatedInHeap) MemoryAllocator.HEAP else MemoryAllocator.UNSAFE
}
diff --git a/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala
index fa44f37234..9c2c2e90a2 100644
--- a/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala
+++ b/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala
@@ -33,14 +33,16 @@ import org.apache.spark.storage.{BlockId, BlockStatus}
private[spark] class StaticMemoryManager(
conf: SparkConf,
override val maxExecutionMemory: Long,
- override val maxStorageMemory: Long)
- extends MemoryManager {
+ override val maxStorageMemory: Long,
+ numCores: Int)
+ extends MemoryManager(conf, numCores) {
- def this(conf: SparkConf) {
+ def this(conf: SparkConf, numCores: Int) {
this(
conf,
StaticMemoryManager.getMaxExecutionMemory(conf),
- StaticMemoryManager.getMaxStorageMemory(conf))
+ StaticMemoryManager.getMaxStorageMemory(conf),
+ numCores)
}
// Max number of bytes worth of blocks to evict when unrolling
@@ -52,7 +54,7 @@ private[spark] class StaticMemoryManager(
* Acquire N bytes of memory for execution.
* @return number of bytes successfully granted (<= N).
*/
- override def acquireExecutionMemory(
+ override def doAcquireExecutionMemory(
numBytes: Long,
evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long = synchronized {
assert(numBytes >= 0)
diff --git a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala
index 5bf78d5b67..a3093030a0 100644
--- a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala
+++ b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala
@@ -42,10 +42,14 @@ import org.apache.spark.storage.{BlockStatus, BlockId}
* up most of the storage space, in which case the new blocks will be evicted immediately
* according to their respective storage levels.
*/
-private[spark] class UnifiedMemoryManager(conf: SparkConf, maxMemory: Long) extends MemoryManager {
+private[spark] class UnifiedMemoryManager(
+ conf: SparkConf,
+ maxMemory: Long,
+ numCores: Int)
+ extends MemoryManager(conf, numCores) {
- def this(conf: SparkConf) {
- this(conf, UnifiedMemoryManager.getMaxMemory(conf))
+ def this(conf: SparkConf, numCores: Int) {
+ this(conf, UnifiedMemoryManager.getMaxMemory(conf), numCores)
}
/**
@@ -91,7 +95,7 @@ private[spark] class UnifiedMemoryManager(conf: SparkConf, maxMemory: Long) exte
* Blocks evicted in the process, if any, are added to `evictedBlocks`.
* @return number of bytes successfully granted (<= N).
*/
- override def acquireExecutionMemory(
+ private[memory] override def doAcquireExecutionMemory(
numBytes: Long,
evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long = synchronized {
assert(numBytes >= 0)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 9edf9f048f..4fb32ba8cb 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -25,8 +25,8 @@ import scala.collection.mutable.HashMap
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.{Accumulator, SparkEnv, TaskContextImpl, TaskContext}
import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.serializer.SerializerInstance
-import org.apache.spark.unsafe.memory.TaskMemoryManager
import org.apache.spark.util.ByteBufferInputStream
import org.apache.spark.util.Utils
@@ -90,10 +90,6 @@ private[spark] abstract class Task[T](
context.markTaskCompleted()
try {
Utils.tryLogNonFatalError {
- // Release memory used by this thread for shuffles
- SparkEnv.get.shuffleMemoryManager.releaseMemoryForThisTask()
- }
- Utils.tryLogNonFatalError {
// Release memory used by this thread for unrolling blocks
SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask()
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
index 7c3e2b5a37..b0abda4a81 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
@@ -98,13 +98,14 @@ private[spark] class BlockStoreShuffleReader[K, C](
case Some(keyOrd: Ordering[K]) =>
// Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
// the ExternalSorter won't spill to disk.
- val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser))
+ val sorter =
+ new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = Some(ser))
sorter.insertAll(aggregatedIter)
context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
context.internalMetricsToAccumulators(
InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes)
- sorter.iterator
+ CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
case None =>
aggregatedIter
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
deleted file mode 100644
index 9bd18da47f..0000000000
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
+++ /dev/null
@@ -1,209 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle
-
-import scala.collection.mutable
-import scala.collection.mutable.ArrayBuffer
-
-import com.google.common.annotations.VisibleForTesting
-
-import org.apache.spark._
-import org.apache.spark.memory.{StaticMemoryManager, MemoryManager}
-import org.apache.spark.storage.{BlockId, BlockStatus}
-import org.apache.spark.unsafe.array.ByteArrayMethods
-
-/**
- * Allocates a pool of memory to tasks for use in shuffle operations. Each disk-spilling
- * collection (ExternalAppendOnlyMap or ExternalSorter) used by these tasks can acquire memory
- * from this pool and release it as it spills data out. When a task ends, all its memory will be
- * released by the Executor.
- *
- * This class tries to ensure that each task gets a reasonable share of memory, instead of some
- * task ramping up to a large amount first and then causing others to spill to disk repeatedly.
- * If there are N tasks, it ensures that each tasks can acquire at least 1 / 2N of the memory
- * before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the
- * set of active tasks and redo the calculations of 1 / 2N and 1 / N in waiting tasks whenever
- * this set changes. This is all done by synchronizing access to `memoryManager` to mutate state
- * and using wait() and notifyAll() to signal changes.
- *
- * Use `ShuffleMemoryManager.create()` factory method to create a new instance.
- *
- * @param memoryManager the interface through which this manager acquires execution memory
- * @param pageSizeBytes number of bytes for each page, by default.
- */
-private[spark]
-class ShuffleMemoryManager protected (
- memoryManager: MemoryManager,
- val pageSizeBytes: Long)
- extends Logging {
-
- private val taskMemory = new mutable.HashMap[Long, Long]() // taskAttemptId -> memory bytes
-
- private def currentTaskAttemptId(): Long = {
- // In case this is called on the driver, return an invalid task attempt id.
- Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(-1L)
- }
-
- /**
- * Try to acquire up to numBytes memory for the current task, and return the number of bytes
- * obtained, or 0 if none can be allocated. This call may block until there is enough free memory
- * in some situations, to make sure each task has a chance to ramp up to at least 1 / 2N of the
- * total memory pool (where N is the # of active tasks) before it is forced to spill. This can
- * happen if the number of tasks increases but an older task had a lot of memory already.
- */
- def tryToAcquire(numBytes: Long): Long = memoryManager.synchronized {
- val taskAttemptId = currentTaskAttemptId()
- assert(numBytes > 0, "invalid number of bytes requested: " + numBytes)
-
- // Add this task to the taskMemory map just so we can keep an accurate count of the number
- // of active tasks, to let other tasks ramp down their memory in calls to tryToAcquire
- if (!taskMemory.contains(taskAttemptId)) {
- taskMemory(taskAttemptId) = 0L
- // This will later cause waiting tasks to wake up and check numTasks again
- memoryManager.notifyAll()
- }
-
- // Keep looping until we're either sure that we don't want to grant this request (because this
- // task would have more than 1 / numActiveTasks of the memory) or we have enough free
- // memory to give it (we always let each task get at least 1 / (2 * numActiveTasks)).
- // TODO: simplify this to limit each task to its own slot
- while (true) {
- val numActiveTasks = taskMemory.keys.size
- val curMem = taskMemory(taskAttemptId)
- val maxMemory = memoryManager.maxExecutionMemory
- val freeMemory = maxMemory - taskMemory.values.sum
-
- // How much we can grant this task; don't let it grow to more than 1 / numActiveTasks;
- // don't let it be negative
- val maxToGrant = math.min(numBytes, math.max(0, (maxMemory / numActiveTasks) - curMem))
- // Only give it as much memory as is free, which might be none if it reached 1 / numTasks
- val toGrant = math.min(maxToGrant, freeMemory)
-
- if (curMem < maxMemory / (2 * numActiveTasks)) {
- // We want to let each task get at least 1 / (2 * numActiveTasks) before blocking;
- // if we can't give it this much now, wait for other tasks to free up memory
- // (this happens if older tasks allocated lots of memory before N grew)
- if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveTasks) - curMem)) {
- return acquire(toGrant)
- } else {
- logInfo(
- s"TID $taskAttemptId waiting for at least 1/2N of shuffle memory pool to be free")
- memoryManager.wait()
- }
- } else {
- return acquire(toGrant)
- }
- }
- 0L // Never reached
- }
-
- /**
- * Acquire N bytes of execution memory from the memory manager for the current task.
- * @return number of bytes actually acquired (<= N).
- */
- private def acquire(numBytes: Long): Long = memoryManager.synchronized {
- val taskAttemptId = currentTaskAttemptId()
- val evictedBlocks = new ArrayBuffer[(BlockId, BlockStatus)]
- val acquired = memoryManager.acquireExecutionMemory(numBytes, evictedBlocks)
- // Register evicted blocks, if any, with the active task metrics
- // TODO: just do this in `acquireExecutionMemory` (SPARK-10985)
- Option(TaskContext.get()).foreach { tc =>
- val metrics = tc.taskMetrics()
- val lastUpdatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]())
- metrics.updatedBlocks = Some(lastUpdatedBlocks ++ evictedBlocks.toSeq)
- }
- taskMemory(taskAttemptId) += acquired
- acquired
- }
-
- /** Release numBytes bytes for the current task. */
- def release(numBytes: Long): Unit = memoryManager.synchronized {
- val taskAttemptId = currentTaskAttemptId()
- val curMem = taskMemory.getOrElse(taskAttemptId, 0L)
- if (curMem < numBytes) {
- throw new SparkException(
- s"Internal error: release called on $numBytes bytes but task only has $curMem")
- }
- if (taskMemory.contains(taskAttemptId)) {
- taskMemory(taskAttemptId) -= numBytes
- memoryManager.releaseExecutionMemory(numBytes)
- }
- memoryManager.notifyAll() // Notify waiters in tryToAcquire that memory has been freed
- }
-
- /** Release all memory for the current task and mark it as inactive (e.g. when a task ends). */
- def releaseMemoryForThisTask(): Unit = memoryManager.synchronized {
- val taskAttemptId = currentTaskAttemptId()
- taskMemory.remove(taskAttemptId).foreach { numBytes =>
- memoryManager.releaseExecutionMemory(numBytes)
- }
- memoryManager.notifyAll() // Notify waiters in tryToAcquire that memory has been freed
- }
-
- /** Returns the memory consumption, in bytes, for the current task */
- def getMemoryConsumptionForThisTask(): Long = memoryManager.synchronized {
- val taskAttemptId = currentTaskAttemptId()
- taskMemory.getOrElse(taskAttemptId, 0L)
- }
-}
-
-
-private[spark] object ShuffleMemoryManager {
-
- def create(
- conf: SparkConf,
- memoryManager: MemoryManager,
- numCores: Int): ShuffleMemoryManager = {
- val maxMemory = memoryManager.maxExecutionMemory
- val pageSize = ShuffleMemoryManager.getPageSize(conf, maxMemory, numCores)
- new ShuffleMemoryManager(memoryManager, pageSize)
- }
-
- /**
- * Create a dummy [[ShuffleMemoryManager]] with the specified capacity and page size.
- */
- def create(maxMemory: Long, pageSizeBytes: Long): ShuffleMemoryManager = {
- val conf = new SparkConf
- val memoryManager = new StaticMemoryManager(
- conf, maxExecutionMemory = maxMemory, maxStorageMemory = Long.MaxValue)
- new ShuffleMemoryManager(memoryManager, pageSizeBytes)
- }
-
- @VisibleForTesting
- def createForTesting(maxMemory: Long): ShuffleMemoryManager = {
- create(maxMemory, 4 * 1024 * 1024)
- }
-
- /**
- * Sets the page size, in bytes.
- *
- * If user didn't explicitly set "spark.buffer.pageSize", we figure out the default value
- * by looking at the number of cores available to the process, and the total amount of memory,
- * and then divide it by a factor of safety.
- */
- private def getPageSize(conf: SparkConf, maxMemory: Long, numCores: Int): Long = {
- val minPageSize = 1L * 1024 * 1024 // 1MB
- val maxPageSize = 64L * minPageSize // 64MB
- val cores = if (numCores > 0) numCores else Runtime.getRuntime.availableProcessors()
- // Because of rounding to next power of 2, we may have safetyFactor as 8 in worst case
- val safetyFactor = 16
- val size = ByteArrayMethods.nextPowerOf2(maxMemory / cores / safetyFactor)
- val default = math.min(maxPageSize, math.max(minPageSize, size))
- conf.getSizeAsBytes("spark.buffer.pageSize", default)
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
index 1105167d39..66b6bbc61f 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
@@ -133,7 +133,6 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
env.blockManager,
shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
context.taskMemoryManager(),
- env.shuffleMemoryManager,
unsafeShuffleHandle,
mapId,
context,
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index bbd9c1ab53..808317b017 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -52,13 +52,13 @@ private[spark] class SortShuffleWriter[K, V, C](
sorter = if (dep.mapSideCombine) {
require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
new ExternalSorter[K, V, C](
- dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
+ context, dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
} else {
// In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
// care whether the keys get sorted in each partition; that will be done on the reduce side
// if the operation being run is sortByKey.
new ExternalSorter[K, V, V](
- aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
+ context, aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
}
sorter.insertAll(records)
@@ -67,7 +67,7 @@ private[spark] class SortShuffleWriter[K, V, C](
// (see SPARK-3570).
val outputFile = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
- val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile)
+ val partitionLengths = sorter.writePartitionedFile(blockId, outputFile)
shuffleBlockResolver.writeIndexFile(dep.shuffleId, mapId, partitionLengths)
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index cfa58f5ef4..f6d81ee5bf 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -28,8 +28,10 @@ import com.google.common.io.ByteStreams
import org.apache.spark.{Logging, SparkEnv, TaskContext}
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.serializer.{DeserializationStream, Serializer}
import org.apache.spark.storage.{BlockId, BlockManager}
+import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalAppendOnlyMap.HashComparator
import org.apache.spark.executor.ShuffleWriteMetrics
@@ -55,12 +57,30 @@ class ExternalAppendOnlyMap[K, V, C](
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C,
serializer: Serializer = SparkEnv.get.serializer,
- blockManager: BlockManager = SparkEnv.get.blockManager)
+ blockManager: BlockManager = SparkEnv.get.blockManager,
+ context: TaskContext = TaskContext.get())
extends Iterable[(K, C)]
with Serializable
with Logging
with Spillable[SizeTracker] {
+ if (context == null) {
+ throw new IllegalStateException(
+ "Spillable collections should not be instantiated outside of tasks")
+ }
+
+ // Backwards-compatibility constructor for binary compatibility
+ def this(
+ createCombiner: V => C,
+ mergeValue: (C, V) => C,
+ mergeCombiners: (C, C) => C,
+ serializer: Serializer,
+ blockManager: BlockManager) {
+ this(createCombiner, mergeValue, mergeCombiners, serializer, blockManager, TaskContext.get())
+ }
+
+ override protected[this] def taskMemoryManager: TaskMemoryManager = context.taskMemoryManager()
+
private var currentMap = new SizeTrackingAppendOnlyMap[K, C]
private val spilledMaps = new ArrayBuffer[DiskMapIterator]
private val sparkConf = SparkEnv.get.conf
@@ -118,6 +138,10 @@ class ExternalAppendOnlyMap[K, V, C](
* The shuffle memory usage of the first trackMemoryThreshold entries is not tracked.
*/
def insertAll(entries: Iterator[Product2[K, V]]): Unit = {
+ if (currentMap == null) {
+ throw new IllegalStateException(
+ "Cannot insert new elements into a map after calling iterator")
+ }
// An update function for the map that we reuse across entries to avoid allocating
// a new closure each time
var curEntry: Product2[K, V] = null
@@ -215,17 +239,26 @@ class ExternalAppendOnlyMap[K, V, C](
}
/**
- * Return an iterator that merges the in-memory map with the spilled maps.
+ * Return a destructive iterator that merges the in-memory map with the spilled maps.
* If no spill has occurred, simply return the in-memory map's iterator.
*/
override def iterator: Iterator[(K, C)] = {
+ if (currentMap == null) {
+ throw new IllegalStateException(
+ "ExternalAppendOnlyMap.iterator is destructive and should only be called once.")
+ }
if (spilledMaps.isEmpty) {
- currentMap.iterator
+ CompletionIterator[(K, C), Iterator[(K, C)]](currentMap.iterator, freeCurrentMap())
} else {
new ExternalIterator()
}
}
+ private def freeCurrentMap(): Unit = {
+ currentMap = null // So that the memory can be garbage-collected
+ releaseMemory()
+ }
+
/**
* An iterator that sort-merges (K, C) pairs from the in-memory map and the spilled maps
*/
@@ -237,7 +270,8 @@ class ExternalAppendOnlyMap[K, V, C](
// Input streams are derived both from the in-memory map and spilled maps on disk
// The in-memory map is sorted in place, while the spilled maps are already in sorted order
- private val sortedMap = currentMap.destructiveSortedIterator(keyComparator)
+ private val sortedMap = CompletionIterator[(K, C), Iterator[(K, C)]](
+ currentMap.destructiveSortedIterator(keyComparator), freeCurrentMap())
private val inputStreams = (Seq(sortedMap) ++ spilledMaps).map(it => it.buffered)
inputStreams.foreach { it =>
@@ -493,12 +527,7 @@ class ExternalAppendOnlyMap[K, V, C](
}
}
- val context = TaskContext.get()
- // context is null in some tests of ExternalAppendOnlyMapSuite because these tests don't run in
- // a TaskContext.
- if (context != null) {
- context.addTaskCompletionListener(context => cleanup())
- }
+ context.addTaskCompletionListener(context => cleanup())
}
/** Convenience function to hash the given (K, C) pair by the key. */
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index c48c453a90..a44e72b7c1 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -27,6 +27,7 @@ import com.google.common.annotations.VisibleForTesting
import com.google.common.io.ByteStreams
import org.apache.spark._
+import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.serializer._
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter}
@@ -87,6 +88,7 @@ import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter}
* - Users are expected to call stop() at the end to delete all the intermediate files.
*/
private[spark] class ExternalSorter[K, V, C](
+ context: TaskContext,
aggregator: Option[Aggregator[K, V, C]] = None,
partitioner: Option[Partitioner] = None,
ordering: Option[Ordering[K]] = None,
@@ -94,6 +96,8 @@ private[spark] class ExternalSorter[K, V, C](
extends Logging
with Spillable[WritablePartitionedPairCollection[K, C]] {
+ override protected[this] def taskMemoryManager: TaskMemoryManager = context.taskMemoryManager()
+
private val conf = SparkEnv.get.conf
private val numPartitions = partitioner.map(_.numPartitions).getOrElse(1)
@@ -640,7 +644,6 @@ private[spark] class ExternalSorter[K, V, C](
*/
def writePartitionedFile(
blockId: BlockId,
- context: TaskContext,
outputFile: File): Array[Long] = {
// Track location of each range in the output file
@@ -686,8 +689,11 @@ private[spark] class ExternalSorter[K, V, C](
}
def stop(): Unit = {
+ map = null // So that the memory can be garbage-collected
+ buffer = null // So that the memory can be garbage-collected
spills.foreach(s => s.file.delete())
spills.clear()
+ releaseMemory()
}
/**
diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
index d2a68ca7a3..a76891acf0 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
@@ -17,8 +17,8 @@
package org.apache.spark.util.collection
-import org.apache.spark.Logging
-import org.apache.spark.SparkEnv
+import org.apache.spark.memory.TaskMemoryManager
+import org.apache.spark.{Logging, SparkEnv}
/**
* Spills contents of an in-memory collection to disk when the memory threshold
@@ -40,7 +40,7 @@ private[spark] trait Spillable[C] extends Logging {
protected def addElementsRead(): Unit = { _elementsRead += 1 }
// Memory manager that can be used to acquire/release memory
- private[this] val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager
+ protected[this] def taskMemoryManager: TaskMemoryManager
// Initial threshold for the size of a collection before we start tracking its memory usage
// For testing only
@@ -78,7 +78,7 @@ private[spark] trait Spillable[C] extends Logging {
if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) {
// Claim up to double our current memory from the shuffle memory pool
val amountToRequest = 2 * currentMemory - myMemoryThreshold
- val granted = shuffleMemoryManager.tryToAcquire(amountToRequest)
+ val granted = taskMemoryManager.acquireExecutionMemory(amountToRequest)
myMemoryThreshold += granted
// If we were granted too little memory to grow further (either tryToAcquire returned 0,
// or we already had more memory than myMemoryThreshold), spill the current collection
@@ -92,7 +92,7 @@ private[spark] trait Spillable[C] extends Logging {
spill(collection)
_elementsRead = 0
_memoryBytesSpilled += currentMemory
- releaseMemoryForThisThread()
+ releaseMemory()
}
shouldSpill
}
@@ -103,11 +103,11 @@ private[spark] trait Spillable[C] extends Logging {
def memoryBytesSpilled: Long = _memoryBytesSpilled
/**
- * Release our memory back to the shuffle pool so that other threads can grab it.
+ * Release our memory back to the execution pool so that other tasks can grab it.
*/
- private def releaseMemoryForThisThread(): Unit = {
+ def releaseMemory(): Unit = {
// The amount we requested does not include the initial memory tracking threshold
- shuffleMemoryManager.release(myMemoryThreshold - initialMemoryThreshold)
+ taskMemoryManager.releaseExecutionMemory(myMemoryThreshold - initialMemoryThreshold)
myMemoryThreshold = initialMemoryThreshold
}
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
index 06fb081183..f381db0c62 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java
+++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
@@ -15,33 +15,28 @@
* limitations under the License.
*/
-package org.apache.spark.unsafe.memory;
+package org.apache.spark.memory;
import org.junit.Assert;
import org.junit.Test;
-public class TaskMemoryManagerSuite {
+import org.apache.spark.SparkConf;
+import org.apache.spark.unsafe.memory.MemoryBlock;
- @Test
- public void leakedNonPageMemoryIsDetected() {
- final TaskMemoryManager manager =
- new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
- manager.allocate(1024); // leak memory
- Assert.assertEquals(1024, manager.cleanUpAllAllocatedMemory());
- }
+public class TaskMemoryManagerSuite {
@Test
public void leakedPageMemoryIsDetected() {
- final TaskMemoryManager manager =
- new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+ final TaskMemoryManager manager = new TaskMemoryManager(
+ new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0);
manager.allocatePage(4096); // leak memory
Assert.assertEquals(4096, manager.cleanUpAllAllocatedMemory());
}
@Test
public void encodePageNumberAndOffsetOffHeap() {
- final TaskMemoryManager manager =
- new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE));
+ final TaskMemoryManager manager = new TaskMemoryManager(
+ new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "true")), 0);
final MemoryBlock dataPage = manager.allocatePage(256);
// In off-heap mode, an offset is an absolute address that may require more than 51 bits to
// encode. This test exercises that corner-case:
@@ -53,8 +48,8 @@ public class TaskMemoryManagerSuite {
@Test
public void encodePageNumberAndOffsetOnHeap() {
- final TaskMemoryManager manager =
- new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+ final TaskMemoryManager manager = new TaskMemoryManager(
+ new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0);
final MemoryBlock dataPage = manager.allocatePage(256);
final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, 64);
Assert.assertEquals(dataPage.getBaseObject(), manager.getPage(encodedAddress));
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
index 232ae4d926..7fb2f92ca8 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
@@ -21,18 +21,19 @@ import org.apache.spark.shuffle.sort.PackedRecordPointer;
import org.junit.Test;
import static org.junit.Assert.*;
-import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
-import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.SparkConf;
+import org.apache.spark.memory.GrantEverythingMemoryManager;
import org.apache.spark.unsafe.memory.MemoryBlock;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.memory.TaskMemoryManager;
import static org.apache.spark.shuffle.sort.PackedRecordPointer.*;
public class PackedRecordPointerSuite {
@Test
public void heap() {
+ final SparkConf conf = new SparkConf().set("spark.unsafe.offHeap", "false");
final TaskMemoryManager memoryManager =
- new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+ new TaskMemoryManager(new GrantEverythingMemoryManager(conf), 0);
final MemoryBlock page0 = memoryManager.allocatePage(128);
final MemoryBlock page1 = memoryManager.allocatePage(128);
final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,
@@ -49,8 +50,9 @@ public class PackedRecordPointerSuite {
@Test
public void offHeap() {
+ final SparkConf conf = new SparkConf().set("spark.unsafe.offHeap", "true");
final TaskMemoryManager memoryManager =
- new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE));
+ new TaskMemoryManager(new GrantEverythingMemoryManager(conf), 0);
final MemoryBlock page0 = memoryManager.allocatePage(128);
final MemoryBlock page1 = memoryManager.allocatePage(128);
final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
index 1ef3c5ff64..5049a5306f 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
@@ -24,11 +24,11 @@ import org.junit.Assert;
import org.junit.Test;
import org.apache.spark.HashPartitioner;
+import org.apache.spark.SparkConf;
import org.apache.spark.unsafe.Platform;
-import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
-import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.memory.GrantEverythingMemoryManager;
import org.apache.spark.unsafe.memory.MemoryBlock;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.memory.TaskMemoryManager;
public class ShuffleInMemorySorterSuite {
@@ -58,8 +58,9 @@ public class ShuffleInMemorySorterSuite {
"Lychee",
"Mango"
};
+ final SparkConf conf = new SparkConf().set("spark.unsafe.offHeap", "false");
final TaskMemoryManager memoryManager =
- new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+ new TaskMemoryManager(new GrantEverythingMemoryManager(conf), 0);
final MemoryBlock dataPage = memoryManager.allocatePage(2048);
final Object baseObject = dataPage.getBaseObject();
final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4);
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index 29d9823b1f..d65926949c 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -39,7 +39,6 @@ import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.lessThan;
import static org.junit.Assert.*;
-import static org.mockito.AdditionalAnswers.returnsFirstArg;
import static org.mockito.Answers.RETURNS_SMART_NULLS;
import static org.mockito.Mockito.*;
@@ -54,19 +53,15 @@ import org.apache.spark.network.util.LimitedInputStream;
import org.apache.spark.serializer.*;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
-import org.apache.spark.shuffle.ShuffleMemoryManager;
-import org.apache.spark.shuffle.sort.SerializedShuffleHandle;
import org.apache.spark.storage.*;
-import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
-import org.apache.spark.unsafe.memory.MemoryAllocator;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.memory.GrantEverythingMemoryManager;
+import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.util.Utils;
public class UnsafeShuffleWriterSuite {
static final int NUM_PARTITITONS = 4;
- final TaskMemoryManager taskMemoryManager =
- new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+ TaskMemoryManager taskMemoryManager;
final HashPartitioner hashPartitioner = new HashPartitioner(NUM_PARTITITONS);
File mergedOutputFile;
File tempDir;
@@ -76,7 +71,6 @@ public class UnsafeShuffleWriterSuite {
final Serializer serializer = new KryoSerializer(new SparkConf());
TaskMetrics taskMetrics;
- @Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager;
@Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
@Mock(answer = RETURNS_SMART_NULLS) IndexShuffleBlockResolver shuffleBlockResolver;
@Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
@@ -111,11 +105,11 @@ public class UnsafeShuffleWriterSuite {
mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir);
partitionSizesInMergedFile = null;
spillFilesCreated.clear();
- conf = new SparkConf().set("spark.buffer.pageSize", "128m");
+ conf = new SparkConf()
+ .set("spark.buffer.pageSize", "128m")
+ .set("spark.unsafe.offHeap", "false");
taskMetrics = new TaskMetrics();
-
- when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg());
- when(shuffleMemoryManager.pageSizeBytes()).thenReturn(128L * 1024 * 1024);
+ taskMemoryManager = new TaskMemoryManager(new GrantEverythingMemoryManager(conf), 0);
when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
when(blockManager.getDiskWriter(
@@ -203,7 +197,6 @@ public class UnsafeShuffleWriterSuite {
blockManager,
shuffleBlockResolver,
taskMemoryManager,
- shuffleMemoryManager,
new SerializedShuffleHandle<Object, Object>(0, 1, shuffleDep),
0, // map id
taskContext,
@@ -405,11 +398,12 @@ public class UnsafeShuffleWriterSuite {
@Test
public void writeEnoughDataToTriggerSpill() throws Exception {
- when(shuffleMemoryManager.tryToAcquire(anyLong()))
- .then(returnsFirstArg()) // Allocate initial sort buffer
- .then(returnsFirstArg()) // Allocate initial data page
- .thenReturn(0L) // Deny request to allocate new data page
- .then(returnsFirstArg()); // Grant new sort buffer and data page.
+ taskMemoryManager = spy(taskMemoryManager);
+ doCallRealMethod() // initialize sort buffer
+ .doCallRealMethod() // allocate initial data page
+ .doReturn(0L) // deny request to allocate new page
+ .doCallRealMethod() // grant new sort buffer and data page
+ .when(taskMemoryManager).acquireExecutionMemory(anyLong());
final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<Product2<Object, Object>>();
final byte[] bigByteArray = new byte[PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES / 128];
@@ -417,7 +411,7 @@ public class UnsafeShuffleWriterSuite {
dataToWrite.add(new Tuple2<Object, Object>(i, bigByteArray));
}
writer.write(dataToWrite.iterator());
- verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong());
+ verify(taskMemoryManager, times(5)).acquireExecutionMemory(anyLong());
assertEquals(2, spillFilesCreated.size());
writer.stop(true);
readRecordsFromFile();
@@ -432,18 +426,19 @@ public class UnsafeShuffleWriterSuite {
@Test
public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception {
- when(shuffleMemoryManager.tryToAcquire(anyLong()))
- .then(returnsFirstArg()) // Allocate initial sort buffer
- .then(returnsFirstArg()) // Allocate initial data page
- .thenReturn(0L) // Deny request to grow sort buffer
- .then(returnsFirstArg()); // Grant new sort buffer and data page.
+ taskMemoryManager = spy(taskMemoryManager);
+ doCallRealMethod() // initialize sort buffer
+ .doCallRealMethod() // allocate initial data page
+ .doReturn(0L) // deny request to allocate new page
+ .doCallRealMethod() // grant new sort buffer and data page
+ .when(taskMemoryManager).acquireExecutionMemory(anyLong());
final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
- final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<Product2<Object, Object>>();
+ final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<>();
for (int i = 0; i < UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE; i++) {
dataToWrite.add(new Tuple2<Object, Object>(i, i));
}
writer.write(dataToWrite.iterator());
- verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong());
+ verify(taskMemoryManager, times(5)).acquireExecutionMemory(anyLong());
assertEquals(2, spillFilesCreated.size());
writer.stop(true);
readRecordsFromFile();
@@ -509,13 +504,13 @@ public class UnsafeShuffleWriterSuite {
final long recordLengthBytes = 8;
final long pageSizeBytes = 256;
final long numRecordsPerPage = pageSizeBytes / recordLengthBytes;
- when(shuffleMemoryManager.pageSizeBytes()).thenReturn(pageSizeBytes);
+ taskMemoryManager = spy(taskMemoryManager);
+ when(taskMemoryManager.pageSizeBytes()).thenReturn(pageSizeBytes);
final UnsafeShuffleWriter<Object, Object> writer =
new UnsafeShuffleWriter<Object, Object>(
blockManager,
shuffleBlockResolver,
taskMemoryManager,
- shuffleMemoryManager,
new SerializedShuffleHandle<>(0, 1, shuffleDep),
0, // map id
taskContext,
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index ab480b60ad..6e52496cf9 100644
--- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -21,15 +21,13 @@ import java.lang.Exception;
import java.nio.ByteBuffer;
import java.util.*;
+import org.apache.spark.memory.TaskMemoryManager;
import org.junit.*;
-import org.mockito.invocation.InvocationOnMock;
-import org.mockito.stubbing.Answer;
import static org.hamcrest.Matchers.greaterThan;
import static org.junit.Assert.*;
-import static org.mockito.AdditionalMatchers.geq;
-import static org.mockito.Mockito.*;
-import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.SparkConf;
+import org.apache.spark.memory.GrantEverythingMemoryManager;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.memory.*;
import org.apache.spark.unsafe.Platform;
@@ -39,42 +37,29 @@ public abstract class AbstractBytesToBytesMapSuite {
private final Random rand = new Random(42);
- private ShuffleMemoryManager shuffleMemoryManager;
+ private GrantEverythingMemoryManager memoryManager;
private TaskMemoryManager taskMemoryManager;
- private TaskMemoryManager sizeLimitedTaskMemoryManager;
private final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes
@Before
public void setup() {
- shuffleMemoryManager = ShuffleMemoryManager.create(Long.MAX_VALUE, PAGE_SIZE_BYTES);
- taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(getMemoryAllocator()));
- // Mocked memory manager for tests that check the maximum array size, since actually allocating
- // such large arrays will cause us to run out of memory in our tests.
- sizeLimitedTaskMemoryManager = mock(TaskMemoryManager.class);
- when(sizeLimitedTaskMemoryManager.allocate(geq(1L << 20))).thenAnswer(
- new Answer<MemoryBlock>() {
- @Override
- public MemoryBlock answer(InvocationOnMock invocation) throws Throwable {
- if (((Long) invocation.getArguments()[0] / 8) > Integer.MAX_VALUE) {
- throw new OutOfMemoryError("Requested array size exceeds VM limit");
- }
- return new MemoryBlock(null, 0, (Long) invocation.getArguments()[0]);
- }
- }
- );
+ memoryManager =
+ new GrantEverythingMemoryManager(
+ new SparkConf().set("spark.unsafe.offHeap", "" + useOffHeapMemoryAllocator()));
+ taskMemoryManager = new TaskMemoryManager(memoryManager, 0);
}
@After
public void tearDown() {
Assert.assertEquals(0L, taskMemoryManager.cleanUpAllAllocatedMemory());
- if (shuffleMemoryManager != null) {
- long leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask();
- shuffleMemoryManager = null;
- Assert.assertEquals(0L, leakedShuffleMemory);
+ if (taskMemoryManager != null) {
+ long leakedMemory = taskMemoryManager.getMemoryConsumptionForThisTask();
+ taskMemoryManager = null;
+ Assert.assertEquals(0L, leakedMemory);
}
}
- protected abstract MemoryAllocator getMemoryAllocator();
+ protected abstract boolean useOffHeapMemoryAllocator();
private static byte[] getByteArray(MemoryLocation loc, int size) {
final byte[] arr = new byte[size];
@@ -110,8 +95,7 @@ public abstract class AbstractBytesToBytesMapSuite {
@Test
public void emptyMap() {
- BytesToBytesMap map = new BytesToBytesMap(
- taskMemoryManager, shuffleMemoryManager, 64, PAGE_SIZE_BYTES);
+ BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 64, PAGE_SIZE_BYTES);
try {
Assert.assertEquals(0, map.numElements());
final int keyLengthInWords = 10;
@@ -126,8 +110,7 @@ public abstract class AbstractBytesToBytesMapSuite {
@Test
public void setAndRetrieveAKey() {
- BytesToBytesMap map = new BytesToBytesMap(
- taskMemoryManager, shuffleMemoryManager, 64, PAGE_SIZE_BYTES);
+ BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 64, PAGE_SIZE_BYTES);
final int recordLengthWords = 10;
final int recordLengthBytes = recordLengthWords * 8;
final byte[] keyData = getRandomByteArray(recordLengthWords);
@@ -179,8 +162,7 @@ public abstract class AbstractBytesToBytesMapSuite {
private void iteratorTestBase(boolean destructive) throws Exception {
final int size = 4096;
- BytesToBytesMap map = new BytesToBytesMap(
- taskMemoryManager, shuffleMemoryManager, size / 2, PAGE_SIZE_BYTES);
+ BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, size / 2, PAGE_SIZE_BYTES);
try {
for (long i = 0; i < size; i++) {
final long[] value = new long[] { i };
@@ -265,8 +247,8 @@ public abstract class AbstractBytesToBytesMapSuite {
final int NUM_ENTRIES = 1000 * 1000;
final int KEY_LENGTH = 24;
final int VALUE_LENGTH = 40;
- final BytesToBytesMap map = new BytesToBytesMap(
- taskMemoryManager, shuffleMemoryManager, NUM_ENTRIES, PAGE_SIZE_BYTES);
+ final BytesToBytesMap map =
+ new BytesToBytesMap(taskMemoryManager, NUM_ENTRIES, PAGE_SIZE_BYTES);
// Each record will take 8 + 24 + 40 = 72 bytes of space in the data page. Our 64-megabyte
// pages won't be evenly-divisible by records of this size, which will cause us to waste some
// space at the end of the page. This is necessary in order for us to take the end-of-record
@@ -335,9 +317,7 @@ public abstract class AbstractBytesToBytesMapSuite {
// Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays
// into ByteBuffers in order to use them as keys here.
final Map<ByteBuffer, byte[]> expected = new HashMap<ByteBuffer, byte[]>();
- final BytesToBytesMap map = new BytesToBytesMap(
- taskMemoryManager, shuffleMemoryManager, size, PAGE_SIZE_BYTES);
-
+ final BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, size, PAGE_SIZE_BYTES);
try {
// Fill the map to 90% full so that we can trigger probing
for (int i = 0; i < size * 0.9; i++) {
@@ -386,8 +366,7 @@ public abstract class AbstractBytesToBytesMapSuite {
@Test
public void randomizedTestWithRecordsLargerThanPageSize() {
final long pageSizeBytes = 128;
- final BytesToBytesMap map = new BytesToBytesMap(
- taskMemoryManager, shuffleMemoryManager, 64, pageSizeBytes);
+ final BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 64, pageSizeBytes);
// Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays
// into ByteBuffers in order to use them as keys here.
final Map<ByteBuffer, byte[]> expected = new HashMap<ByteBuffer, byte[]>();
@@ -436,9 +415,9 @@ public abstract class AbstractBytesToBytesMapSuite {
@Test
public void failureToAllocateFirstPage() {
- shuffleMemoryManager = ShuffleMemoryManager.createForTesting(1024);
- BytesToBytesMap map =
- new BytesToBytesMap(taskMemoryManager, shuffleMemoryManager, 1, PAGE_SIZE_BYTES);
+ memoryManager.markExecutionAsOutOfMemory();
+ BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 1, PAGE_SIZE_BYTES);
+ memoryManager.markExecutionAsOutOfMemory();
try {
final long[] emptyArray = new long[0];
final BytesToBytesMap.Location loc =
@@ -454,12 +433,14 @@ public abstract class AbstractBytesToBytesMapSuite {
@Test
public void failureToGrow() {
- shuffleMemoryManager = ShuffleMemoryManager.createForTesting(1024 * 10);
- BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, shuffleMemoryManager, 1, 1024);
+ BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 1, 1024);
try {
boolean success = true;
int i;
- for (i = 0; i < 1024; i++) {
+ for (i = 0; i < 127; i++) {
+ if (i > 0) {
+ memoryManager.markExecutionAsOutOfMemory();
+ }
final long[] arr = new long[]{i};
final BytesToBytesMap.Location loc = map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8);
success =
@@ -478,7 +459,7 @@ public abstract class AbstractBytesToBytesMapSuite {
@Test
public void initialCapacityBoundsChecking() {
try {
- new BytesToBytesMap(sizeLimitedTaskMemoryManager, shuffleMemoryManager, 0, PAGE_SIZE_BYTES);
+ new BytesToBytesMap(taskMemoryManager, 0, PAGE_SIZE_BYTES);
Assert.fail("Expected IllegalArgumentException to be thrown");
} catch (IllegalArgumentException e) {
// expected exception
@@ -486,36 +467,13 @@ public abstract class AbstractBytesToBytesMapSuite {
try {
new BytesToBytesMap(
- sizeLimitedTaskMemoryManager,
- shuffleMemoryManager,
+ taskMemoryManager,
BytesToBytesMap.MAX_CAPACITY + 1,
PAGE_SIZE_BYTES);
Assert.fail("Expected IllegalArgumentException to be thrown");
} catch (IllegalArgumentException e) {
// expected exception
}
-
- // Ignored because this can OOM now that we allocate the long array w/o a TaskMemoryManager
- // Can allocate _at_ the max capacity
- // BytesToBytesMap map = new BytesToBytesMap(
- // sizeLimitedTaskMemoryManager,
- // shuffleMemoryManager,
- // BytesToBytesMap.MAX_CAPACITY,
- // PAGE_SIZE_BYTES);
- // map.free();
- }
-
- // Ignored because this can OOM now that we allocate the long array w/o a TaskMemoryManager
- @Ignore
- public void resizingLargeMap() {
- // As long as a map's capacity is below the max, we should be able to resize up to the max
- BytesToBytesMap map = new BytesToBytesMap(
- sizeLimitedTaskMemoryManager,
- shuffleMemoryManager,
- BytesToBytesMap.MAX_CAPACITY - 64,
- PAGE_SIZE_BYTES);
- map.growAndRehash();
- map.free();
}
@Test
@@ -523,8 +481,7 @@ public abstract class AbstractBytesToBytesMapSuite {
final long recordLengthBytes = 24;
final long pageSizeBytes = 256 + 8; // 8 bytes for end-of-page marker
final long numRecordsPerPage = (pageSizeBytes - 8) / recordLengthBytes;
- final BytesToBytesMap map = new BytesToBytesMap(
- taskMemoryManager, shuffleMemoryManager, 1024, pageSizeBytes);
+ final BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 1024, pageSizeBytes);
// Since BytesToBytesMap is append-only, we expect the total memory consumption to be
// monotonically increasing. More specifically, every time we allocate a new page it
@@ -564,8 +521,7 @@ public abstract class AbstractBytesToBytesMapSuite {
@Test
public void testAcquirePageInConstructor() {
- final BytesToBytesMap map = new BytesToBytesMap(
- taskMemoryManager, shuffleMemoryManager, 1, PAGE_SIZE_BYTES);
+ final BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 1, PAGE_SIZE_BYTES);
assertEquals(1, map.getNumDataPages());
map.free();
}
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java
index 5a10de49f5..f0bad4d760 100644
--- a/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java
+++ b/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java
@@ -17,13 +17,10 @@
package org.apache.spark.unsafe.map;
-import org.apache.spark.unsafe.memory.MemoryAllocator;
-
public class BytesToBytesMapOffHeapSuite extends AbstractBytesToBytesMapSuite {
@Override
- protected MemoryAllocator getMemoryAllocator() {
- return MemoryAllocator.UNSAFE;
+ protected boolean useOffHeapMemoryAllocator() {
+ return true;
}
-
}
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java
index 12cc9b25d9..d76bb4fd05 100644
--- a/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java
+++ b/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java
@@ -17,13 +17,10 @@
package org.apache.spark.unsafe.map;
-import org.apache.spark.unsafe.memory.MemoryAllocator;
-
public class BytesToBytesMapOnHeapSuite extends AbstractBytesToBytesMapSuite {
@Override
- protected MemoryAllocator getMemoryAllocator() {
- return MemoryAllocator.HEAP;
+ protected boolean useOffHeapMemoryAllocator() {
+ return false;
}
-
}
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
index a5bbaa95fa..94d50b94fd 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -46,20 +46,19 @@ import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.executor.TaskMetrics;
+import org.apache.spark.memory.GrantEverythingMemoryManager;
import org.apache.spark.serializer.SerializerInstance;
-import org.apache.spark.shuffle.ShuffleMemoryManager;
import org.apache.spark.storage.*;
import org.apache.spark.unsafe.Platform;
-import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
-import org.apache.spark.unsafe.memory.MemoryAllocator;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.util.Utils;
public class UnsafeExternalSorterSuite {
final LinkedList<File> spillFilesCreated = new LinkedList<File>();
- final TaskMemoryManager taskMemoryManager =
- new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+ final GrantEverythingMemoryManager memoryManager =
+ new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false"));
+ final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0);
// Use integer comparison for comparing prefixes (which are partition ids, in this case)
final PrefixComparator prefixComparator = new PrefixComparator() {
@Override
@@ -82,7 +81,6 @@ public class UnsafeExternalSorterSuite {
SparkConf sparkConf;
File tempDir;
- ShuffleMemoryManager shuffleMemoryManager;
@Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
@Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
@Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext;
@@ -102,7 +100,6 @@ public class UnsafeExternalSorterSuite {
MockitoAnnotations.initMocks(this);
sparkConf = new SparkConf();
tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "unsafe-test");
- shuffleMemoryManager = ShuffleMemoryManager.create(Long.MAX_VALUE, pageSizeBytes);
spillFilesCreated.clear();
taskContext = mock(TaskContext.class);
when(taskContext.taskMetrics()).thenReturn(new TaskMetrics());
@@ -143,13 +140,7 @@ public class UnsafeExternalSorterSuite {
@After
public void tearDown() {
try {
- long leakedUnsafeMemory = taskMemoryManager.cleanUpAllAllocatedMemory();
- if (shuffleMemoryManager != null) {
- long leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask();
- shuffleMemoryManager = null;
- assertEquals(0L, leakedShuffleMemory);
- }
- assertEquals(0, leakedUnsafeMemory);
+ assertEquals(0L, taskMemoryManager.cleanUpAllAllocatedMemory());
} finally {
Utils.deleteRecursively(tempDir);
tempDir = null;
@@ -178,7 +169,6 @@ public class UnsafeExternalSorterSuite {
private UnsafeExternalSorter newSorter() throws IOException {
return UnsafeExternalSorter.create(
taskMemoryManager,
- shuffleMemoryManager,
blockManager,
taskContext,
recordComparator,
@@ -236,12 +226,16 @@ public class UnsafeExternalSorterSuite {
@Test
public void spillingOccursInResponseToMemoryPressure() throws Exception {
- shuffleMemoryManager = ShuffleMemoryManager.create(pageSizeBytes * 2, pageSizeBytes);
final UnsafeExternalSorter sorter = newSorter();
- final int numRecords = (int) pageSizeBytes / 4;
- for (int i = 0; i <= numRecords; i++) {
+ // This should be enough records to completely fill up a data page:
+ final int numRecords = (int) (pageSizeBytes / (4 + 4));
+ for (int i = 0; i < numRecords; i++) {
insertNumber(sorter, numRecords - i);
}
+ assertEquals(1, sorter.getNumberOfAllocatedPages());
+ memoryManager.markExecutionAsOutOfMemory();
+ // The insertion of this record should trigger a spill:
+ insertNumber(sorter, 0);
// Ensure that spill files were created
assertThat(tempDir.listFiles().length, greaterThanOrEqualTo(1));
// Read back the sorted data:
@@ -255,6 +249,7 @@ public class UnsafeExternalSorterSuite {
assertEquals(i, Platform.getInt(iter.getBaseObject(), iter.getBaseOffset()));
i++;
}
+ assertEquals(numRecords + 1, i);
sorter.cleanupResources();
assertSpillFilesWereCleanedUp();
}
@@ -323,7 +318,6 @@ public class UnsafeExternalSorterSuite {
final long numRecordsPerPage = pageSizeBytes / recordLengthBytes;
final UnsafeExternalSorter sorter = UnsafeExternalSorter.create(
taskMemoryManager,
- shuffleMemoryManager,
blockManager,
taskContext,
recordComparator,
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
index 778e813df6..d5de56a051 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
@@ -26,11 +26,11 @@ import static org.junit.Assert.*;
import static org.mockito.Mockito.mock;
import org.apache.spark.HashPartitioner;
+import org.apache.spark.SparkConf;
import org.apache.spark.unsafe.Platform;
-import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
-import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.memory.GrantEverythingMemoryManager;
import org.apache.spark.unsafe.memory.MemoryBlock;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.memory.TaskMemoryManager;
public class UnsafeInMemorySorterSuite {
@@ -43,7 +43,8 @@ public class UnsafeInMemorySorterSuite {
@Test
public void testSortingEmptyInput() {
final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(
- new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)),
+ new TaskMemoryManager(
+ new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0),
mock(RecordComparator.class),
mock(PrefixComparator.class),
100);
@@ -64,8 +65,8 @@ public class UnsafeInMemorySorterSuite {
"Lychee",
"Mango"
};
- final TaskMemoryManager memoryManager =
- new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+ final TaskMemoryManager memoryManager = new TaskMemoryManager(
+ new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0);
final MemoryBlock dataPage = memoryManager.allocatePage(2048);
final Object baseObject = dataPage.getBaseObject();
// Write the records into the data page:
diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala
index f58756e6f6..0242cbc924 100644
--- a/core/src/test/scala/org/apache/spark/FailureSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala
@@ -149,7 +149,7 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext {
// cause is preserved
val thrownDueToTaskFailure = intercept[SparkException] {
sc.parallelize(Seq(0)).mapPartitions { iter =>
- TaskContext.get().taskMemoryManager().allocate(128)
+ TaskContext.get().taskMemoryManager().allocatePage(128)
throw new Exception("intentional task failure")
iter
}.count()
@@ -159,7 +159,7 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext {
// If the task succeeded but memory was leaked, then the task should fail due to that leak
val thrownDueToMemoryLeak = intercept[SparkException] {
sc.parallelize(Seq(0)).mapPartitions { iter =>
- TaskContext.get().taskMemoryManager().allocate(128)
+ TaskContext.get().taskMemoryManager().allocatePage(128)
iter
}.count()
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala b/core/src/test/scala/org/apache/spark/memory/GrantEverythingMemoryManager.scala
index c4358f409b..fe102d8aeb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TestShuffleMemoryManager.scala
+++ b/core/src/test/scala/org/apache/spark/memory/GrantEverythingMemoryManager.scala
@@ -15,51 +15,25 @@
* limitations under the License.
*/
-package org.apache.spark.sql.execution
+package org.apache.spark.memory
import scala.collection.mutable
-import org.apache.spark.memory.MemoryManager
-import org.apache.spark.shuffle.ShuffleMemoryManager
-import org.apache.spark.storage.{BlockId, BlockStatus}
+import org.apache.spark.SparkConf
+import org.apache.spark.storage.{BlockStatus, BlockId}
-
-/**
- * A [[ShuffleMemoryManager]] that can be controlled to run out of memory.
- */
-class TestShuffleMemoryManager
- extends ShuffleMemoryManager(new GrantEverythingMemoryManager, 4 * 1024 * 1024) {
- private var oom = false
-
- override def tryToAcquire(numBytes: Long): Long = {
+class GrantEverythingMemoryManager(conf: SparkConf) extends MemoryManager(conf, numCores = 1) {
+ private[memory] override def doAcquireExecutionMemory(
+ numBytes: Long,
+ evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long = synchronized {
if (oom) {
oom = false
0
} else {
- // Uncomment the following to trace memory allocations.
- // println(s"tryToAcquire $numBytes in " +
- // Thread.currentThread().getStackTrace.mkString("", "\n -", ""))
- val acquired = super.tryToAcquire(numBytes)
- acquired
+ _executionMemoryUsed += numBytes // To suppress warnings when freeing unallocated memory
+ numBytes
}
}
-
- override def release(numBytes: Long): Unit = {
- // Uncomment the following to trace memory releases.
- // println(s"release $numBytes in " +
- // Thread.currentThread().getStackTrace.mkString("", "\n -", ""))
- super.release(numBytes)
- }
-
- def markAsOutOfMemory(): Unit = {
- oom = true
- }
-}
-
-private class GrantEverythingMemoryManager extends MemoryManager {
- override def acquireExecutionMemory(
- numBytes: Long,
- evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long = numBytes
override def acquireStorageMemory(
blockId: BlockId,
numBytes: Long,
@@ -68,8 +42,13 @@ private class GrantEverythingMemoryManager extends MemoryManager {
blockId: BlockId,
numBytes: Long,
evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = true
- override def releaseExecutionMemory(numBytes: Long): Unit = { }
override def releaseStorageMemory(numBytes: Long): Unit = { }
override def maxExecutionMemory: Long = Long.MaxValue
override def maxStorageMemory: Long = Long.MaxValue
+
+ private var oom = false
+
+ def markExecutionAsOutOfMemory(): Unit = {
+ oom = true
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala
index 36e4566310..1265087743 100644
--- a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala
@@ -19,10 +19,14 @@ package org.apache.spark.memory
import java.util.concurrent.atomic.AtomicLong
+import scala.concurrent.duration.Duration
+import scala.concurrent.{Await, ExecutionContext, Future}
+
import org.mockito.Matchers.{any, anyLong}
import org.mockito.Mockito.{mock, when}
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
+import org.scalatest.time.SpanSugar._
import org.apache.spark.SparkFunSuite
import org.apache.spark.storage.MemoryStore
@@ -126,6 +130,136 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite {
assert(ensureFreeSpaceCalled.get() === DEFAULT_ENSURE_FREE_SPACE_CALLED,
"ensure free space should not have been called!")
}
+
+ /**
+ * Create a MemoryManager with the specified execution memory limit and no storage memory.
+ */
+ protected def createMemoryManager(maxExecutionMemory: Long): MemoryManager
+
+ // -- Tests of sharing of execution memory between tasks ----------------------------------------
+ // Prior to Spark 1.6, these tests were part of ShuffleMemoryManagerSuite.
+
+ implicit val ec = ExecutionContext.global
+
+ test("single task requesting execution memory") {
+ val manager = createMemoryManager(1000L)
+ val taskMemoryManager = new TaskMemoryManager(manager, 0)
+
+ assert(taskMemoryManager.acquireExecutionMemory(100L) === 100L)
+ assert(taskMemoryManager.acquireExecutionMemory(400L) === 400L)
+ assert(taskMemoryManager.acquireExecutionMemory(400L) === 400L)
+ assert(taskMemoryManager.acquireExecutionMemory(200L) === 100L)
+ assert(taskMemoryManager.acquireExecutionMemory(100L) === 0L)
+ assert(taskMemoryManager.acquireExecutionMemory(100L) === 0L)
+
+ taskMemoryManager.releaseExecutionMemory(500L)
+ assert(taskMemoryManager.acquireExecutionMemory(300L) === 300L)
+ assert(taskMemoryManager.acquireExecutionMemory(300L) === 200L)
+
+ taskMemoryManager.cleanUpAllAllocatedMemory()
+ assert(taskMemoryManager.acquireExecutionMemory(1000L) === 1000L)
+ assert(taskMemoryManager.acquireExecutionMemory(100L) === 0L)
+ }
+
+ test("two tasks requesting full execution memory") {
+ val memoryManager = createMemoryManager(1000L)
+ val t1MemManager = new TaskMemoryManager(memoryManager, 1)
+ val t2MemManager = new TaskMemoryManager(memoryManager, 2)
+ val futureTimeout: Duration = 20.seconds
+
+ // Have both tasks request 500 bytes, then wait until both requests have been granted:
+ val t1Result1 = Future { t1MemManager.acquireExecutionMemory(500L) }
+ val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L) }
+ assert(Await.result(t1Result1, futureTimeout) === 500L)
+ assert(Await.result(t2Result1, futureTimeout) === 500L)
+
+ // Have both tasks each request 500 bytes more; both should immediately return 0 as they are
+ // both now at 1 / N
+ val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L) }
+ val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L) }
+ assert(Await.result(t1Result2, 200.millis) === 0L)
+ assert(Await.result(t2Result2, 200.millis) === 0L)
+ }
+
+ test("two tasks cannot grow past 1 / N of execution memory") {
+ val memoryManager = createMemoryManager(1000L)
+ val t1MemManager = new TaskMemoryManager(memoryManager, 1)
+ val t2MemManager = new TaskMemoryManager(memoryManager, 2)
+ val futureTimeout: Duration = 20.seconds
+
+ // Have both tasks request 250 bytes, then wait until both requests have been granted:
+ val t1Result1 = Future { t1MemManager.acquireExecutionMemory(250L) }
+ val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L) }
+ assert(Await.result(t1Result1, futureTimeout) === 250L)
+ assert(Await.result(t2Result1, futureTimeout) === 250L)
+
+ // Have both tasks each request 500 bytes more.
+ // We should only grant 250 bytes to each of them on this second request
+ val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L) }
+ val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L) }
+ assert(Await.result(t1Result2, futureTimeout) === 250L)
+ assert(Await.result(t2Result2, futureTimeout) === 250L)
+ }
+
+ test("tasks can block to get at least 1 / 2N of execution memory") {
+ val memoryManager = createMemoryManager(1000L)
+ val t1MemManager = new TaskMemoryManager(memoryManager, 1)
+ val t2MemManager = new TaskMemoryManager(memoryManager, 2)
+ val futureTimeout: Duration = 20.seconds
+
+ // t1 grabs 1000 bytes and then waits until t2 is ready to make a request.
+ val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L) }
+ assert(Await.result(t1Result1, futureTimeout) === 1000L)
+ val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L) }
+ // Make sure that t2 didn't grab the memory right away. This is hacky but it would be difficult
+ // to make sure the other thread blocks for some time otherwise.
+ Thread.sleep(300)
+ t1MemManager.releaseExecutionMemory(250L)
+ // The memory freed from t1 should now be granted to t2.
+ assert(Await.result(t2Result1, futureTimeout) === 250L)
+ // Further requests by t2 should be denied immediately because it now has 1 / 2N of the memory.
+ val t2Result2 = Future { t2MemManager.acquireExecutionMemory(100L) }
+ assert(Await.result(t2Result2, 200.millis) === 0L)
+ }
+
+ test("TaskMemoryManager.cleanUpAllAllocatedMemory") {
+ val memoryManager = createMemoryManager(1000L)
+ val t1MemManager = new TaskMemoryManager(memoryManager, 1)
+ val t2MemManager = new TaskMemoryManager(memoryManager, 2)
+ val futureTimeout: Duration = 20.seconds
+
+ // t1 grabs 1000 bytes and then waits until t2 is ready to make a request.
+ val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L) }
+ assert(Await.result(t1Result1, futureTimeout) === 1000L)
+ val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L) }
+ // Make sure that t2 didn't grab the memory right away. This is hacky but it would be difficult
+ // to make sure the other thread blocks for some time otherwise.
+ Thread.sleep(300)
+ // t1 releases all of its memory, so t2 should be able to grab all of the memory
+ t1MemManager.cleanUpAllAllocatedMemory()
+ assert(Await.result(t2Result1, futureTimeout) === 500L)
+ val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L) }
+ assert(Await.result(t2Result2, futureTimeout) === 500L)
+ val t2Result3 = Future { t2MemManager.acquireExecutionMemory(500L) }
+ assert(Await.result(t2Result3, 200.millis) === 0L)
+ }
+
+ test("tasks should not be granted a negative amount of execution memory") {
+ // This is a regression test for SPARK-4715.
+ val memoryManager = createMemoryManager(1000L)
+ val t1MemManager = new TaskMemoryManager(memoryManager, 1)
+ val t2MemManager = new TaskMemoryManager(memoryManager, 2)
+ val futureTimeout: Duration = 20.seconds
+
+ val t1Result1 = Future { t1MemManager.acquireExecutionMemory(700L) }
+ assert(Await.result(t1Result1, futureTimeout) === 700L)
+
+ val t2Result1 = Future { t2MemManager.acquireExecutionMemory(300L) }
+ assert(Await.result(t2Result1, futureTimeout) === 300L)
+
+ val t1Result2 = Future { t1MemManager.acquireExecutionMemory(300L) }
+ assert(Await.result(t1Result2, 200.millis) === 0L)
+ }
}
private object MemoryManagerSuite {
diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala
new file mode 100644
index 0000000000..4b4c3b0311
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.memory
+
+import org.apache.spark.{SparkEnv, TaskContextImpl, TaskContext}
+
+/**
+ * Helper methods for mocking out memory-management-related classes in tests.
+ */
+object MemoryTestingUtils {
+ def fakeTaskContext(env: SparkEnv): TaskContext = {
+ val taskMemoryManager = new TaskMemoryManager(env.memoryManager, 0)
+ new TaskContextImpl(
+ stageId = 0,
+ partitionId = 0,
+ taskAttemptId = 0,
+ attemptNumber = 0,
+ taskMemoryManager = taskMemoryManager,
+ metricsSystem = env.metricsSystem,
+ internalAccumulators = Seq.empty)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala
index 6cae1f871e..885c450d6d 100644
--- a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala
@@ -36,27 +36,35 @@ class StaticMemoryManagerSuite extends MemoryManagerSuite {
maxExecutionMem: Long,
maxStorageMem: Long): (StaticMemoryManager, MemoryStore) = {
val mm = new StaticMemoryManager(
- conf, maxExecutionMemory = maxExecutionMem, maxStorageMemory = maxStorageMem)
+ conf, maxExecutionMemory = maxExecutionMem, maxStorageMemory = maxStorageMem, numCores = 1)
val ms = makeMemoryStore(mm)
(mm, ms)
}
+ override protected def createMemoryManager(maxMemory: Long): MemoryManager = {
+ new StaticMemoryManager(
+ conf,
+ maxExecutionMemory = maxMemory,
+ maxStorageMemory = 0,
+ numCores = 1)
+ }
+
test("basic execution memory") {
val maxExecutionMem = 1000L
val (mm, _) = makeThings(maxExecutionMem, Long.MaxValue)
assert(mm.executionMemoryUsed === 0L)
- assert(mm.acquireExecutionMemory(10L, evictedBlocks) === 10L)
+ assert(mm.doAcquireExecutionMemory(10L, evictedBlocks) === 10L)
assert(mm.executionMemoryUsed === 10L)
- assert(mm.acquireExecutionMemory(100L, evictedBlocks) === 100L)
+ assert(mm.doAcquireExecutionMemory(100L, evictedBlocks) === 100L)
// Acquire up to the max
- assert(mm.acquireExecutionMemory(1000L, evictedBlocks) === 890L)
+ assert(mm.doAcquireExecutionMemory(1000L, evictedBlocks) === 890L)
assert(mm.executionMemoryUsed === maxExecutionMem)
- assert(mm.acquireExecutionMemory(1L, evictedBlocks) === 0L)
+ assert(mm.doAcquireExecutionMemory(1L, evictedBlocks) === 0L)
assert(mm.executionMemoryUsed === maxExecutionMem)
mm.releaseExecutionMemory(800L)
assert(mm.executionMemoryUsed === 200L)
// Acquire after release
- assert(mm.acquireExecutionMemory(1L, evictedBlocks) === 1L)
+ assert(mm.doAcquireExecutionMemory(1L, evictedBlocks) === 1L)
assert(mm.executionMemoryUsed === 201L)
// Release beyond what was acquired
mm.releaseExecutionMemory(maxExecutionMem)
@@ -108,10 +116,10 @@ class StaticMemoryManagerSuite extends MemoryManagerSuite {
val dummyBlock = TestBlockId("ain't nobody love like you do")
val (mm, ms) = makeThings(maxExecutionMem, maxStorageMem)
// Only execution memory should increase
- assert(mm.acquireExecutionMemory(100L, evictedBlocks) === 100L)
+ assert(mm.doAcquireExecutionMemory(100L, evictedBlocks) === 100L)
assert(mm.storageMemoryUsed === 0L)
assert(mm.executionMemoryUsed === 100L)
- assert(mm.acquireExecutionMemory(1000L, evictedBlocks) === 100L)
+ assert(mm.doAcquireExecutionMemory(1000L, evictedBlocks) === 100L)
assert(mm.storageMemoryUsed === 0L)
assert(mm.executionMemoryUsed === 200L)
// Only storage memory should increase
diff --git a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala
index e7baa50dc2..0c97f2bd89 100644
--- a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala
@@ -34,11 +34,15 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes
* Make a [[UnifiedMemoryManager]] and a [[MemoryStore]] with limited class dependencies.
*/
private def makeThings(maxMemory: Long): (UnifiedMemoryManager, MemoryStore) = {
- val mm = new UnifiedMemoryManager(conf, maxMemory)
+ val mm = new UnifiedMemoryManager(conf, maxMemory, numCores = 1)
val ms = makeMemoryStore(mm)
(mm, ms)
}
+ override protected def createMemoryManager(maxMemory: Long): MemoryManager = {
+ new UnifiedMemoryManager(conf, maxMemory, numCores = 1)
+ }
+
private def getStorageRegionSize(mm: UnifiedMemoryManager): Long = {
mm invokePrivate PrivateMethod[Long]('storageRegionSize)()
}
@@ -56,18 +60,18 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes
val maxMemory = 1000L
val (mm, _) = makeThings(maxMemory)
assert(mm.executionMemoryUsed === 0L)
- assert(mm.acquireExecutionMemory(10L, evictedBlocks) === 10L)
+ assert(mm.doAcquireExecutionMemory(10L, evictedBlocks) === 10L)
assert(mm.executionMemoryUsed === 10L)
- assert(mm.acquireExecutionMemory(100L, evictedBlocks) === 100L)
+ assert(mm.doAcquireExecutionMemory(100L, evictedBlocks) === 100L)
// Acquire up to the max
- assert(mm.acquireExecutionMemory(1000L, evictedBlocks) === 890L)
+ assert(mm.doAcquireExecutionMemory(1000L, evictedBlocks) === 890L)
assert(mm.executionMemoryUsed === maxMemory)
- assert(mm.acquireExecutionMemory(1L, evictedBlocks) === 0L)
+ assert(mm.doAcquireExecutionMemory(1L, evictedBlocks) === 0L)
assert(mm.executionMemoryUsed === maxMemory)
mm.releaseExecutionMemory(800L)
assert(mm.executionMemoryUsed === 200L)
// Acquire after release
- assert(mm.acquireExecutionMemory(1L, evictedBlocks) === 1L)
+ assert(mm.doAcquireExecutionMemory(1L, evictedBlocks) === 1L)
assert(mm.executionMemoryUsed === 201L)
// Release beyond what was acquired
mm.releaseExecutionMemory(maxMemory)
@@ -132,12 +136,12 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes
require(mm.storageMemoryUsed > storageRegionSize,
s"bad test: storage memory used should exceed the storage region")
// Execution needs to request 250 bytes to evict storage memory
- assert(mm.acquireExecutionMemory(100L, evictedBlocks) === 100L)
+ assert(mm.doAcquireExecutionMemory(100L, evictedBlocks) === 100L)
assert(mm.executionMemoryUsed === 100L)
assert(mm.storageMemoryUsed === 750L)
assertEnsureFreeSpaceNotCalled(ms)
// Execution wants 200 bytes but only 150 are free, so storage is evicted
- assert(mm.acquireExecutionMemory(200L, evictedBlocks) === 200L)
+ assert(mm.doAcquireExecutionMemory(200L, evictedBlocks) === 200L)
assertEnsureFreeSpaceCalled(ms, 200L)
assert(mm.executionMemoryUsed === 300L)
mm.releaseAllStorageMemory()
@@ -151,7 +155,7 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes
s"bad test: storage memory used should be within the storage region")
// Execution cannot evict storage because the latter is within the storage fraction,
// so grant only what's remaining without evicting anything, i.e. 1000 - 300 - 400 = 300
- assert(mm.acquireExecutionMemory(400L, evictedBlocks) === 300L)
+ assert(mm.doAcquireExecutionMemory(400L, evictedBlocks) === 300L)
assert(mm.executionMemoryUsed === 600L)
assert(mm.storageMemoryUsed === 400L)
assertEnsureFreeSpaceNotCalled(ms)
@@ -170,7 +174,7 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes
require(executionRegionSize === expectedExecutionRegionSize,
"bad test: storage region size is unexpected")
// Acquire enough execution memory to exceed the execution region
- assert(mm.acquireExecutionMemory(800L, evictedBlocks) === 800L)
+ assert(mm.doAcquireExecutionMemory(800L, evictedBlocks) === 800L)
assert(mm.executionMemoryUsed === 800L)
assert(mm.storageMemoryUsed === 0L)
assertEnsureFreeSpaceNotCalled(ms)
@@ -188,7 +192,7 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes
mm.releaseExecutionMemory(maxMemory)
mm.releaseStorageMemory(maxMemory)
// Acquire some execution memory again, but this time keep it within the execution region
- assert(mm.acquireExecutionMemory(200L, evictedBlocks) === 200L)
+ assert(mm.doAcquireExecutionMemory(200L, evictedBlocks) === 200L)
assert(mm.executionMemoryUsed === 200L)
assert(mm.storageMemoryUsed === 0L)
assertEnsureFreeSpaceNotCalled(ms)
diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala
deleted file mode 100644
index 5877aa042d..0000000000
--- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala
+++ /dev/null
@@ -1,326 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle
-
-import java.util.concurrent.CountDownLatch
-import java.util.concurrent.atomic.AtomicInteger
-
-import org.mockito.Mockito._
-import org.scalatest.concurrent.Timeouts
-import org.scalatest.time.SpanSugar._
-
-import org.apache.spark.{SparkFunSuite, TaskContext}
-import org.apache.spark.executor.TaskMetrics
-
-class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts {
-
- val nextTaskAttemptId = new AtomicInteger()
-
- /** Launch a thread with the given body block and return it. */
- private def startThread(name: String)(body: => Unit): Thread = {
- val thread = new Thread("ShuffleMemorySuite " + name) {
- override def run() {
- try {
- val taskAttemptId = nextTaskAttemptId.getAndIncrement
- val mockTaskContext = mock(classOf[TaskContext], RETURNS_SMART_NULLS)
- val taskMetrics = new TaskMetrics
- when(mockTaskContext.taskAttemptId()).thenReturn(taskAttemptId)
- when(mockTaskContext.taskMetrics()).thenReturn(taskMetrics)
- TaskContext.setTaskContext(mockTaskContext)
- body
- } finally {
- TaskContext.unset()
- }
- }
- }
- thread.start()
- thread
- }
-
- test("single task requesting memory") {
- val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L)
-
- assert(manager.tryToAcquire(100L) === 100L)
- assert(manager.tryToAcquire(400L) === 400L)
- assert(manager.tryToAcquire(400L) === 400L)
- assert(manager.tryToAcquire(200L) === 100L)
- assert(manager.tryToAcquire(100L) === 0L)
- assert(manager.tryToAcquire(100L) === 0L)
-
- manager.release(500L)
- assert(manager.tryToAcquire(300L) === 300L)
- assert(manager.tryToAcquire(300L) === 200L)
-
- manager.releaseMemoryForThisTask()
- assert(manager.tryToAcquire(1000L) === 1000L)
- assert(manager.tryToAcquire(100L) === 0L)
- }
-
- test("two threads requesting full memory") {
- // Two threads request 500 bytes first, wait for each other to get it, and then request
- // 500 more; we should immediately return 0 as both are now at 1 / N
-
- val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L)
-
- class State {
- var t1Result1 = -1L
- var t2Result1 = -1L
- var t1Result2 = -1L
- var t2Result2 = -1L
- }
- val state = new State
-
- val t1 = startThread("t1") {
- val r1 = manager.tryToAcquire(500L)
- state.synchronized {
- state.t1Result1 = r1
- state.notifyAll()
- while (state.t2Result1 === -1L) {
- state.wait()
- }
- }
- val r2 = manager.tryToAcquire(500L)
- state.synchronized { state.t1Result2 = r2 }
- }
-
- val t2 = startThread("t2") {
- val r1 = manager.tryToAcquire(500L)
- state.synchronized {
- state.t2Result1 = r1
- state.notifyAll()
- while (state.t1Result1 === -1L) {
- state.wait()
- }
- }
- val r2 = manager.tryToAcquire(500L)
- state.synchronized { state.t2Result2 = r2 }
- }
-
- failAfter(20 seconds) {
- t1.join()
- t2.join()
- }
-
- assert(state.t1Result1 === 500L)
- assert(state.t2Result1 === 500L)
- assert(state.t1Result2 === 0L)
- assert(state.t2Result2 === 0L)
- }
-
-
- test("tasks cannot grow past 1 / N") {
- // Two tasks request 250 bytes first, wait for each other to get it, and then request
- // 500 more; we should only grant 250 bytes to each of them on this second request
-
- val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L)
-
- class State {
- var t1Result1 = -1L
- var t2Result1 = -1L
- var t1Result2 = -1L
- var t2Result2 = -1L
- }
- val state = new State
-
- val t1 = startThread("t1") {
- val r1 = manager.tryToAcquire(250L)
- state.synchronized {
- state.t1Result1 = r1
- state.notifyAll()
- while (state.t2Result1 === -1L) {
- state.wait()
- }
- }
- val r2 = manager.tryToAcquire(500L)
- state.synchronized { state.t1Result2 = r2 }
- }
-
- val t2 = startThread("t2") {
- val r1 = manager.tryToAcquire(250L)
- state.synchronized {
- state.t2Result1 = r1
- state.notifyAll()
- while (state.t1Result1 === -1L) {
- state.wait()
- }
- }
- val r2 = manager.tryToAcquire(500L)
- state.synchronized { state.t2Result2 = r2 }
- }
-
- failAfter(20 seconds) {
- t1.join()
- t2.join()
- }
-
- assert(state.t1Result1 === 250L)
- assert(state.t2Result1 === 250L)
- assert(state.t1Result2 === 250L)
- assert(state.t2Result2 === 250L)
- }
-
- test("tasks can block to get at least 1 / 2N memory") {
- // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps
- // for a bit and releases 250 bytes, which should then be granted to t2. Further requests
- // by t2 will return false right away because it now has 1 / 2N of the memory.
-
- val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L)
-
- class State {
- var t1Requested = false
- var t2Requested = false
- var t1Result = -1L
- var t2Result = -1L
- var t2Result2 = -1L
- var t2WaitTime = 0L
- }
- val state = new State
-
- val t1 = startThread("t1") {
- state.synchronized {
- state.t1Result = manager.tryToAcquire(1000L)
- state.t1Requested = true
- state.notifyAll()
- while (!state.t2Requested) {
- state.wait()
- }
- }
- // Sleep a bit before releasing our memory; this is hacky but it would be difficult to make
- // sure the other thread blocks for some time otherwise
- Thread.sleep(300)
- manager.release(250L)
- }
-
- val t2 = startThread("t2") {
- state.synchronized {
- while (!state.t1Requested) {
- state.wait()
- }
- state.t2Requested = true
- state.notifyAll()
- }
- val startTime = System.currentTimeMillis()
- val result = manager.tryToAcquire(250L)
- val endTime = System.currentTimeMillis()
- state.synchronized {
- state.t2Result = result
- // A second call should return 0 because we're now already at 1 / 2N
- state.t2Result2 = manager.tryToAcquire(100L)
- state.t2WaitTime = endTime - startTime
- }
- }
-
- failAfter(20 seconds) {
- t1.join()
- t2.join()
- }
-
- // Both threads should've been able to acquire their memory; the second one will have waited
- // until the first one acquired 1000 bytes and then released 250
- state.synchronized {
- assert(state.t1Result === 1000L, "t1 could not allocate memory")
- assert(state.t2Result === 250L, "t2 could not allocate memory")
- assert(state.t2WaitTime > 200, s"t2 waited less than 200 ms (${state.t2WaitTime})")
- assert(state.t2Result2 === 0L, "t1 got extra memory the second time")
- }
- }
-
- test("releaseMemoryForThisTask") {
- // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps
- // for a bit and releases all its memory. t2 should now be able to grab all the memory.
-
- val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L)
-
- class State {
- var t1Requested = false
- var t2Requested = false
- var t1Result = -1L
- var t2Result1 = -1L
- var t2Result2 = -1L
- var t2Result3 = -1L
- var t2WaitTime = 0L
- }
- val state = new State
-
- val t1 = startThread("t1") {
- state.synchronized {
- state.t1Result = manager.tryToAcquire(1000L)
- state.t1Requested = true
- state.notifyAll()
- while (!state.t2Requested) {
- state.wait()
- }
- }
- // Sleep a bit before releasing our memory; this is hacky but it would be difficult to make
- // sure the other task blocks for some time otherwise
- Thread.sleep(300)
- manager.releaseMemoryForThisTask()
- }
-
- val t2 = startThread("t2") {
- state.synchronized {
- while (!state.t1Requested) {
- state.wait()
- }
- state.t2Requested = true
- state.notifyAll()
- }
- val startTime = System.currentTimeMillis()
- val r1 = manager.tryToAcquire(500L)
- val endTime = System.currentTimeMillis()
- val r2 = manager.tryToAcquire(500L)
- val r3 = manager.tryToAcquire(500L)
- state.synchronized {
- state.t2Result1 = r1
- state.t2Result2 = r2
- state.t2Result3 = r3
- state.t2WaitTime = endTime - startTime
- }
- }
-
- failAfter(20 seconds) {
- t1.join()
- t2.join()
- }
-
- // Both tasks should've been able to acquire their memory; the second one will have waited
- // until the first one acquired 1000 bytes and then released all of it
- state.synchronized {
- assert(state.t1Result === 1000L, "t1 could not allocate memory")
- assert(state.t2Result1 === 500L, "t2 didn't get 500 bytes the first time")
- assert(state.t2Result2 === 500L, "t2 didn't get 500 bytes the second time")
- assert(state.t2Result3 === 0L, s"t2 got more bytes a third time (${state.t2Result3})")
- assert(state.t2WaitTime > 200, s"t2 waited less than 200 ms (${state.t2WaitTime})")
- }
- }
-
- test("tasks should not be granted a negative size") {
- val manager = ShuffleMemoryManager.createForTesting(maxMemory = 1000L)
- manager.tryToAcquire(700L)
-
- val latch = new CountDownLatch(1)
- startThread("t1") {
- manager.tryToAcquire(300L)
- latch.countDown()
- }
- latch.await() // Wait until `t1` calls `tryToAcquire`
-
- val granted = manager.tryToAcquire(300L)
- assert(0 === granted, "granted is negative")
- }
-}
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
index cc44c676b2..6e3f500e15 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
@@ -61,7 +61,7 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo
maxMem: Long,
name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = {
val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1)
- val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem)
+ val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1)
val store = new BlockManager(name, rpcEnv, master, serializer, conf,
memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0)
memManager.setMemoryStore(store.memoryStore)
@@ -261,7 +261,7 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo
val failableTransfer = mock(classOf[BlockTransferService]) // this wont actually work
when(failableTransfer.hostName).thenReturn("some-hostname")
when(failableTransfer.port).thenReturn(1000)
- val memManager = new StaticMemoryManager(conf, Long.MaxValue, 10000)
+ val memManager = new StaticMemoryManager(conf, Long.MaxValue, 10000, numCores = 1)
val failableStore = new BlockManager("failable-store", rpcEnv, master, serializer, conf,
memManager, mapOutputTracker, shuffleManager, failableTransfer, securityMgr, 0)
memManager.setMemoryStore(failableStore.memoryStore)
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index f3fab33ca2..d49015afcd 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -68,7 +68,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
maxMem: Long,
name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = {
val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1)
- val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem)
+ val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1)
val blockManager = new BlockManager(name, rpcEnv, master, serializer, conf,
memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0)
memManager.setMemoryStore(blockManager.memoryStore)
@@ -823,7 +823,11 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
test("block store put failure") {
// Use Java serializer so we can create an unserializable error.
val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1)
- val memoryManager = new StaticMemoryManager(conf, Long.MaxValue, 1200)
+ val memoryManager = new StaticMemoryManager(
+ conf,
+ maxExecutionMemory = Long.MaxValue,
+ maxStorageMemory = 1200,
+ numCores = 1)
store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, master,
new JavaSerializer(conf), conf, memoryManager, mapOutputTracker,
shuffleManager, transfer, securityMgr, 0)
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
index 5cb506ea21..dc3185a6d5 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark._
import org.apache.spark.io.CompressionCodec
-
+import org.apache.spark.memory.MemoryTestingUtils
class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext {
import TestUtils.{assertNotSpilled, assertSpilled}
@@ -32,8 +32,11 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext {
private def mergeCombiners[T](buf1: ArrayBuffer[T], buf2: ArrayBuffer[T]): ArrayBuffer[T] =
buf1 ++= buf2
- private def createExternalMap[T] = new ExternalAppendOnlyMap[T, T, ArrayBuffer[T]](
- createCombiner[T], mergeValue[T], mergeCombiners[T])
+ private def createExternalMap[T] = {
+ val context = MemoryTestingUtils.fakeTaskContext(sc.env)
+ new ExternalAppendOnlyMap[T, T, ArrayBuffer[T]](
+ createCombiner[T], mergeValue[T], mergeCombiners[T], context = context)
+ }
private def createSparkConf(loadDefaults: Boolean, codec: Option[String] = None): SparkConf = {
val conf = new SparkConf(loadDefaults)
@@ -49,23 +52,27 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext {
conf
}
- test("simple insert") {
+ test("single insert insert") {
val conf = createSparkConf(loadDefaults = false)
sc = new SparkContext("local", "test", conf)
val map = createExternalMap[Int]
-
- // Single insert
map.insert(1, 10)
- var it = map.iterator
+ val it = map.iterator
assert(it.hasNext)
val kv = it.next()
assert(kv._1 === 1 && kv._2 === ArrayBuffer[Int](10))
assert(!it.hasNext)
+ sc.stop()
+ }
- // Multiple insert
+ test("multiple insert") {
+ val conf = createSparkConf(loadDefaults = false)
+ sc = new SparkContext("local", "test", conf)
+ val map = createExternalMap[Int]
+ map.insert(1, 10)
map.insert(2, 20)
map.insert(3, 30)
- it = map.iterator
+ val it = map.iterator
assert(it.hasNext)
assert(it.toSet === Set[(Int, ArrayBuffer[Int])](
(1, ArrayBuffer[Int](10)),
@@ -144,39 +151,22 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext {
sc = new SparkContext("local", "test", conf)
val map = createExternalMap[Int]
+ val nullInt = null.asInstanceOf[Int]
map.insert(1, 5)
map.insert(2, 6)
map.insert(3, 7)
- assert(map.size === 3)
- assert(map.iterator.toSet === Set[(Int, Seq[Int])](
- (1, Seq[Int](5)),
- (2, Seq[Int](6)),
- (3, Seq[Int](7))
- ))
-
- // Null keys
- val nullInt = null.asInstanceOf[Int]
+ map.insert(4, nullInt)
map.insert(nullInt, 8)
- assert(map.size === 4)
- assert(map.iterator.toSet === Set[(Int, Seq[Int])](
+ map.insert(nullInt, nullInt)
+ val result = map.iterator.toSet[(Int, ArrayBuffer[Int])].map(kv => (kv._1, kv._2.sorted))
+ assert(result === Set[(Int, Seq[Int])](
(1, Seq[Int](5)),
(2, Seq[Int](6)),
(3, Seq[Int](7)),
- (nullInt, Seq[Int](8))
+ (4, Seq[Int](nullInt)),
+ (nullInt, Seq[Int](nullInt, 8))
))
- // Null values
- map.insert(4, nullInt)
- map.insert(nullInt, nullInt)
- assert(map.size === 5)
- val result = map.iterator.toSet[(Int, ArrayBuffer[Int])].map(kv => (kv._1, kv._2.toSet))
- assert(result === Set[(Int, Set[Int])](
- (1, Set[Int](5)),
- (2, Set[Int](6)),
- (3, Set[Int](7)),
- (4, Set[Int](nullInt)),
- (nullInt, Set[Int](nullInt, 8))
- ))
sc.stop()
}
@@ -344,7 +334,9 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext {
val conf = createSparkConf(loadDefaults = true)
conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString)
sc = new SparkContext("local-cluster[1,1,1024]", "test", conf)
- val map = new ExternalAppendOnlyMap[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _)
+ val context = MemoryTestingUtils.fakeTaskContext(sc.env)
+ val map =
+ new ExternalAppendOnlyMap[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _, context = context)
// Insert 10 copies each of lots of objects whose hash codes are either 0 or 1. This causes
// problems if the map fails to group together the objects with the same code (SPARK-2043).
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
index e2cb791771..d7b2d07a40 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.util.collection
+import org.apache.spark.memory.MemoryTestingUtils
+
import scala.collection.mutable.ArrayBuffer
import scala.util.Random
@@ -98,6 +100,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
val conf = createSparkConf(loadDefaults = true, kryo = false)
conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString)
sc = new SparkContext("local-cluster[1,1,1024]", "test", conf)
+ val context = MemoryTestingUtils.fakeTaskContext(sc.env)
def createCombiner(i: String): ArrayBuffer[String] = ArrayBuffer[String](i)
def mergeValue(buffer: ArrayBuffer[String], i: String): ArrayBuffer[String] = buffer += i
@@ -109,7 +112,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
createCombiner _, mergeValue _, mergeCombiners _)
val sorter = new ExternalSorter[String, String, ArrayBuffer[String]](
- Some(agg), None, None, None)
+ context, Some(agg), None, None, None)
val collisionPairs = Seq(
("Aa", "BB"), // 2112
@@ -158,8 +161,9 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
val conf = createSparkConf(loadDefaults = true, kryo = false)
conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString)
sc = new SparkContext("local-cluster[1,1,1024]", "test", conf)
+ val context = MemoryTestingUtils.fakeTaskContext(sc.env)
val agg = new Aggregator[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _)
- val sorter = new ExternalSorter[FixedHashObject, Int, Int](Some(agg), None, None, None)
+ val sorter = new ExternalSorter[FixedHashObject, Int, Int](context, Some(agg), None, None, None)
// Insert 10 copies each of lots of objects whose hash codes are either 0 or 1. This causes
// problems if the map fails to group together the objects with the same code (SPARK-2043).
val toInsert = for (i <- 1 to 10; j <- 1 to size) yield (FixedHashObject(j, j % 2), 1)
@@ -180,6 +184,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
val conf = createSparkConf(loadDefaults = true, kryo = false)
conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString)
sc = new SparkContext("local-cluster[1,1,1024]", "test", conf)
+ val context = MemoryTestingUtils.fakeTaskContext(sc.env)
def createCombiner(i: Int): ArrayBuffer[Int] = ArrayBuffer[Int](i)
def mergeValue(buffer: ArrayBuffer[Int], i: Int): ArrayBuffer[Int] = buffer += i
@@ -188,7 +193,8 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
}
val agg = new Aggregator[Int, Int, ArrayBuffer[Int]](createCombiner, mergeValue, mergeCombiners)
- val sorter = new ExternalSorter[Int, Int, ArrayBuffer[Int]](Some(agg), None, None, None)
+ val sorter =
+ new ExternalSorter[Int, Int, ArrayBuffer[Int]](context, Some(agg), None, None, None)
sorter.insertAll(
(1 to size).iterator.map(i => (i, i)) ++ Iterator((Int.MaxValue, Int.MaxValue)))
assert(sorter.numSpills > 0, "sorter did not spill")
@@ -204,6 +210,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
val conf = createSparkConf(loadDefaults = true, kryo = false)
conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString)
sc = new SparkContext("local-cluster[1,1,1024]", "test", conf)
+ val context = MemoryTestingUtils.fakeTaskContext(sc.env)
def createCombiner(i: String): ArrayBuffer[String] = ArrayBuffer[String](i)
def mergeValue(buffer: ArrayBuffer[String], i: String): ArrayBuffer[String] = buffer += i
@@ -214,7 +221,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
createCombiner, mergeValue, mergeCombiners)
val sorter = new ExternalSorter[String, String, ArrayBuffer[String]](
- Some(agg), None, None, None)
+ context, Some(agg), None, None, None)
sorter.insertAll((1 to size).iterator.map(i => (i.toString, i.toString)) ++ Iterator(
(null.asInstanceOf[String], "1"),
@@ -271,31 +278,32 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
private def emptyDataStream(conf: SparkConf) {
conf.set("spark.shuffle.manager", "sort")
sc = new SparkContext("local", "test", conf)
+ val context = MemoryTestingUtils.fakeTaskContext(sc.env)
val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
val ord = implicitly[Ordering[Int]]
// Both aggregator and ordering
val sorter = new ExternalSorter[Int, Int, Int](
- Some(agg), Some(new HashPartitioner(3)), Some(ord), None)
+ context, Some(agg), Some(new HashPartitioner(3)), Some(ord), None)
assert(sorter.iterator.toSeq === Seq())
sorter.stop()
// Only aggregator
val sorter2 = new ExternalSorter[Int, Int, Int](
- Some(agg), Some(new HashPartitioner(3)), None, None)
+ context, Some(agg), Some(new HashPartitioner(3)), None, None)
assert(sorter2.iterator.toSeq === Seq())
sorter2.stop()
// Only ordering
val sorter3 = new ExternalSorter[Int, Int, Int](
- None, Some(new HashPartitioner(3)), Some(ord), None)
+ context, None, Some(new HashPartitioner(3)), Some(ord), None)
assert(sorter3.iterator.toSeq === Seq())
sorter3.stop()
// Neither aggregator nor ordering
val sorter4 = new ExternalSorter[Int, Int, Int](
- None, Some(new HashPartitioner(3)), None, None)
+ context, None, Some(new HashPartitioner(3)), None, None)
assert(sorter4.iterator.toSeq === Seq())
sorter4.stop()
}
@@ -303,6 +311,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
private def fewElementsPerPartition(conf: SparkConf) {
conf.set("spark.shuffle.manager", "sort")
sc = new SparkContext("local", "test", conf)
+ val context = MemoryTestingUtils.fakeTaskContext(sc.env)
val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
val ord = implicitly[Ordering[Int]]
@@ -313,28 +322,28 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
// Both aggregator and ordering
val sorter = new ExternalSorter[Int, Int, Int](
- Some(agg), Some(new HashPartitioner(7)), Some(ord), None)
+ context, Some(agg), Some(new HashPartitioner(7)), Some(ord), None)
sorter.insertAll(elements.iterator)
assert(sorter.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
sorter.stop()
// Only aggregator
val sorter2 = new ExternalSorter[Int, Int, Int](
- Some(agg), Some(new HashPartitioner(7)), None, None)
+ context, Some(agg), Some(new HashPartitioner(7)), None, None)
sorter2.insertAll(elements.iterator)
assert(sorter2.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
sorter2.stop()
// Only ordering
val sorter3 = new ExternalSorter[Int, Int, Int](
- None, Some(new HashPartitioner(7)), Some(ord), None)
+ context, None, Some(new HashPartitioner(7)), Some(ord), None)
sorter3.insertAll(elements.iterator)
assert(sorter3.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
sorter3.stop()
// Neither aggregator nor ordering
val sorter4 = new ExternalSorter[Int, Int, Int](
- None, Some(new HashPartitioner(7)), None, None)
+ context, None, Some(new HashPartitioner(7)), None, None)
sorter4.insertAll(elements.iterator)
assert(sorter4.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
sorter4.stop()
@@ -345,12 +354,13 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
conf.set("spark.shuffle.manager", "sort")
conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 2).toString)
sc = new SparkContext("local", "test", conf)
+ val context = MemoryTestingUtils.fakeTaskContext(sc.env)
val ord = implicitly[Ordering[Int]]
val elements = Iterator((1, 1), (5, 5)) ++ (0 until size).iterator.map(x => (2, 2))
val sorter = new ExternalSorter[Int, Int, Int](
- None, Some(new HashPartitioner(7)), Some(ord), None)
+ context, None, Some(new HashPartitioner(7)), Some(ord), None)
sorter.insertAll(elements)
assert(sorter.numSpills > 0, "sorter did not spill")
val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList))
@@ -432,8 +442,9 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
val diskBlockManager = sc.env.blockManager.diskBlockManager
val ord = implicitly[Ordering[Int]]
val expectedSize = if (withFailures) size - 1 else size
+ val context = MemoryTestingUtils.fakeTaskContext(sc.env)
val sorter = new ExternalSorter[Int, Int, Int](
- None, Some(new HashPartitioner(3)), Some(ord), None)
+ context, None, Some(new HashPartitioner(3)), Some(ord), None)
if (withFailures) {
intercept[SparkException] {
sorter.insertAll((0 until size).iterator.map { i =>
@@ -501,7 +512,9 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
None
}
val ord = if (withOrdering) Some(implicitly[Ordering[Int]]) else None
- val sorter = new ExternalSorter[Int, Int, Int](agg, Some(new HashPartitioner(3)), ord, None)
+ val context = MemoryTestingUtils.fakeTaskContext(sc.env)
+ val sorter =
+ new ExternalSorter[Int, Int, Int](context, agg, Some(new HashPartitioner(3)), ord, None)
sorter.insertAll((0 until size).iterator.map { i => (i / 4, i) })
if (withSpilling) {
assert(sorter.numSpills > 0, "sorter did not spill")
@@ -538,8 +551,9 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
val testData = Array.tabulate(size) { _ => rand.nextInt().toString }
+ val context = MemoryTestingUtils.fakeTaskContext(sc.env)
val sorter1 = new ExternalSorter[String, String, String](
- None, None, Some(wrongOrdering), None)
+ context, None, None, Some(wrongOrdering), None)
val thrown = intercept[IllegalArgumentException] {
sorter1.insertAll(testData.iterator.map(i => (i, i)))
assert(sorter1.numSpills > 0, "sorter did not spill")
@@ -561,7 +575,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
createCombiner, mergeValue, mergeCombiners)
val sorter2 = new ExternalSorter[String, String, ArrayBuffer[String]](
- Some(agg), None, None, None)
+ context, Some(agg), None, None, None)
sorter2.insertAll(testData.iterator.map(i => (i, i)))
assert(sorter2.numSpills > 0, "sorter did not spill")
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
index 7d94e0566f..810c74fd2f 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -67,7 +67,6 @@ final class UnsafeExternalRowSorter {
final TaskContext taskContext = TaskContext.get();
sorter = UnsafeExternalSorter.create(
taskContext.taskMemoryManager(),
- sparkEnv.shuffleMemoryManager(),
sparkEnv.blockManager(),
taskContext,
new RowComparator(ordering, schema.length()),
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index 09511ff35f..82c645df28 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -22,7 +22,6 @@ import java.io.IOException;
import com.google.common.annotations.VisibleForTesting;
import org.apache.spark.SparkEnv;
-import org.apache.spark.shuffle.ShuffleMemoryManager;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
@@ -32,7 +31,7 @@ import org.apache.spark.unsafe.KVIterator;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.map.BytesToBytesMap;
import org.apache.spark.unsafe.memory.MemoryLocation;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.memory.TaskMemoryManager;
/**
* Unsafe-based HashMap for performing aggregations where the aggregated values are fixed-width.
@@ -88,8 +87,6 @@ public final class UnsafeFixedWidthAggregationMap {
* @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion.
* @param groupingKeySchema the schema of the grouping key, used for row conversion.
* @param taskMemoryManager the memory manager used to allocate our Unsafe memory structures.
- * @param shuffleMemoryManager the shuffle memory manager, for coordinating our memory usage with
- * other tasks.
* @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing).
* @param pageSizeBytes the data page size, in bytes; limits the maximum record size.
* @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact)
@@ -99,15 +96,14 @@ public final class UnsafeFixedWidthAggregationMap {
StructType aggregationBufferSchema,
StructType groupingKeySchema,
TaskMemoryManager taskMemoryManager,
- ShuffleMemoryManager shuffleMemoryManager,
int initialCapacity,
long pageSizeBytes,
boolean enablePerfMetrics) {
this.aggregationBufferSchema = aggregationBufferSchema;
this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema);
this.groupingKeySchema = groupingKeySchema;
- this.map = new BytesToBytesMap(
- taskMemoryManager, shuffleMemoryManager, initialCapacity, pageSizeBytes, enablePerfMetrics);
+ this.map =
+ new BytesToBytesMap(taskMemoryManager, initialCapacity, pageSizeBytes, enablePerfMetrics);
this.enablePerfMetrics = enablePerfMetrics;
// Initialize the buffer for aggregation value
@@ -256,7 +252,7 @@ public final class UnsafeFixedWidthAggregationMap {
public UnsafeKVExternalSorter destructAndCreateExternalSorter() throws IOException {
UnsafeKVExternalSorter sorter = new UnsafeKVExternalSorter(
groupingKeySchema, aggregationBufferSchema,
- SparkEnv.get().blockManager(), map.getShuffleMemoryManager(), map.getPageSizeBytes(), map);
+ SparkEnv.get().blockManager(), map.getPageSizeBytes(), map);
return sorter;
}
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
index 9df5780e4f..46301f0042 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
@@ -24,7 +24,6 @@ import javax.annotation.Nullable;
import com.google.common.annotations.VisibleForTesting;
import org.apache.spark.TaskContext;
-import org.apache.spark.shuffle.ShuffleMemoryManager;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.catalyst.expressions.codegen.BaseOrdering;
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering;
@@ -34,7 +33,7 @@ import org.apache.spark.unsafe.KVIterator;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.map.BytesToBytesMap;
import org.apache.spark.unsafe.memory.MemoryBlock;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.util.collection.unsafe.sort.*;
/**
@@ -50,14 +49,19 @@ public final class UnsafeKVExternalSorter {
private final UnsafeExternalRowSorter.PrefixComputer prefixComputer;
private final UnsafeExternalSorter sorter;
- public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema,
- BlockManager blockManager, ShuffleMemoryManager shuffleMemoryManager, long pageSizeBytes)
- throws IOException {
- this(keySchema, valueSchema, blockManager, shuffleMemoryManager, pageSizeBytes, null);
+ public UnsafeKVExternalSorter(
+ StructType keySchema,
+ StructType valueSchema,
+ BlockManager blockManager,
+ long pageSizeBytes) throws IOException {
+ this(keySchema, valueSchema, blockManager, pageSizeBytes, null);
}
- public UnsafeKVExternalSorter(StructType keySchema, StructType valueSchema,
- BlockManager blockManager, ShuffleMemoryManager shuffleMemoryManager, long pageSizeBytes,
+ public UnsafeKVExternalSorter(
+ StructType keySchema,
+ StructType valueSchema,
+ BlockManager blockManager,
+ long pageSizeBytes,
@Nullable BytesToBytesMap map) throws IOException {
this.keySchema = keySchema;
this.valueSchema = valueSchema;
@@ -73,7 +77,6 @@ public final class UnsafeKVExternalSorter {
if (map == null) {
sorter = UnsafeExternalSorter.create(
taskMemoryManager,
- shuffleMemoryManager,
blockManager,
taskContext,
recordComparator,
@@ -115,7 +118,6 @@ public final class UnsafeKVExternalSorter {
sorter = UnsafeExternalSorter.createWithExistingInMemorySorter(
taskContext.taskMemoryManager(),
- shuffleMemoryManager,
blockManager,
taskContext,
new KVComparator(ordering, keySchema.length()),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
index 7cd0f7b81e..fb2fc98e34 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.aggregate
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.unsafe.KVIterator
-import org.apache.spark.{InternalAccumulator, Logging, SparkEnv, TaskContext}
+import org.apache.spark.{InternalAccumulator, Logging, TaskContext}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
@@ -34,7 +34,7 @@ import org.apache.spark.sql.types.StructType
*
* This iterator first uses hash-based aggregation to process input rows. It uses
* a hash map to store groups and their corresponding aggregation buffers. If we
- * this map cannot allocate memory from [[org.apache.spark.shuffle.ShuffleMemoryManager]],
+ * this map cannot allocate memory from memory manager,
* it switches to sort-based aggregation. The process of the switch has the following step:
* - Step 1: Sort all entries of the hash map based on values of grouping expressions and
* spill them to disk.
@@ -480,10 +480,9 @@ class TungstenAggregationIterator(
initialAggregationBuffer,
StructType.fromAttributes(allAggregateFunctions.flatMap(_.aggBufferAttributes)),
StructType.fromAttributes(groupingExpressions.map(_.toAttribute)),
- TaskContext.get.taskMemoryManager(),
- SparkEnv.get.shuffleMemoryManager,
+ TaskContext.get().taskMemoryManager(),
1024 * 16, // initial capacity
- SparkEnv.get.shuffleMemoryManager.pageSizeBytes,
+ TaskContext.get().taskMemoryManager().pageSizeBytes,
false // disable tracking of performance metrics
)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
index cfd64c1d9e..1b59b19d94 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
@@ -344,8 +344,7 @@ private[sql] class DynamicPartitionWriterContainer(
StructType.fromAttributes(partitionColumns),
StructType.fromAttributes(dataColumns),
SparkEnv.get.blockManager,
- SparkEnv.get.shuffleMemoryManager,
- SparkEnv.get.shuffleMemoryManager.pageSizeBytes)
+ TaskContext.get().taskMemoryManager().pageSizeBytes)
sorter.insertKV(currentKey, getOutputRow(inputRow))
}
} else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index bc255b2750..cc8abb1ba4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -21,7 +21,7 @@ import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
import java.nio.ByteOrder
import java.util.{HashMap => JavaHashMap}
-import org.apache.spark.shuffle.ShuffleMemoryManager
+import org.apache.spark.memory.{TaskMemoryManager, StaticMemoryManager}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkSqlSerializer
@@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.local.LocalNode
import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics}
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.map.BytesToBytesMap
-import org.apache.spark.unsafe.memory.{MemoryLocation, ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
+import org.apache.spark.unsafe.memory.MemoryLocation
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.CompactBuffer
import org.apache.spark.{SparkConf, SparkEnv}
@@ -320,21 +320,20 @@ private[joins] final class UnsafeHashedRelation(
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
val nKeys = in.readInt()
// This is used in Broadcast, shared by multiple tasks, so we use on-heap memory
- val taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
+ // TODO(josh): This needs to be revisited before we merge this patch; making this change now
+ // so that tests compile:
+ val taskMemoryManager = new TaskMemoryManager(
+ new StaticMemoryManager(
+ new SparkConf().set("spark.unsafe.offHeap", "false"), Long.MaxValue, Long.MaxValue, 1), 0)
- val pageSizeBytes = Option(SparkEnv.get).map(_.shuffleMemoryManager.pageSizeBytes)
+ val pageSizeBytes = Option(SparkEnv.get).map(_.memoryManager.pageSizeBytes)
.getOrElse(new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "16m"))
- // Dummy shuffle memory manager which always grants all memory allocation requests.
- // We use this because it doesn't make sense count shared broadcast variables' memory usage
- // towards individual tasks' quotas. In the future, we should devise a better way of handling
- // this.
- val shuffleMemoryManager =
- ShuffleMemoryManager.create(maxMemory = Long.MaxValue, pageSizeBytes = pageSizeBytes)
+ // TODO(josh): We won't need this dummy memory manager after future refactorings; revisit
+ // during code review
binaryMap = new BytesToBytesMap(
taskMemoryManager,
- shuffleMemoryManager,
(nKeys * 1.5 + 1).toInt, // reduce hash collision
pageSizeBytes)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
index 9385e5734d..dd92dda480 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala
@@ -49,7 +49,8 @@ case class Sort(
protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") {
child.execute().mapPartitions( { iterator =>
val ordering = newOrdering(sortOrder, child.output)
- val sorter = new ExternalSorter[InternalRow, Null, InternalRow](ordering = Some(ordering))
+ val sorter = new ExternalSorter[InternalRow, Null, InternalRow](
+ TaskContext.get(), ordering = Some(ordering))
sorter.insertAll(iterator.map(r => (r.copy(), null)))
val baseIterator = sorter.iterator.map(_._1)
val context = TaskContext.get()
@@ -124,7 +125,7 @@ case class TungstenSort(
}
}
- val pageSize = SparkEnv.get.shuffleMemoryManager.pageSizeBytes
+ val pageSize = SparkEnv.get.memoryManager.pageSizeBytes
val sorter = new UnsafeExternalRowSorter(
schema, ordering, prefixComparator, prefixComputer, pageSize)
if (testSpillFrequency > 0) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
index 1739798a24..dbf4863b76 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
@@ -23,13 +23,12 @@ import scala.util.{Try, Random}
import org.scalatest.Matchers
-import org.apache.spark.{TaskContextImpl, TaskContext, SparkFunSuite}
-import org.apache.spark.shuffle.ShuffleMemoryManager
+import org.apache.spark.{SparkConf, TaskContextImpl, TaskContext, SparkFunSuite}
+import org.apache.spark.memory.{TaskMemoryManager, GrantEverythingMemoryManager}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
import org.apache.spark.unsafe.types.UTF8String
/**
@@ -49,23 +48,22 @@ class UnsafeFixedWidthAggregationMapSuite
private def emptyAggregationBuffer: InternalRow = InternalRow(0)
private val PAGE_SIZE_BYTES: Long = 1L << 26; // 64 megabytes
+ private var memoryManager: GrantEverythingMemoryManager = null
private var taskMemoryManager: TaskMemoryManager = null
- private var shuffleMemoryManager: TestShuffleMemoryManager = null
def testWithMemoryLeakDetection(name: String)(f: => Unit) {
def cleanup(): Unit = {
if (taskMemoryManager != null) {
- val leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask()
assert(taskMemoryManager.cleanUpAllAllocatedMemory() === 0)
- assert(leakedShuffleMemory === 0)
taskMemoryManager = null
}
TaskContext.unset()
}
test(name) {
- taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
- shuffleMemoryManager = new TestShuffleMemoryManager
+ val conf = new SparkConf().set("spark.unsafe.offHeap", "false")
+ memoryManager = new GrantEverythingMemoryManager(conf)
+ taskMemoryManager = new TaskMemoryManager(memoryManager, 0)
TaskContext.setTaskContext(new TaskContextImpl(
stageId = 0,
@@ -110,7 +108,6 @@ class UnsafeFixedWidthAggregationMapSuite
aggBufferSchema,
groupKeySchema,
taskMemoryManager,
- shuffleMemoryManager,
1024, // initial capacity,
PAGE_SIZE_BYTES,
false // disable perf metrics
@@ -125,7 +122,6 @@ class UnsafeFixedWidthAggregationMapSuite
aggBufferSchema,
groupKeySchema,
taskMemoryManager,
- shuffleMemoryManager,
1024, // initial capacity
PAGE_SIZE_BYTES,
false // disable perf metrics
@@ -153,7 +149,6 @@ class UnsafeFixedWidthAggregationMapSuite
aggBufferSchema,
groupKeySchema,
taskMemoryManager,
- shuffleMemoryManager,
128, // initial capacity
PAGE_SIZE_BYTES,
false // disable perf metrics
@@ -176,14 +171,13 @@ class UnsafeFixedWidthAggregationMapSuite
testWithMemoryLeakDetection("test external sorting") {
// Memory consumption in the beginning of the task.
- val initialMemoryConsumption = shuffleMemoryManager.getMemoryConsumptionForThisTask()
+ val initialMemoryConsumption = taskMemoryManager.getMemoryConsumptionForThisTask()
val map = new UnsafeFixedWidthAggregationMap(
emptyAggregationBuffer,
aggBufferSchema,
groupKeySchema,
taskMemoryManager,
- shuffleMemoryManager,
128, // initial capacity
PAGE_SIZE_BYTES,
false // disable perf metrics
@@ -200,7 +194,7 @@ class UnsafeFixedWidthAggregationMapSuite
val sorter = map.destructAndCreateExternalSorter()
withClue(s"destructAndCreateExternalSorter should release memory used by the map") {
- assert(shuffleMemoryManager.getMemoryConsumptionForThisTask() === initialMemoryConsumption)
+ assert(taskMemoryManager.getMemoryConsumptionForThisTask() === initialMemoryConsumption)
}
// Add more keys to the sorter and make sure the results come out sorted.
@@ -214,7 +208,7 @@ class UnsafeFixedWidthAggregationMapSuite
sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v))
if ((i % 100) == 0) {
- shuffleMemoryManager.markAsOutOfMemory()
+ memoryManager.markExecutionAsOutOfMemory()
sorter.closeCurrentPage()
}
}
@@ -238,7 +232,6 @@ class UnsafeFixedWidthAggregationMapSuite
aggBufferSchema,
groupKeySchema,
taskMemoryManager,
- shuffleMemoryManager,
128, // initial capacity
PAGE_SIZE_BYTES,
false // disable perf metrics
@@ -258,7 +251,7 @@ class UnsafeFixedWidthAggregationMapSuite
sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v))
if ((i % 100) == 0) {
- shuffleMemoryManager.markAsOutOfMemory()
+ memoryManager.markExecutionAsOutOfMemory()
sorter.closeCurrentPage()
}
}
@@ -281,14 +274,13 @@ class UnsafeFixedWidthAggregationMapSuite
testWithMemoryLeakDetection("test external sorting with empty records") {
// Memory consumption in the beginning of the task.
- val initialMemoryConsumption = shuffleMemoryManager.getMemoryConsumptionForThisTask()
+ val initialMemoryConsumption = taskMemoryManager.getMemoryConsumptionForThisTask()
val map = new UnsafeFixedWidthAggregationMap(
emptyAggregationBuffer,
StructType(Nil),
StructType(Nil),
taskMemoryManager,
- shuffleMemoryManager,
128, // initial capacity
PAGE_SIZE_BYTES,
false // disable perf metrics
@@ -303,7 +295,7 @@ class UnsafeFixedWidthAggregationMapSuite
val sorter = map.destructAndCreateExternalSorter()
withClue(s"destructAndCreateExternalSorter should release memory used by the map") {
- assert(shuffleMemoryManager.getMemoryConsumptionForThisTask() === initialMemoryConsumption)
+ assert(taskMemoryManager.getMemoryConsumptionForThisTask() === initialMemoryConsumption)
}
// Add more keys to the sorter and make sure the results come out sorted.
@@ -311,7 +303,7 @@ class UnsafeFixedWidthAggregationMapSuite
sorter.insertKV(UnsafeRow.createFromByteArray(0, 0), UnsafeRow.createFromByteArray(0, 0))
if ((i % 100) == 0) {
- shuffleMemoryManager.markAsOutOfMemory()
+ memoryManager.markExecutionAsOutOfMemory()
sorter.closeCurrentPage()
}
}
@@ -332,34 +324,28 @@ class UnsafeFixedWidthAggregationMapSuite
}
testWithMemoryLeakDetection("convert to external sorter under memory pressure (SPARK-10474)") {
- val smm = ShuffleMemoryManager.createForTesting(65536)
val pageSize = 4096
val map = new UnsafeFixedWidthAggregationMap(
emptyAggregationBuffer,
aggBufferSchema,
groupKeySchema,
taskMemoryManager,
- smm,
128, // initial capacity
pageSize,
false // disable perf metrics
)
- // Insert into the map until we've run out of space
val rand = new Random(42)
- var hasSpace = true
- while (hasSpace) {
+ for (i <- 1 to 100) {
val str = rand.nextString(1024)
val buf = map.getAggregationBuffer(InternalRow(UTF8String.fromString(str)))
- if (buf == null) {
- hasSpace = false
- } else {
- buf.setInt(0, str.length)
- }
+ buf.setInt(0, str.length)
}
-
- // Ensure we're actually maxed out by asserting that we can't acquire even just 1 byte
- assert(smm.tryToAcquire(1) === 0)
+ // Simulate running out of space
+ memoryManager.markExecutionAsOutOfMemory()
+ val str = rand.nextString(1024)
+ val buf = map.getAggregationBuffer(InternalRow(UTF8String.fromString(str)))
+ assert(buf == null)
// Convert the map into a sorter. This used to fail before the fix for SPARK-10474
// because we would try to acquire space for the in-memory sorter pointer array before
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
index d3be568a87..13dc1754c9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
@@ -20,12 +20,12 @@ package org.apache.spark.sql.execution
import scala.util.Random
import org.apache.spark._
+import org.apache.spark.memory.{TaskMemoryManager, GrantEverythingMemoryManager}
import org.apache.spark.sql.{RandomDataGenerator, Row}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeRow, UnsafeProjection}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
/**
* Test suite for [[UnsafeKVExternalSorter]], with randomly generated test data.
@@ -108,9 +108,9 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext {
inputData: Seq[(InternalRow, InternalRow)],
pageSize: Long,
spill: Boolean): Unit = {
-
- val taskMemMgr = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
- val shuffleMemMgr = new TestShuffleMemoryManager
+ val memoryManager =
+ new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false"))
+ val taskMemMgr = new TaskMemoryManager(memoryManager, 0)
TaskContext.setTaskContext(new TaskContextImpl(
stageId = 0,
partitionId = 0,
@@ -121,14 +121,14 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext {
internalAccumulators = Seq.empty))
val sorter = new UnsafeKVExternalSorter(
- keySchema, valueSchema, SparkEnv.get.blockManager, shuffleMemMgr, pageSize)
+ keySchema, valueSchema, SparkEnv.get.blockManager, pageSize)
// Insert the keys and values into the sorter
inputData.foreach { case (k, v) =>
sorter.insertKV(k.asInstanceOf[UnsafeRow], v.asInstanceOf[UnsafeRow])
// 1% chance we will spill
if (rand.nextDouble() < 0.01 && spill) {
- shuffleMemMgr.markAsOutOfMemory()
+ memoryManager.markExecutionAsOutOfMemory()
sorter.closeCurrentPage()
}
}
@@ -170,12 +170,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext {
assert(out.sorted(kvOrdering) === inputData.sorted(kvOrdering))
// Make sure there is no memory leak
- val leakedUnsafeMemory: Long = taskMemMgr.cleanUpAllAllocatedMemory
- if (shuffleMemMgr != null) {
- val leakedShuffleMemory: Long = shuffleMemMgr.getMemoryConsumptionForThisTask()
- assert(0L === leakedShuffleMemory)
- }
- assert(0 === leakedUnsafeMemory)
+ assert(0 === taskMemMgr.cleanUpAllAllocatedMemory)
TaskContext.unset()
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
index 1680d7e0a8..d32572b54b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
import java.io.{File, ByteArrayInputStream, ByteArrayOutputStream}
import org.apache.spark.executor.ShuffleWriteMetrics
+import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.ShuffleBlockId
import org.apache.spark.util.collection.ExternalSorter
@@ -112,7 +113,12 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext {
val data = (1 to 10000).iterator.map { i =>
(i, converter(Row(i)))
}
+ val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0)
+ val taskContext = new TaskContextImpl(
+ 0, 0, 0, 0, taskMemoryManager, null, InternalAccumulator.create(sc))
+
val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow](
+ taskContext,
partitioner = Some(new HashPartitioner(10)),
serializer = Some(new UnsafeRowSerializer(numFields = 1)))
@@ -122,10 +128,8 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext {
assert(sorter.numSpills > 0)
// Merging spilled files should not throw assertion error
- val taskContext =
- new TaskContextImpl(0, 0, 0, 0, null, null, InternalAccumulator.create(sc))
taskContext.taskMetrics.shuffleWriteMetrics = Some(new ShuffleWriteMetrics)
- sorter.writePartitionedFile(ShuffleBlockId(0, 0, 0), taskContext, outputFile)
+ sorter.writePartitionedFile(ShuffleBlockId(0, 0, 0), outputFile)
} {
// Clean up
if (sc != null) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
index cc0ac1b07c..475037bd45 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
@@ -18,16 +18,16 @@
package org.apache.spark.sql.execution.aggregate
import org.apache.spark._
+import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.InterpretedMutableProjection
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.unsafe.memory.TaskMemoryManager
class TungstenAggregationIteratorSuite extends SparkFunSuite with SharedSQLContext {
test("memory acquired on construction") {
- val taskMemoryManager = new TaskMemoryManager(SparkEnv.get.executorMemoryManager)
+ val taskMemoryManager = new TaskMemoryManager(SparkEnv.get.memoryManager, 0)
val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, null, Seq.empty)
TaskContext.setTaskContext(taskContext)
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
index b2b6848719..c17fb72381 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
@@ -254,7 +254,7 @@ class ReceivedBlockHandlerSuite
maxMem: Long,
conf: SparkConf,
name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = {
- val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem)
+ val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1)
val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1)
val blockManager = new BlockManager(name, rpcEnv, blockManagerMaster, serializer, conf,
memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0)
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java
deleted file mode 100644
index cbbe859462..0000000000
--- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/ExecutorMemoryManager.java
+++ /dev/null
@@ -1,111 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.unsafe.memory;
-
-import java.lang.ref.WeakReference;
-import java.util.HashMap;
-import java.util.LinkedList;
-import java.util.Map;
-import javax.annotation.concurrent.GuardedBy;
-
-/**
- * Manages memory for an executor. Individual operators / tasks allocate memory through
- * {@link TaskMemoryManager} objects, which obtain their memory from ExecutorMemoryManager.
- */
-public class ExecutorMemoryManager {
-
- /**
- * Allocator, exposed for enabling untracked allocations of temporary data structures.
- */
- public final MemoryAllocator allocator;
-
- /**
- * Tracks whether memory will be allocated on the JVM heap or off-heap using sun.misc.Unsafe.
- */
- final boolean inHeap;
-
- @GuardedBy("this")
- private final Map<Long, LinkedList<WeakReference<MemoryBlock>>> bufferPoolsBySize =
- new HashMap<Long, LinkedList<WeakReference<MemoryBlock>>>();
-
- private static final int POOLING_THRESHOLD_BYTES = 1024 * 1024;
-
- /**
- * Construct a new ExecutorMemoryManager.
- *
- * @param allocator the allocator that will be used
- */
- public ExecutorMemoryManager(MemoryAllocator allocator) {
- this.inHeap = allocator instanceof HeapMemoryAllocator;
- this.allocator = allocator;
- }
-
- /**
- * Returns true if allocations of the given size should go through the pooling mechanism and
- * false otherwise.
- */
- private boolean shouldPool(long size) {
- // Very small allocations are less likely to benefit from pooling.
- // At some point, we should explore supporting pooling for off-heap memory, but for now we'll
- // ignore that case in the interest of simplicity.
- return size >= POOLING_THRESHOLD_BYTES && allocator instanceof HeapMemoryAllocator;
- }
-
- /**
- * Allocates a contiguous block of memory. Note that the allocated memory is not guaranteed
- * to be zeroed out (call `zero()` on the result if this is necessary).
- */
- MemoryBlock allocate(long size) throws OutOfMemoryError {
- if (shouldPool(size)) {
- synchronized (this) {
- final LinkedList<WeakReference<MemoryBlock>> pool = bufferPoolsBySize.get(size);
- if (pool != null) {
- while (!pool.isEmpty()) {
- final WeakReference<MemoryBlock> blockReference = pool.pop();
- final MemoryBlock memory = blockReference.get();
- if (memory != null) {
- assert (memory.size() == size);
- return memory;
- }
- }
- bufferPoolsBySize.remove(size);
- }
- }
- return allocator.allocate(size);
- } else {
- return allocator.allocate(size);
- }
- }
-
- void free(MemoryBlock memory) {
- final long size = memory.size();
- if (shouldPool(size)) {
- synchronized (this) {
- LinkedList<WeakReference<MemoryBlock>> pool = bufferPoolsBySize.get(size);
- if (pool == null) {
- pool = new LinkedList<WeakReference<MemoryBlock>>();
- bufferPoolsBySize.put(size, pool);
- }
- pool.add(new WeakReference<MemoryBlock>(memory));
- }
- } else {
- allocator.free(memory);
- }
- }
-
-}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
index 6722301df1..ebe90d9e63 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
@@ -17,22 +17,71 @@
package org.apache.spark.unsafe.memory;
+import javax.annotation.concurrent.GuardedBy;
+import java.lang.ref.WeakReference;
+import java.util.HashMap;
+import java.util.LinkedList;
+import java.util.Map;
+
/**
* A simple {@link MemoryAllocator} that can allocate up to 16GB using a JVM long primitive array.
*/
public class HeapMemoryAllocator implements MemoryAllocator {
+ @GuardedBy("this")
+ private final Map<Long, LinkedList<WeakReference<MemoryBlock>>> bufferPoolsBySize =
+ new HashMap<>();
+
+ private static final int POOLING_THRESHOLD_BYTES = 1024 * 1024;
+
+ /**
+ * Returns true if allocations of the given size should go through the pooling mechanism and
+ * false otherwise.
+ */
+ private boolean shouldPool(long size) {
+ // Very small allocations are less likely to benefit from pooling.
+ return size >= POOLING_THRESHOLD_BYTES;
+ }
+
@Override
public MemoryBlock allocate(long size) throws OutOfMemoryError {
if (size % 8 != 0) {
throw new IllegalArgumentException("Size " + size + " was not a multiple of 8");
}
+ if (shouldPool(size)) {
+ synchronized (this) {
+ final LinkedList<WeakReference<MemoryBlock>> pool = bufferPoolsBySize.get(size);
+ if (pool != null) {
+ while (!pool.isEmpty()) {
+ final WeakReference<MemoryBlock> blockReference = pool.pop();
+ final MemoryBlock memory = blockReference.get();
+ if (memory != null) {
+ assert (memory.size() == size);
+ return memory;
+ }
+ }
+ bufferPoolsBySize.remove(size);
+ }
+ }
+ }
long[] array = new long[(int) (size / 8)];
return MemoryBlock.fromLongArray(array);
}
@Override
public void free(MemoryBlock memory) {
- // Do nothing
+ final long size = memory.size();
+ if (shouldPool(size)) {
+ synchronized (this) {
+ LinkedList<WeakReference<MemoryBlock>> pool = bufferPoolsBySize.get(size);
+ if (pool == null) {
+ pool = new LinkedList<>();
+ bufferPoolsBySize.put(size, pool);
+ }
+ pool.add(new WeakReference<>(memory));
+ }
+ } else {
+ // Do nothing
+ }
}
}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
index dd75820834..e3e7947115 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
@@ -30,9 +30,10 @@ public class MemoryBlock extends MemoryLocation {
/**
* Optional page number; used when this MemoryBlock represents a page allocated by a
- * MemoryManager. This is package-private and is modified by MemoryManager.
+ * TaskMemoryManager. This field is public so that it can be modified by the TaskMemoryManager,
+ * which lives in a different package.
*/
- int pageNumber = -1;
+ public int pageNumber = -1;
public MemoryBlock(@Nullable Object obj, long offset, long length) {
super(obj, offset);