aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/java/org/apache/spark/memory/MemoryConsumer.java128
-rw-r--r--core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java138
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java210
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java50
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java6
-rw-r--r--core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java430
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java426
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java60
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java6
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java2
-rw-r--r--core/src/main/scala/org/apache/spark/memory/MemoryManager.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/Spillable.scala4
-rw-r--r--core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java77
-rw-r--r--core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java30
-rw-r--r--core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java6
-rw-r--r--core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java38
-rw-r--r--core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java149
-rw-r--r--core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java97
-rw-r--r--core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java20
-rw-r--r--core/src/test/scala/org/apache/spark/FailureSuite.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala60
-rw-r--r--core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala (renamed from core/src/test/scala/org/apache/spark/memory/GrantEverythingMemoryManager.scala)32
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java7
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java2
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java19
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala22
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala54
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java9
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java3
30 files changed, 1270 insertions, 834 deletions
diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java
new file mode 100644
index 0000000000..008799cc77
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.memory;
+
+
+import java.io.IOException;
+
+import org.apache.spark.unsafe.memory.MemoryBlock;
+
+
+/**
+ * An memory consumer of TaskMemoryManager, which support spilling.
+ */
+public abstract class MemoryConsumer {
+
+ private final TaskMemoryManager taskMemoryManager;
+ private final long pageSize;
+ private long used;
+
+ protected MemoryConsumer(TaskMemoryManager taskMemoryManager, long pageSize) {
+ this.taskMemoryManager = taskMemoryManager;
+ this.pageSize = pageSize;
+ this.used = 0;
+ }
+
+ protected MemoryConsumer(TaskMemoryManager taskMemoryManager) {
+ this(taskMemoryManager, taskMemoryManager.pageSizeBytes());
+ }
+
+ /**
+ * Returns the size of used memory in bytes.
+ */
+ long getUsed() {
+ return used;
+ }
+
+ /**
+ * Force spill during building.
+ *
+ * For testing.
+ */
+ public void spill() throws IOException {
+ spill(Long.MAX_VALUE, this);
+ }
+
+ /**
+ * Spill some data to disk to release memory, which will be called by TaskMemoryManager
+ * when there is not enough memory for the task.
+ *
+ * This should be implemented by subclass.
+ *
+ * Note: In order to avoid possible deadlock, should not call acquireMemory() from spill().
+ *
+ * @param size the amount of memory should be released
+ * @param trigger the MemoryConsumer that trigger this spilling
+ * @return the amount of released memory in bytes
+ * @throws IOException
+ */
+ public abstract long spill(long size, MemoryConsumer trigger) throws IOException;
+
+ /**
+ * Acquire `size` bytes memory.
+ *
+ * If there is not enough memory, throws OutOfMemoryError.
+ */
+ protected void acquireMemory(long size) {
+ long got = taskMemoryManager.acquireExecutionMemory(size, this);
+ if (got < size) {
+ taskMemoryManager.releaseExecutionMemory(got, this);
+ taskMemoryManager.showMemoryUsage();
+ throw new OutOfMemoryError("Could not acquire " + size + " bytes of memory, got " + got);
+ }
+ used += got;
+ }
+
+ /**
+ * Release `size` bytes memory.
+ */
+ protected void releaseMemory(long size) {
+ used -= size;
+ taskMemoryManager.releaseExecutionMemory(size, this);
+ }
+
+ /**
+ * Allocate a memory block with at least `required` bytes.
+ *
+ * Throws IOException if there is not enough memory.
+ *
+ * @throws OutOfMemoryError
+ */
+ protected MemoryBlock allocatePage(long required) {
+ MemoryBlock page = taskMemoryManager.allocatePage(Math.max(pageSize, required), this);
+ if (page == null || page.size() < required) {
+ long got = 0;
+ if (page != null) {
+ got = page.size();
+ freePage(page);
+ }
+ taskMemoryManager.showMemoryUsage();
+ throw new OutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + got);
+ }
+ used += page.size();
+ return page;
+ }
+
+ /**
+ * Free a memory block.
+ */
+ protected void freePage(MemoryBlock page) {
+ used -= page.size();
+ taskMemoryManager.freePage(page, this);
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
index 7b31c90dac..4230575446 100644
--- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
+++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
@@ -17,13 +17,18 @@
package org.apache.spark.memory;
-import java.util.*;
+import javax.annotation.concurrent.GuardedBy;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.BitSet;
+import java.util.HashSet;
import com.google.common.annotations.VisibleForTesting;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.util.Utils;
/**
* Manages the memory allocated by an individual task.
@@ -101,29 +106,104 @@ public class TaskMemoryManager {
private final boolean inHeap;
/**
+ * The size of memory granted to each consumer.
+ */
+ @GuardedBy("this")
+ private final HashSet<MemoryConsumer> consumers;
+
+ /**
* Construct a new TaskMemoryManager.
*/
public TaskMemoryManager(MemoryManager memoryManager, long taskAttemptId) {
this.inHeap = memoryManager.tungstenMemoryIsAllocatedInHeap();
this.memoryManager = memoryManager;
this.taskAttemptId = taskAttemptId;
+ this.consumers = new HashSet<>();
}
/**
- * Acquire N bytes of memory for execution, evicting cached blocks if necessary.
+ * Acquire N bytes of memory for a consumer. If there is no enough memory, it will call
+ * spill() of consumers to release more memory.
+ *
* @return number of bytes successfully granted (<= N).
*/
- public long acquireExecutionMemory(long size) {
- return memoryManager.acquireExecutionMemory(size, taskAttemptId);
+ public long acquireExecutionMemory(long required, MemoryConsumer consumer) {
+ assert(required >= 0);
+ synchronized (this) {
+ long got = memoryManager.acquireExecutionMemory(required, taskAttemptId);
+
+ // try to release memory from other consumers first, then we can reduce the frequency of
+ // spilling, avoid to have too many spilled files.
+ if (got < required) {
+ // Call spill() on other consumers to release memory
+ for (MemoryConsumer c: consumers) {
+ if (c != null && c != consumer && c.getUsed() > 0) {
+ try {
+ long released = c.spill(required - got, consumer);
+ if (released > 0) {
+ logger.info("Task {} released {} from {} for {}", taskAttemptId,
+ Utils.bytesToString(released), c, consumer);
+ got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId);
+ if (got >= required) {
+ break;
+ }
+ }
+ } catch (IOException e) {
+ logger.error("error while calling spill() on " + c, e);
+ throw new OutOfMemoryError("error while calling spill() on " + c + " : "
+ + e.getMessage());
+ }
+ }
+ }
+ }
+
+ // call spill() on itself
+ if (got < required && consumer != null) {
+ try {
+ long released = consumer.spill(required - got, consumer);
+ if (released > 0) {
+ logger.info("Task {} released {} from itself ({})", taskAttemptId,
+ Utils.bytesToString(released), consumer);
+ got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId);
+ }
+ } catch (IOException e) {
+ logger.error("error while calling spill() on " + consumer, e);
+ throw new OutOfMemoryError("error while calling spill() on " + consumer + " : "
+ + e.getMessage());
+ }
+ }
+
+ consumers.add(consumer);
+ logger.debug("Task {} acquire {} for {}", taskAttemptId, Utils.bytesToString(got), consumer);
+ return got;
+ }
}
/**
- * Release N bytes of execution memory.
+ * Release N bytes of execution memory for a MemoryConsumer.
*/
- public void releaseExecutionMemory(long size) {
+ public void releaseExecutionMemory(long size, MemoryConsumer consumer) {
+ logger.debug("Task {} release {} from {}", taskAttemptId, Utils.bytesToString(size), consumer);
memoryManager.releaseExecutionMemory(size, taskAttemptId);
}
+ /**
+ * Dump the memory usage of all consumers.
+ */
+ public void showMemoryUsage() {
+ logger.info("Memory used in task " + taskAttemptId);
+ synchronized (this) {
+ for (MemoryConsumer c: consumers) {
+ if (c.getUsed() > 0) {
+ logger.info("Acquired by " + c + ": " + Utils.bytesToString(c.getUsed()));
+ }
+ }
+ }
+ }
+
+ /**
+ * Return the page size in bytes.
+ */
public long pageSizeBytes() {
return memoryManager.pageSizeBytes();
}
@@ -134,42 +214,40 @@ public class TaskMemoryManager {
*
* Returns `null` if there was not enough memory to allocate the page.
*/
- public MemoryBlock allocatePage(long size) {
+ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) {
if (size > MAXIMUM_PAGE_SIZE_BYTES) {
throw new IllegalArgumentException(
"Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE_BYTES + " bytes");
}
+ long acquired = acquireExecutionMemory(size, consumer);
+ if (acquired <= 0) {
+ return null;
+ }
+
final int pageNumber;
synchronized (this) {
pageNumber = allocatedPages.nextClearBit(0);
if (pageNumber >= PAGE_TABLE_SIZE) {
+ releaseExecutionMemory(acquired, consumer);
throw new IllegalStateException(
"Have already allocated a maximum of " + PAGE_TABLE_SIZE + " pages");
}
allocatedPages.set(pageNumber);
}
- final long acquiredExecutionMemory = acquireExecutionMemory(size);
- if (acquiredExecutionMemory != size) {
- releaseExecutionMemory(acquiredExecutionMemory);
- synchronized (this) {
- allocatedPages.clear(pageNumber);
- }
- return null;
- }
- final MemoryBlock page = memoryManager.tungstenMemoryAllocator().allocate(size);
+ final MemoryBlock page = memoryManager.tungstenMemoryAllocator().allocate(acquired);
page.pageNumber = pageNumber;
pageTable[pageNumber] = page;
if (logger.isTraceEnabled()) {
- logger.trace("Allocate page number {} ({} bytes)", pageNumber, size);
+ logger.trace("Allocate page number {} ({} bytes)", pageNumber, acquired);
}
return page;
}
/**
- * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage(long)}.
+ * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage}.
*/
- public void freePage(MemoryBlock page) {
+ public void freePage(MemoryBlock page, MemoryConsumer consumer) {
assert (page.pageNumber != -1) :
"Called freePage() on memory that wasn't allocated with allocatePage()";
assert(allocatedPages.get(page.pageNumber));
@@ -182,14 +260,14 @@ public class TaskMemoryManager {
}
long pageSize = page.size();
memoryManager.tungstenMemoryAllocator().free(page);
- releaseExecutionMemory(pageSize);
+ releaseExecutionMemory(pageSize, consumer);
}
/**
* Given a memory page and offset within that page, encode this address into a 64-bit long.
* This address will remain valid as long as the corresponding page has not been freed.
*
- * @param page a data page allocated by {@link TaskMemoryManager#allocatePage(long)}/
+ * @param page a data page allocated by {@link TaskMemoryManager#allocatePage}/
* @param offsetInPage an offset in this page which incorporates the base offset. In other words,
* this should be the value that you would pass as the base offset into an
* UNSAFE call (e.g. page.baseOffset() + something).
@@ -261,17 +339,17 @@ public class TaskMemoryManager {
* value can be used to detect memory leaks.
*/
public long cleanUpAllAllocatedMemory() {
- long freedBytes = 0;
- for (MemoryBlock page : pageTable) {
- if (page != null) {
- freedBytes += page.size();
- freePage(page);
+ synchronized (this) {
+ Arrays.fill(pageTable, null);
+ for (MemoryConsumer c: consumers) {
+ if (c != null && c.getUsed() > 0) {
+ // In case of failed task, it's normal to see leaked memory
+ logger.warn("leak " + Utils.bytesToString(c.getUsed()) + " memory from " + c);
+ }
}
+ consumers.clear();
}
-
- freedBytes += memoryManager.releaseAllExecutionMemoryForTask(taskAttemptId);
-
- return freedBytes;
+ return memoryManager.releaseAllExecutionMemoryForTask(taskAttemptId);
}
/**
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
index f43236f41a..400d852001 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
@@ -31,15 +31,15 @@ import org.slf4j.LoggerFactory;
import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.memory.MemoryConsumer;
+import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.serializer.DummySerializerInstance;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.DiskBlockObjectWriter;
import org.apache.spark.storage.TempShuffleBlockId;
import org.apache.spark.unsafe.Platform;
-import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.memory.MemoryBlock;
-import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.util.Utils;
/**
@@ -58,23 +58,18 @@ import org.apache.spark.util.Utils;
* spill files. Instead, this merging is performed in {@link UnsafeShuffleWriter}, which uses a
* specialized merge procedure that avoids extra serialization/deserialization.
*/
-final class ShuffleExternalSorter {
+final class ShuffleExternalSorter extends MemoryConsumer {
private final Logger logger = LoggerFactory.getLogger(ShuffleExternalSorter.class);
@VisibleForTesting
static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024;
- private final int initialSize;
private final int numPartitions;
- private final int pageSizeBytes;
- @VisibleForTesting
- final int maxRecordSizeBytes;
private final TaskMemoryManager taskMemoryManager;
private final BlockManager blockManager;
private final TaskContext taskContext;
private final ShuffleWriteMetrics writeMetrics;
- private long numRecordsInsertedSinceLastSpill = 0;
/** Force this sorter to spill when there are this many elements in memory. For testing only */
private final long numElementsForSpillThreshold;
@@ -98,8 +93,7 @@ final class ShuffleExternalSorter {
// These variables are reset after spilling:
@Nullable private ShuffleInMemorySorter inMemSorter;
@Nullable private MemoryBlock currentPage = null;
- private long currentPagePosition = -1;
- private long freeSpaceInCurrentPage = 0;
+ private long pageCursor = -1;
public ShuffleExternalSorter(
TaskMemoryManager memoryManager,
@@ -108,42 +102,21 @@ final class ShuffleExternalSorter {
int initialSize,
int numPartitions,
SparkConf conf,
- ShuffleWriteMetrics writeMetrics) throws IOException {
+ ShuffleWriteMetrics writeMetrics) {
+ super(memoryManager, (int) Math.min(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES,
+ memoryManager.pageSizeBytes()));
this.taskMemoryManager = memoryManager;
this.blockManager = blockManager;
this.taskContext = taskContext;
- this.initialSize = initialSize;
- this.peakMemoryUsedBytes = initialSize;
this.numPartitions = numPartitions;
// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
this.numElementsForSpillThreshold =
conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MAX_VALUE);
- this.pageSizeBytes = (int) Math.min(
- PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, taskMemoryManager.pageSizeBytes());
- this.maxRecordSizeBytes = pageSizeBytes - 4;
this.writeMetrics = writeMetrics;
- initializeForWriting();
-
- // preserve first page to ensure that we have at least one page to work with. Otherwise,
- // other operators in the same task may starve this sorter (SPARK-9709).
- acquireNewPageIfNecessary(pageSizeBytes);
- }
-
- /**
- * Allocates new sort data structures. Called when creating the sorter and after each spill.
- */
- private void initializeForWriting() throws IOException {
- // TODO: move this sizing calculation logic into a static method of sorter:
- final long memoryRequested = initialSize * 8L;
- final long memoryAcquired = taskMemoryManager.acquireExecutionMemory(memoryRequested);
- if (memoryAcquired != memoryRequested) {
- taskMemoryManager.releaseExecutionMemory(memoryAcquired);
- throw new IOException("Could not acquire " + memoryRequested + " bytes of memory");
- }
-
+ acquireMemory(initialSize * 8L);
this.inMemSorter = new ShuffleInMemorySorter(initialSize);
- numRecordsInsertedSinceLastSpill = 0;
+ this.peakMemoryUsedBytes = getMemoryUsage();
}
/**
@@ -242,6 +215,8 @@ final class ShuffleExternalSorter {
}
}
+ inMemSorter.reset();
+
if (!isLastFile) { // i.e. this is a spill file
// The current semantics of `shuffleRecordsWritten` seem to be that it's updated when records
// are written to disk, not when they enter the shuffle sorting code. DiskBlockObjectWriter
@@ -266,9 +241,12 @@ final class ShuffleExternalSorter {
/**
* Sort and spill the current records in response to memory pressure.
*/
- @VisibleForTesting
- void spill() throws IOException {
- assert(inMemSorter != null);
+ @Override
+ public long spill(long size, MemoryConsumer trigger) throws IOException {
+ if (trigger != this || inMemSorter == null || inMemSorter.numRecords() == 0) {
+ return 0L;
+ }
+
logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)",
Thread.currentThread().getId(),
Utils.bytesToString(getMemoryUsage()),
@@ -276,13 +254,9 @@ final class ShuffleExternalSorter {
spills.size() > 1 ? " times" : " time");
writeSortedFile(false);
- final long inMemSorterMemoryUsage = inMemSorter.getMemoryUsage();
- inMemSorter = null;
- taskMemoryManager.releaseExecutionMemory(inMemSorterMemoryUsage);
final long spillSize = freeMemory();
taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
-
- initializeForWriting();
+ return spillSize;
}
private long getMemoryUsage() {
@@ -312,18 +286,12 @@ final class ShuffleExternalSorter {
updatePeakMemoryUsed();
long memoryFreed = 0;
for (MemoryBlock block : allocatedPages) {
- taskMemoryManager.freePage(block);
memoryFreed += block.size();
- }
- if (inMemSorter != null) {
- long sorterMemoryUsage = inMemSorter.getMemoryUsage();
- inMemSorter = null;
- taskMemoryManager.releaseExecutionMemory(sorterMemoryUsage);
+ freePage(block);
}
allocatedPages.clear();
currentPage = null;
- currentPagePosition = -1;
- freeSpaceInCurrentPage = 0;
+ pageCursor = 0;
return memoryFreed;
}
@@ -332,16 +300,16 @@ final class ShuffleExternalSorter {
*/
public void cleanupResources() {
freeMemory();
+ if (inMemSorter != null) {
+ long sorterMemoryUsage = inMemSorter.getMemoryUsage();
+ inMemSorter = null;
+ releaseMemory(sorterMemoryUsage);
+ }
for (SpillInfo spill : spills) {
if (spill.file.exists() && !spill.file.delete()) {
logger.error("Unable to delete spill file {}", spill.file.getPath());
}
}
- if (inMemSorter != null) {
- long sorterMemoryUsage = inMemSorter.getMemoryUsage();
- inMemSorter = null;
- taskMemoryManager.releaseExecutionMemory(sorterMemoryUsage);
- }
}
/**
@@ -352,16 +320,27 @@ final class ShuffleExternalSorter {
private void growPointerArrayIfNecessary() throws IOException {
assert(inMemSorter != null);
if (!inMemSorter.hasSpaceForAnotherRecord()) {
- logger.debug("Attempting to expand sort pointer array");
- final long oldPointerArrayMemoryUsage = inMemSorter.getMemoryUsage();
- final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2;
- final long memoryAcquired = taskMemoryManager.acquireExecutionMemory(memoryToGrowPointerArray);
- if (memoryAcquired < memoryToGrowPointerArray) {
- taskMemoryManager.releaseExecutionMemory(memoryAcquired);
- spill();
+ long used = inMemSorter.getMemoryUsage();
+ long needed = used + inMemSorter.getMemoryToExpand();
+ try {
+ acquireMemory(needed); // could trigger spilling
+ } catch (OutOfMemoryError e) {
+ // should have trigger spilling
+ assert(inMemSorter.hasSpaceForAnotherRecord());
+ return;
+ }
+ // check if spilling is triggered or not
+ if (inMemSorter.hasSpaceForAnotherRecord()) {
+ releaseMemory(needed);
} else {
- inMemSorter.expandPointerArray();
- taskMemoryManager.releaseExecutionMemory(oldPointerArrayMemoryUsage);
+ try {
+ inMemSorter.expandPointerArray();
+ releaseMemory(used);
+ } catch (OutOfMemoryError oom) {
+ // Just in case that JVM had run out of memory
+ releaseMemory(needed);
+ spill();
+ }
}
}
}
@@ -370,96 +349,46 @@ final class ShuffleExternalSorter {
* Allocates more memory in order to insert an additional record. This will request additional
* memory from the memory manager and spill if the requested memory can not be obtained.
*
- * @param requiredSpace the required space in the data page, in bytes, including space for storing
+ * @param required the required space in the data page, in bytes, including space for storing
* the record size. This must be less than or equal to the page size (records
* that exceed the page size are handled via a different code path which uses
* special overflow pages).
*/
- private void acquireNewPageIfNecessary(int requiredSpace) throws IOException {
- growPointerArrayIfNecessary();
- if (requiredSpace > freeSpaceInCurrentPage) {
- logger.trace("Required space {} is less than free space in current page ({})", requiredSpace,
- freeSpaceInCurrentPage);
- // TODO: we should track metrics on the amount of space wasted when we roll over to a new page
- // without using the free space at the end of the current page. We should also do this for
- // BytesToBytesMap.
- if (requiredSpace > pageSizeBytes) {
- throw new IOException("Required space " + requiredSpace + " is greater than page size (" +
- pageSizeBytes + ")");
- } else {
- currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
- if (currentPage == null) {
- spill();
- currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
- if (currentPage == null) {
- throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory");
- }
- }
- currentPagePosition = currentPage.getBaseOffset();
- freeSpaceInCurrentPage = pageSizeBytes;
- allocatedPages.add(currentPage);
- }
+ private void acquireNewPageIfNecessary(int required) {
+ if (currentPage == null ||
+ pageCursor + required > currentPage.getBaseOffset() + currentPage.size() ) {
+ // TODO: try to find space in previous pages
+ currentPage = allocatePage(required);
+ pageCursor = currentPage.getBaseOffset();
+ allocatedPages.add(currentPage);
}
}
/**
* Write a record to the shuffle sorter.
*/
- public void insertRecord(
- Object recordBaseObject,
- long recordBaseOffset,
- int lengthInBytes,
- int partitionId) throws IOException {
+ public void insertRecord(Object recordBase, long recordOffset, int length, int partitionId)
+ throws IOException {
- if (numRecordsInsertedSinceLastSpill > numElementsForSpillThreshold) {
+ // for tests
+ assert(inMemSorter != null);
+ if (inMemSorter.numRecords() > numElementsForSpillThreshold) {
spill();
}
growPointerArrayIfNecessary();
// Need 4 bytes to store the record length.
- final int totalSpaceRequired = lengthInBytes + 4;
-
- // --- Figure out where to insert the new record ----------------------------------------------
-
- final MemoryBlock dataPage;
- long dataPagePosition;
- boolean useOverflowPage = totalSpaceRequired > pageSizeBytes;
- if (useOverflowPage) {
- long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired);
- // The record is larger than the page size, so allocate a special overflow page just to hold
- // that record.
- MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
- if (overflowPage == null) {
- spill();
- overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
- if (overflowPage == null) {
- throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory");
- }
- }
- allocatedPages.add(overflowPage);
- dataPage = overflowPage;
- dataPagePosition = overflowPage.getBaseOffset();
- } else {
- // The record is small enough to fit in a regular data page, but the current page might not
- // have enough space to hold it (or no pages have been allocated yet).
- acquireNewPageIfNecessary(totalSpaceRequired);
- dataPage = currentPage;
- dataPagePosition = currentPagePosition;
- // Update bookkeeping information
- freeSpaceInCurrentPage -= totalSpaceRequired;
- currentPagePosition += totalSpaceRequired;
- }
- final Object dataPageBaseObject = dataPage.getBaseObject();
-
- final long recordAddress =
- taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition);
- Platform.putInt(dataPageBaseObject, dataPagePosition, lengthInBytes);
- dataPagePosition += 4;
- Platform.copyMemory(
- recordBaseObject, recordBaseOffset, dataPageBaseObject, dataPagePosition, lengthInBytes);
- assert(inMemSorter != null);
+ final int required = length + 4;
+ acquireNewPageIfNecessary(required);
+
+ assert(currentPage != null);
+ final Object base = currentPage.getBaseObject();
+ final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor);
+ Platform.putInt(base, pageCursor, length);
+ pageCursor += 4;
+ Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length);
+ pageCursor += length;
inMemSorter.insertRecord(recordAddress, partitionId);
- numRecordsInsertedSinceLastSpill += 1;
}
/**
@@ -475,6 +404,9 @@ final class ShuffleExternalSorter {
// Do not count the final file towards the spill count.
writeSortedFile(true);
freeMemory();
+ long sorterMemoryUsage = inMemSorter.getMemoryUsage();
+ inMemSorter = null;
+ releaseMemory(sorterMemoryUsage);
}
return spills.toArray(new SpillInfo[spills.size()]);
} catch (IOException e) {
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
index a8dee6c610..e630575d1a 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
@@ -37,33 +37,51 @@ final class ShuffleInMemorySorter {
* {@link PackedRecordPointer}. The sort operates on this array instead of directly manipulating
* records.
*/
- private long[] pointerArray;
+ private long[] array;
/**
* The position in the pointer array where new records can be inserted.
*/
- private int pointerArrayInsertPosition = 0;
+ private int pos = 0;
public ShuffleInMemorySorter(int initialSize) {
assert (initialSize > 0);
- this.pointerArray = new long[initialSize];
- this.sorter = new Sorter<PackedRecordPointer, long[]>(ShuffleSortDataFormat.INSTANCE);
+ this.array = new long[initialSize];
+ this.sorter = new Sorter<>(ShuffleSortDataFormat.INSTANCE);
}
- public void expandPointerArray() {
- final long[] oldArray = pointerArray;
+ public int numRecords() {
+ return pos;
+ }
+
+ public void reset() {
+ pos = 0;
+ }
+
+ private int newLength() {
// Guard against overflow:
- final int newLength = oldArray.length * 2 > 0 ? (oldArray.length * 2) : Integer.MAX_VALUE;
- pointerArray = new long[newLength];
- System.arraycopy(oldArray, 0, pointerArray, 0, oldArray.length);
+ return array.length <= Integer.MAX_VALUE / 2 ? (array.length * 2) : Integer.MAX_VALUE;
+ }
+
+ /**
+ * Returns the memory needed to expand
+ */
+ public long getMemoryToExpand() {
+ return ((long) (newLength() - array.length)) * 8;
+ }
+
+ public void expandPointerArray() {
+ final long[] oldArray = array;
+ array = new long[newLength()];
+ System.arraycopy(oldArray, 0, array, 0, oldArray.length);
}
public boolean hasSpaceForAnotherRecord() {
- return pointerArrayInsertPosition + 1 < pointerArray.length;
+ return pos < array.length;
}
public long getMemoryUsage() {
- return pointerArray.length * 8L;
+ return array.length * 8L;
}
/**
@@ -78,15 +96,15 @@ final class ShuffleInMemorySorter {
*/
public void insertRecord(long recordPointer, int partitionId) {
if (!hasSpaceForAnotherRecord()) {
- if (pointerArray.length == Integer.MAX_VALUE) {
+ if (array.length == Integer.MAX_VALUE) {
throw new IllegalStateException("Sort pointer array has reached maximum size");
} else {
expandPointerArray();
}
}
- pointerArray[pointerArrayInsertPosition] =
+ array[pos] =
PackedRecordPointer.packPointer(recordPointer, partitionId);
- pointerArrayInsertPosition++;
+ pos++;
}
/**
@@ -118,7 +136,7 @@ final class ShuffleInMemorySorter {
* Return an iterator over record pointers in sorted order.
*/
public ShuffleSorterIterator getSortedIterator() {
- sorter.sort(pointerArray, 0, pointerArrayInsertPosition, SORT_COMPARATOR);
- return new ShuffleSorterIterator(pointerArrayInsertPosition, pointerArray);
+ sorter.sort(array, 0, pos, SORT_COMPARATOR);
+ return new ShuffleSorterIterator(pos, array);
}
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index f6c5c944bd..e19b378642 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -127,12 +127,6 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
open();
}
- @VisibleForTesting
- public int maxRecordSizeBytes() {
- assert(sorter != null);
- return sorter.maxRecordSizeBytes;
- }
-
private void updatePeakMemoryUsed() {
// sorter can be null if this writer is closed
if (sorter != null) {
diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index f035bdac81..e36709c6fc 100644
--- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -18,14 +18,20 @@
package org.apache.spark.unsafe.map;
import javax.annotation.Nullable;
+import java.io.File;
+import java.io.IOException;
import java.util.Iterator;
import java.util.LinkedList;
-import java.util.List;
import com.google.common.annotations.VisibleForTesting;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import org.apache.spark.SparkEnv;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.memory.MemoryConsumer;
+import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.storage.BlockManager;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.array.LongArray;
@@ -33,7 +39,8 @@ import org.apache.spark.unsafe.bitset.BitSet;
import org.apache.spark.unsafe.hash.Murmur3_x86_32;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.unsafe.memory.MemoryLocation;
-import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillReader;
+import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillWriter;
/**
* An append-only hash map where keys and values are contiguous regions of bytes.
@@ -54,7 +61,7 @@ import org.apache.spark.memory.TaskMemoryManager;
* is consistent with {@link org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter},
* so we can pass records from this map directly into the sorter to sort records in place.
*/
-public final class BytesToBytesMap {
+public final class BytesToBytesMap extends MemoryConsumer {
private final Logger logger = LoggerFactory.getLogger(BytesToBytesMap.class);
@@ -62,27 +69,22 @@ public final class BytesToBytesMap {
private static final HashMapGrowthStrategy growthStrategy = HashMapGrowthStrategy.DOUBLING;
- /**
- * Special record length that is placed after the last record in a data page.
- */
- private static final int END_OF_PAGE_MARKER = -1;
-
private final TaskMemoryManager taskMemoryManager;
/**
* A linked list for tracking all allocated data pages so that we can free all of our memory.
*/
- private final List<MemoryBlock> dataPages = new LinkedList<MemoryBlock>();
+ private final LinkedList<MemoryBlock> dataPages = new LinkedList<>();
/**
* The data page that will be used to store keys and values for new hashtable entries. When this
* page becomes full, a new page will be allocated and this pointer will change to point to that
* new page.
*/
- private MemoryBlock currentDataPage = null;
+ private MemoryBlock currentPage = null;
/**
- * Offset into `currentDataPage` that points to the location where new data can be inserted into
+ * Offset into `currentPage` that points to the location where new data can be inserted into
* the page. This does not incorporate the page's base offset.
*/
private long pageCursor = 0;
@@ -117,6 +119,11 @@ public final class BytesToBytesMap {
// absolute memory addresses.
/**
+ * Whether or not the longArray can grow. We will not insert more elements if it's false.
+ */
+ private boolean canGrowArray = true;
+
+ /**
* A {@link BitSet} used to track location of the map where the key is set.
* Size of the bitset should be half of the size of the long array.
*/
@@ -164,13 +171,20 @@ public final class BytesToBytesMap {
private long peakMemoryUsedBytes = 0L;
+ private final BlockManager blockManager;
+ private volatile MapIterator destructiveIterator = null;
+ private LinkedList<UnsafeSorterSpillWriter> spillWriters = new LinkedList<>();
+
public BytesToBytesMap(
TaskMemoryManager taskMemoryManager,
+ BlockManager blockManager,
int initialCapacity,
double loadFactor,
long pageSizeBytes,
boolean enablePerfMetrics) {
+ super(taskMemoryManager, pageSizeBytes);
this.taskMemoryManager = taskMemoryManager;
+ this.blockManager = blockManager;
this.loadFactor = loadFactor;
this.loc = new Location();
this.pageSizeBytes = pageSizeBytes;
@@ -187,18 +201,13 @@ public final class BytesToBytesMap {
TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES);
}
allocate(initialCapacity);
-
- // Acquire a new page as soon as we construct the map to ensure that we have at least
- // one page to work with. Otherwise, other operators in the same task may starve this
- // map (SPARK-9747).
- acquireNewPage();
}
public BytesToBytesMap(
TaskMemoryManager taskMemoryManager,
int initialCapacity,
long pageSizeBytes) {
- this(taskMemoryManager, initialCapacity, 0.70, pageSizeBytes, false);
+ this(taskMemoryManager, initialCapacity, pageSizeBytes, false);
}
public BytesToBytesMap(
@@ -208,6 +217,7 @@ public final class BytesToBytesMap {
boolean enablePerfMetrics) {
this(
taskMemoryManager,
+ SparkEnv.get() != null ? SparkEnv.get().blockManager() : null,
initialCapacity,
0.70,
pageSizeBytes,
@@ -219,61 +229,153 @@ public final class BytesToBytesMap {
*/
public int numElements() { return numElements; }
- public static final class BytesToBytesMapIterator implements Iterator<Location> {
+ public final class MapIterator implements Iterator<Location> {
- private final int numRecords;
- private final Iterator<MemoryBlock> dataPagesIterator;
+ private int numRecords;
private final Location loc;
private MemoryBlock currentPage = null;
- private int currentRecordNumber = 0;
+ private int recordsInPage = 0;
private Object pageBaseObject;
private long offsetInPage;
// If this iterator destructive or not. When it is true, it frees each page as it moves onto
// next one.
private boolean destructive = false;
- private BytesToBytesMap bmap;
+ private UnsafeSorterSpillReader reader = null;
- private BytesToBytesMapIterator(
- int numRecords, Iterator<MemoryBlock> dataPagesIterator, Location loc,
- boolean destructive, BytesToBytesMap bmap) {
+ private MapIterator(int numRecords, Location loc, boolean destructive) {
this.numRecords = numRecords;
- this.dataPagesIterator = dataPagesIterator;
this.loc = loc;
this.destructive = destructive;
- this.bmap = bmap;
- if (dataPagesIterator.hasNext()) {
- advanceToNextPage();
+ if (destructive) {
+ destructiveIterator = this;
}
}
private void advanceToNextPage() {
- if (destructive && currentPage != null) {
- dataPagesIterator.remove();
- this.bmap.taskMemoryManager.freePage(currentPage);
+ synchronized (this) {
+ int nextIdx = dataPages.indexOf(currentPage) + 1;
+ if (destructive && currentPage != null) {
+ dataPages.remove(currentPage);
+ freePage(currentPage);
+ nextIdx --;
+ }
+ if (dataPages.size() > nextIdx) {
+ currentPage = dataPages.get(nextIdx);
+ pageBaseObject = currentPage.getBaseObject();
+ offsetInPage = currentPage.getBaseOffset();
+ recordsInPage = Platform.getInt(pageBaseObject, offsetInPage);
+ offsetInPage += 4;
+ } else {
+ currentPage = null;
+ if (reader != null) {
+ // remove the spill file from disk
+ File file = spillWriters.removeFirst().getFile();
+ if (file != null && file.exists()) {
+ if (!file.delete()) {
+ logger.error("Was unable to delete spill file {}", file.getAbsolutePath());
+ }
+ }
+ }
+ try {
+ reader = spillWriters.getFirst().getReader(blockManager);
+ recordsInPage = -1;
+ } catch (IOException e) {
+ // Scala iterator does not handle exception
+ Platform.throwException(e);
+ }
+ }
}
- currentPage = dataPagesIterator.next();
- pageBaseObject = currentPage.getBaseObject();
- offsetInPage = currentPage.getBaseOffset();
}
@Override
public boolean hasNext() {
- return currentRecordNumber != numRecords;
+ if (numRecords == 0) {
+ if (reader != null) {
+ // remove the spill file from disk
+ File file = spillWriters.removeFirst().getFile();
+ if (file != null && file.exists()) {
+ if (!file.delete()) {
+ logger.error("Was unable to delete spill file {}", file.getAbsolutePath());
+ }
+ }
+ }
+ }
+ return numRecords > 0;
}
@Override
public Location next() {
- int totalLength = Platform.getInt(pageBaseObject, offsetInPage);
- if (totalLength == END_OF_PAGE_MARKER) {
+ if (recordsInPage == 0) {
advanceToNextPage();
- totalLength = Platform.getInt(pageBaseObject, offsetInPage);
}
- loc.with(currentPage, offsetInPage);
- offsetInPage += 4 + totalLength;
- currentRecordNumber++;
- return loc;
+ numRecords--;
+ if (currentPage != null) {
+ int totalLength = Platform.getInt(pageBaseObject, offsetInPage);
+ loc.with(currentPage, offsetInPage);
+ offsetInPage += 4 + totalLength;
+ recordsInPage --;
+ return loc;
+ } else {
+ assert(reader != null);
+ if (!reader.hasNext()) {
+ advanceToNextPage();
+ }
+ try {
+ reader.loadNext();
+ } catch (IOException e) {
+ // Scala iterator does not handle exception
+ Platform.throwException(e);
+ }
+ loc.with(reader.getBaseObject(), reader.getBaseOffset(), reader.getRecordLength());
+ return loc;
+ }
+ }
+
+ public long spill(long numBytes) throws IOException {
+ synchronized (this) {
+ if (!destructive || dataPages.size() == 1) {
+ return 0L;
+ }
+
+ // TODO: use existing ShuffleWriteMetrics
+ ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics();
+
+ long released = 0L;
+ while (dataPages.size() > 0) {
+ MemoryBlock block = dataPages.getLast();
+ // The currentPage is used, cannot be released
+ if (block == currentPage) {
+ break;
+ }
+
+ Object base = block.getBaseObject();
+ long offset = block.getBaseOffset();
+ int numRecords = Platform.getInt(base, offset);
+ offset += 4;
+ final UnsafeSorterSpillWriter writer =
+ new UnsafeSorterSpillWriter(blockManager, 32 * 1024, writeMetrics, numRecords);
+ while (numRecords > 0) {
+ int length = Platform.getInt(base, offset);
+ writer.write(base, offset + 4, length, 0);
+ offset += 4 + length;
+ numRecords--;
+ }
+ writer.close();
+ spillWriters.add(writer);
+
+ dataPages.removeLast();
+ released += block.size();
+ freePage(block);
+
+ if (released >= numBytes) {
+ break;
+ }
+ }
+
+ return released;
+ }
}
@Override
@@ -290,8 +392,8 @@ public final class BytesToBytesMap {
* If any other lookups or operations are performed on this map while iterating over it, including
* `lookup()`, the behavior of the returned iterator is undefined.
*/
- public BytesToBytesMapIterator iterator() {
- return new BytesToBytesMapIterator(numElements, dataPages.iterator(), loc, false, this);
+ public MapIterator iterator() {
+ return new MapIterator(numElements, loc, false);
}
/**
@@ -304,8 +406,8 @@ public final class BytesToBytesMap {
* If any other lookups or operations are performed on this map while iterating over it, including
* `lookup()`, the behavior of the returned iterator is undefined.
*/
- public BytesToBytesMapIterator destructiveIterator() {
- return new BytesToBytesMapIterator(numElements, dataPages.iterator(), loc, true, this);
+ public MapIterator destructiveIterator() {
+ return new MapIterator(numElements, loc, true);
}
/**
@@ -314,11 +416,8 @@ public final class BytesToBytesMap {
*
* This function always return the same {@link Location} instance to avoid object allocation.
*/
- public Location lookup(
- Object keyBaseObject,
- long keyBaseOffset,
- int keyRowLengthBytes) {
- safeLookup(keyBaseObject, keyBaseOffset, keyRowLengthBytes, loc);
+ public Location lookup(Object keyBase, long keyOffset, int keyLength) {
+ safeLookup(keyBase, keyOffset, keyLength, loc);
return loc;
}
@@ -327,18 +426,14 @@ public final class BytesToBytesMap {
*
* This is a thread-safe version of `lookup`, could be used by multiple threads.
*/
- public void safeLookup(
- Object keyBaseObject,
- long keyBaseOffset,
- int keyRowLengthBytes,
- Location loc) {
+ public void safeLookup(Object keyBase, long keyOffset, int keyLength, Location loc) {
assert(bitset != null);
assert(longArray != null);
if (enablePerfMetrics) {
numKeyLookups++;
}
- final int hashcode = HASHER.hashUnsafeWords(keyBaseObject, keyBaseOffset, keyRowLengthBytes);
+ final int hashcode = HASHER.hashUnsafeWords(keyBase, keyOffset, keyLength);
int pos = hashcode & mask;
int step = 1;
while (true) {
@@ -354,16 +449,16 @@ public final class BytesToBytesMap {
if ((int) (stored) == hashcode) {
// Full hash code matches. Let's compare the keys for equality.
loc.with(pos, hashcode, true);
- if (loc.getKeyLength() == keyRowLengthBytes) {
+ if (loc.getKeyLength() == keyLength) {
final MemoryLocation keyAddress = loc.getKeyAddress();
- final Object storedKeyBaseObject = keyAddress.getBaseObject();
- final long storedKeyBaseOffset = keyAddress.getBaseOffset();
+ final Object storedkeyBase = keyAddress.getBaseObject();
+ final long storedkeyOffset = keyAddress.getBaseOffset();
final boolean areEqual = ByteArrayMethods.arrayEquals(
- keyBaseObject,
- keyBaseOffset,
- storedKeyBaseObject,
- storedKeyBaseOffset,
- keyRowLengthBytes
+ keyBase,
+ keyOffset,
+ storedkeyBase,
+ storedkeyOffset,
+ keyLength
);
if (areEqual) {
return;
@@ -410,18 +505,18 @@ public final class BytesToBytesMap {
taskMemoryManager.getOffsetInPage(fullKeyAddress));
}
- private void updateAddressesAndSizes(final Object page, final long offsetInPage) {
- long position = offsetInPage;
- final int totalLength = Platform.getInt(page, position);
+ private void updateAddressesAndSizes(final Object base, final long offset) {
+ long position = offset;
+ final int totalLength = Platform.getInt(base, position);
position += 4;
- keyLength = Platform.getInt(page, position);
+ keyLength = Platform.getInt(base, position);
position += 4;
valueLength = totalLength - keyLength - 4;
- keyMemoryLocation.setObjAndOffset(page, position);
+ keyMemoryLocation.setObjAndOffset(base, position);
position += keyLength;
- valueMemoryLocation.setObjAndOffset(page, position);
+ valueMemoryLocation.setObjAndOffset(base, position);
}
private Location with(int pos, int keyHashcode, boolean isDefined) {
@@ -444,6 +539,19 @@ public final class BytesToBytesMap {
}
/**
+ * This is only used for spilling
+ */
+ private Location with(Object base, long offset, int length) {
+ this.isDefined = true;
+ this.memoryPage = null;
+ keyLength = Platform.getInt(base, offset);
+ valueLength = length - 4 - keyLength;
+ keyMemoryLocation.setObjAndOffset(base, offset + 4);
+ valueMemoryLocation.setObjAndOffset(base, offset + 4 + keyLength);
+ return this;
+ }
+
+ /**
* Returns the memory page that contains the current record.
* This is only valid if this is returned by {@link BytesToBytesMap#iterator()}.
*/
@@ -517,9 +625,9 @@ public final class BytesToBytesMap {
* As an example usage, here's the proper way to store a new key:
* </p>
* <pre>
- * Location loc = map.lookup(keyBaseObject, keyBaseOffset, keyLengthInBytes);
+ * Location loc = map.lookup(keyBase, keyOffset, keyLength);
* if (!loc.isDefined()) {
- * if (!loc.putNewKey(keyBaseObject, keyBaseOffset, keyLengthInBytes, ...)) {
+ * if (!loc.putNewKey(keyBase, keyOffset, keyLength, ...)) {
* // handle failure to grow map (by spilling, for example)
* }
* }
@@ -531,113 +639,59 @@ public final class BytesToBytesMap {
* @return true if the put() was successful and false if the put() failed because memory could
* not be acquired.
*/
- public boolean putNewKey(
- Object keyBaseObject,
- long keyBaseOffset,
- int keyLengthBytes,
- Object valueBaseObject,
- long valueBaseOffset,
- int valueLengthBytes) {
+ public boolean putNewKey(Object keyBase, long keyOffset, int keyLength,
+ Object valueBase, long valueOffset, int valueLength) {
assert (!isDefined) : "Can only set value once for a key";
- assert (keyLengthBytes % 8 == 0);
- assert (valueLengthBytes % 8 == 0);
+ assert (keyLength % 8 == 0);
+ assert (valueLength % 8 == 0);
assert(bitset != null);
assert(longArray != null);
- if (numElements == MAX_CAPACITY) {
- throw new IllegalStateException("BytesToBytesMap has reached maximum capacity");
+ if (numElements == MAX_CAPACITY || !canGrowArray) {
+ return false;
}
// Here, we'll copy the data into our data pages. Because we only store a relative offset from
// the key address instead of storing the absolute address of the value, the key and value
// must be stored in the same memory page.
// (8 byte key length) (key) (value)
- final long requiredSize = 8 + keyLengthBytes + valueLengthBytes;
-
- // --- Figure out where to insert the new record ---------------------------------------------
-
- final MemoryBlock dataPage;
- final Object dataPageBaseObject;
- final long dataPageInsertOffset;
- boolean useOverflowPage = requiredSize > pageSizeBytes - 8;
- if (useOverflowPage) {
- // The record is larger than the page size, so allocate a special overflow page just to hold
- // that record.
- final long overflowPageSize = requiredSize + 8;
- MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
- if (overflowPage == null) {
- logger.debug("Failed to acquire {} bytes of memory", overflowPageSize);
+ final long recordLength = 8 + keyLength + valueLength;
+ if (currentPage == null || currentPage.size() - pageCursor < recordLength) {
+ if (!acquireNewPage(recordLength + 4L)) {
return false;
}
- dataPages.add(overflowPage);
- dataPage = overflowPage;
- dataPageBaseObject = overflowPage.getBaseObject();
- dataPageInsertOffset = overflowPage.getBaseOffset();
- } else if (currentDataPage == null || pageSizeBytes - 8 - pageCursor < requiredSize) {
- // The record can fit in a data page, but either we have not allocated any pages yet or
- // the current page does not have enough space.
- if (currentDataPage != null) {
- // There wasn't enough space in the current page, so write an end-of-page marker:
- final Object pageBaseObject = currentDataPage.getBaseObject();
- final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor;
- Platform.putInt(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER);
- }
- if (!acquireNewPage()) {
- return false;
- }
- dataPage = currentDataPage;
- dataPageBaseObject = currentDataPage.getBaseObject();
- dataPageInsertOffset = currentDataPage.getBaseOffset();
- } else {
- // There is enough space in the current data page.
- dataPage = currentDataPage;
- dataPageBaseObject = currentDataPage.getBaseObject();
- dataPageInsertOffset = currentDataPage.getBaseOffset() + pageCursor;
}
// --- Append the key and value data to the current data page --------------------------------
-
- long insertCursor = dataPageInsertOffset;
-
- // Compute all of our offsets up-front:
- final long recordOffset = insertCursor;
- insertCursor += 4;
- final long keyLengthOffset = insertCursor;
- insertCursor += 4;
- final long keyDataOffsetInPage = insertCursor;
- insertCursor += keyLengthBytes;
- final long valueDataOffsetInPage = insertCursor;
- insertCursor += valueLengthBytes; // word used to store the value size
-
- Platform.putInt(dataPageBaseObject, recordOffset,
- keyLengthBytes + valueLengthBytes + 4);
- Platform.putInt(dataPageBaseObject, keyLengthOffset, keyLengthBytes);
- // Copy the key
- Platform.copyMemory(
- keyBaseObject, keyBaseOffset, dataPageBaseObject, keyDataOffsetInPage, keyLengthBytes);
- // Copy the value
- Platform.copyMemory(valueBaseObject, valueBaseOffset, dataPageBaseObject,
- valueDataOffsetInPage, valueLengthBytes);
-
- // --- Update bookeeping data structures -----------------------------------------------------
-
- if (useOverflowPage) {
- // Store the end-of-page marker at the end of the data page
- Platform.putInt(dataPageBaseObject, insertCursor, END_OF_PAGE_MARKER);
- } else {
- pageCursor += requiredSize;
- }
-
+ final Object base = currentPage.getBaseObject();
+ long offset = currentPage.getBaseOffset() + pageCursor;
+ final long recordOffset = offset;
+ Platform.putInt(base, offset, keyLength + valueLength + 4);
+ Platform.putInt(base, offset + 4, keyLength);
+ offset += 8;
+ Platform.copyMemory(keyBase, keyOffset, base, offset, keyLength);
+ offset += keyLength;
+ Platform.copyMemory(valueBase, valueOffset, base, offset, valueLength);
+
+ // --- Update bookkeeping data structures -----------------------------------------------------
+ offset = currentPage.getBaseOffset();
+ Platform.putInt(base, offset, Platform.getInt(base, offset) + 1);
+ pageCursor += recordLength;
numElements++;
bitset.set(pos);
final long storedKeyAddress = taskMemoryManager.encodePageNumberAndOffset(
- dataPage, recordOffset);
+ currentPage, recordOffset);
longArray.set(pos * 2, storedKeyAddress);
longArray.set(pos * 2 + 1, keyHashcode);
updateAddressesAndSizes(storedKeyAddress);
isDefined = true;
+
if (numElements > growthThreshold && longArray.size() < MAX_CAPACITY) {
- growAndRehash();
+ try {
+ growAndRehash();
+ } catch (OutOfMemoryError oom) {
+ canGrowArray = false;
+ }
}
return true;
}
@@ -647,18 +701,26 @@ public final class BytesToBytesMap {
* Acquire a new page from the memory manager.
* @return whether there is enough space to allocate the new page.
*/
- private boolean acquireNewPage() {
- MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes);
- if (newPage == null) {
- logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes);
+ private boolean acquireNewPage(long required) {
+ try {
+ currentPage = allocatePage(required);
+ } catch (OutOfMemoryError e) {
return false;
}
- dataPages.add(newPage);
- pageCursor = 0;
- currentDataPage = newPage;
+ dataPages.add(currentPage);
+ Platform.putInt(currentPage.getBaseObject(), currentPage.getBaseOffset(), 0);
+ pageCursor = 4;
return true;
}
+ @Override
+ public long spill(long size, MemoryConsumer trigger) throws IOException {
+ if (trigger != this && destructiveIterator != null) {
+ return destructiveIterator.spill(size);
+ }
+ return 0L;
+ }
+
/**
* Allocate new data structures for this map. When calling this outside of the constructor,
* make sure to keep references to the old data structures so that you can free them.
@@ -670,6 +732,7 @@ public final class BytesToBytesMap {
// The capacity needs to be divisible by 64 so that our bit set can be sized properly
capacity = Math.max((int) Math.min(MAX_CAPACITY, ByteArrayMethods.nextPowerOf2(capacity)), 64);
assert (capacity <= MAX_CAPACITY);
+ acquireMemory(capacity * 16);
longArray = new LongArray(MemoryBlock.fromLongArray(new long[capacity * 2]));
bitset = new BitSet(MemoryBlock.fromLongArray(new long[capacity / 64]));
@@ -678,22 +741,42 @@ public final class BytesToBytesMap {
}
/**
+ * Free the memory used by longArray.
+ */
+ public void freeArray() {
+ updatePeakMemoryUsed();
+ if (longArray != null) {
+ long used = longArray.memoryBlock().size();
+ longArray = null;
+ releaseMemory(used);
+ bitset = null;
+ }
+ }
+
+ /**
* Free all allocated memory associated with this map, including the storage for keys and values
* as well as the hash map array itself.
*
* This method is idempotent and can be called multiple times.
*/
public void free() {
- updatePeakMemoryUsed();
- longArray = null;
- bitset = null;
+ freeArray();
Iterator<MemoryBlock> dataPagesIterator = dataPages.iterator();
while (dataPagesIterator.hasNext()) {
MemoryBlock dataPage = dataPagesIterator.next();
dataPagesIterator.remove();
- taskMemoryManager.freePage(dataPage);
+ freePage(dataPage);
}
assert(dataPages.isEmpty());
+
+ while (!spillWriters.isEmpty()) {
+ File file = spillWriters.removeFirst().getFile();
+ if (file != null && file.exists()) {
+ if (!file.delete()) {
+ logger.error("Was unable to delete spill file {}", file.getAbsolutePath());
+ }
+ }
+ }
}
public TaskMemoryManager getTaskMemoryManager() {
@@ -782,7 +865,13 @@ public final class BytesToBytesMap {
final int oldCapacity = (int) oldBitSet.capacity();
// Allocate the new data structures
- allocate(Math.min(growthStrategy.nextCapacity(oldCapacity), MAX_CAPACITY));
+ try {
+ allocate(Math.min(growthStrategy.nextCapacity(oldCapacity), MAX_CAPACITY));
+ } catch (OutOfMemoryError oom) {
+ longArray = oldLongArray;
+ bitset = oldBitSet;
+ throw oom;
+ }
// Re-mask (we don't recompute the hashcode because we stored all 32 bits of it)
for (int pos = oldBitSet.nextSetBit(0); pos >= 0; pos = oldBitSet.nextSetBit(pos + 1)) {
@@ -806,6 +895,7 @@ public final class BytesToBytesMap {
}
}
}
+ releaseMemory(oldLongArray.memoryBlock().size());
if (enablePerfMetrics) {
timeSpentResizingNs += System.nanoTime() - resizeStartTime;
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index e317ea391c..49a5a4b13b 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -17,39 +17,34 @@
package org.apache.spark.util.collection.unsafe.sort;
+import javax.annotation.Nullable;
import java.io.File;
import java.io.IOException;
import java.util.LinkedList;
-import javax.annotation.Nullable;
-
-import scala.runtime.AbstractFunction0;
-import scala.runtime.BoxedUnit;
-
import com.google.common.annotations.VisibleForTesting;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.memory.MemoryConsumer;
+import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.storage.BlockManager;
-import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.memory.MemoryBlock;
-import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.util.TaskCompletionListener;
import org.apache.spark.util.Utils;
/**
* External sorter based on {@link UnsafeInMemorySorter}.
*/
-public final class UnsafeExternalSorter {
+public final class UnsafeExternalSorter extends MemoryConsumer {
private final Logger logger = LoggerFactory.getLogger(UnsafeExternalSorter.class);
- private final long pageSizeBytes;
private final PrefixComparator prefixComparator;
private final RecordComparator recordComparator;
- private final int initialSize;
private final TaskMemoryManager taskMemoryManager;
private final BlockManager blockManager;
private final TaskContext taskContext;
@@ -69,14 +64,12 @@ public final class UnsafeExternalSorter {
private final LinkedList<UnsafeSorterSpillWriter> spillWriters = new LinkedList<>();
// These variables are reset after spilling:
- @Nullable private UnsafeInMemorySorter inMemSorter;
- // Whether the in-mem sorter is created internally, or passed in from outside.
- // If it is passed in from outside, we shouldn't release the in-mem sorter's memory.
- private boolean isInMemSorterExternal = false;
+ @Nullable private volatile UnsafeInMemorySorter inMemSorter;
+
private MemoryBlock currentPage = null;
- private long currentPagePosition = -1;
- private long freeSpaceInCurrentPage = 0;
+ private long pageCursor = -1;
private long peakMemoryUsedBytes = 0;
+ private volatile SpillableIterator readingIterator = null;
public static UnsafeExternalSorter createWithExistingInMemorySorter(
TaskMemoryManager taskMemoryManager,
@@ -86,7 +79,7 @@ public final class UnsafeExternalSorter {
PrefixComparator prefixComparator,
int initialSize,
long pageSizeBytes,
- UnsafeInMemorySorter inMemorySorter) throws IOException {
+ UnsafeInMemorySorter inMemorySorter) {
return new UnsafeExternalSorter(taskMemoryManager, blockManager,
taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, inMemorySorter);
}
@@ -98,7 +91,7 @@ public final class UnsafeExternalSorter {
RecordComparator recordComparator,
PrefixComparator prefixComparator,
int initialSize,
- long pageSizeBytes) throws IOException {
+ long pageSizeBytes) {
return new UnsafeExternalSorter(taskMemoryManager, blockManager,
taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, null);
}
@@ -111,60 +104,41 @@ public final class UnsafeExternalSorter {
PrefixComparator prefixComparator,
int initialSize,
long pageSizeBytes,
- @Nullable UnsafeInMemorySorter existingInMemorySorter) throws IOException {
+ @Nullable UnsafeInMemorySorter existingInMemorySorter) {
+ super(taskMemoryManager, pageSizeBytes);
this.taskMemoryManager = taskMemoryManager;
this.blockManager = blockManager;
this.taskContext = taskContext;
this.recordComparator = recordComparator;
this.prefixComparator = prefixComparator;
- this.initialSize = initialSize;
// Use getSizeAsKb (not bytes) to maintain backwards compatibility for units
// this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
this.fileBufferSizeBytes = 32 * 1024;
- this.pageSizeBytes = pageSizeBytes;
+ // TODO: metrics tracking + integration with shuffle write metrics
+ // need to connect the write metrics to task metrics so we count the spill IO somewhere.
this.writeMetrics = new ShuffleWriteMetrics();
if (existingInMemorySorter == null) {
- initializeForWriting();
- // Acquire a new page as soon as we construct the sorter to ensure that we have at
- // least one page to work with. Otherwise, other operators in the same task may starve
- // this sorter (SPARK-9709). We don't need to do this if we already have an existing sorter.
- acquireNewPage();
+ this.inMemSorter =
+ new UnsafeInMemorySorter(taskMemoryManager, recordComparator, prefixComparator, initialSize);
+ acquireMemory(inMemSorter.getMemoryUsage());
} else {
- this.isInMemSorterExternal = true;
this.inMemSorter = existingInMemorySorter;
+ // will acquire after free the map
}
+ this.peakMemoryUsedBytes = getMemoryUsage();
// Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at
// the end of the task. This is necessary to avoid memory leaks in when the downstream operator
// does not fully consume the sorter's output (e.g. sort followed by limit).
- taskContext.addOnCompleteCallback(new AbstractFunction0<BoxedUnit>() {
- @Override
- public BoxedUnit apply() {
- cleanupResources();
- return null;
+ taskContext.addTaskCompletionListener(
+ new TaskCompletionListener() {
+ @Override
+ public void onTaskCompletion(TaskContext context) {
+ cleanupResources();
+ }
}
- });
- }
-
- // TODO: metrics tracking + integration with shuffle write metrics
- // need to connect the write metrics to task metrics so we count the spill IO somewhere.
-
- /**
- * Allocates new sort data structures. Called when creating the sorter and after each spill.
- */
- private void initializeForWriting() throws IOException {
- // Note: Do not track memory for the pointer array for now because of SPARK-10474.
- // In more detail, in TungstenAggregate we only reserve a page, but when we fall back to
- // sort-based aggregation we try to acquire a page AND a pointer array, which inevitably
- // fails if all other memory is already occupied. It should be safe to not track the array
- // because its memory footprint is frequently much smaller than that of a page. This is a
- // temporary hack that we should address in 1.6.0.
- // TODO: track the pointer array memory!
- this.writeMetrics = new ShuffleWriteMetrics();
- this.inMemSorter =
- new UnsafeInMemorySorter(taskMemoryManager, recordComparator, prefixComparator, initialSize);
- this.isInMemSorterExternal = false;
+ );
}
/**
@@ -173,14 +147,27 @@ public final class UnsafeExternalSorter {
*/
@VisibleForTesting
public void closeCurrentPage() {
- freeSpaceInCurrentPage = 0;
+ if (currentPage != null) {
+ pageCursor = currentPage.getBaseOffset() + currentPage.size();
+ }
}
/**
* Sort and spill the current records in response to memory pressure.
*/
- public void spill() throws IOException {
- assert(inMemSorter != null);
+ @Override
+ public long spill(long size, MemoryConsumer trigger) throws IOException {
+ if (trigger != this) {
+ if (readingIterator != null) {
+ return readingIterator.spill();
+ }
+ return 0L;
+ }
+
+ if (inMemSorter == null || inMemSorter.numRecords() <= 0) {
+ return 0L;
+ }
+
logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)",
Thread.currentThread().getId(),
Utils.bytesToString(getMemoryUsage()),
@@ -202,6 +189,8 @@ public final class UnsafeExternalSorter {
spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix());
}
spillWriter.close();
+
+ inMemSorter.reset();
}
final long spillSize = freeMemory();
@@ -210,7 +199,7 @@ public final class UnsafeExternalSorter {
// written to disk. This also counts the space needed to store the sorter's pointer array.
taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
- initializeForWriting();
+ return spillSize;
}
/**
@@ -246,7 +235,7 @@ public final class UnsafeExternalSorter {
}
/**
- * Free this sorter's in-memory data structures, including its data pages and pointer array.
+ * Free this sorter's data pages.
*
* @return the number of bytes freed.
*/
@@ -254,14 +243,12 @@ public final class UnsafeExternalSorter {
updatePeakMemoryUsed();
long memoryFreed = 0;
for (MemoryBlock block : allocatedPages) {
- taskMemoryManager.freePage(block);
memoryFreed += block.size();
+ freePage(block);
}
- // TODO: track in-memory sorter memory usage (SPARK-10474)
allocatedPages.clear();
currentPage = null;
- currentPagePosition = -1;
- freeSpaceInCurrentPage = 0;
+ pageCursor = 0;
return memoryFreed;
}
@@ -283,8 +270,15 @@ public final class UnsafeExternalSorter {
* Frees this sorter's in-memory data structures and cleans up its spill files.
*/
public void cleanupResources() {
- deleteSpillFiles();
- freeMemory();
+ synchronized (this) {
+ deleteSpillFiles();
+ freeMemory();
+ if (inMemSorter != null) {
+ long used = inMemSorter.getMemoryUsage();
+ inMemSorter = null;
+ releaseMemory(used);
+ }
+ }
}
/**
@@ -295,8 +289,28 @@ public final class UnsafeExternalSorter {
private void growPointerArrayIfNecessary() throws IOException {
assert(inMemSorter != null);
if (!inMemSorter.hasSpaceForAnotherRecord()) {
- // TODO: track the pointer array memory! (SPARK-10474)
- inMemSorter.expandPointerArray();
+ long used = inMemSorter.getMemoryUsage();
+ long needed = used + inMemSorter.getMemoryToExpand();
+ try {
+ acquireMemory(needed); // could trigger spilling
+ } catch (OutOfMemoryError e) {
+ // should have trigger spilling
+ assert(inMemSorter.hasSpaceForAnotherRecord());
+ return;
+ }
+ // check if spilling is triggered or not
+ if (inMemSorter.hasSpaceForAnotherRecord()) {
+ releaseMemory(needed);
+ } else {
+ try {
+ inMemSorter.expandPointerArray();
+ releaseMemory(used);
+ } catch (OutOfMemoryError oom) {
+ // Just in case that JVM had run out of memory
+ releaseMemory(needed);
+ spill();
+ }
+ }
}
}
@@ -304,101 +318,38 @@ public final class UnsafeExternalSorter {
* Allocates more memory in order to insert an additional record. This will request additional
* memory from the memory manager and spill if the requested memory can not be obtained.
*
- * @param requiredSpace the required space in the data page, in bytes, including space for storing
+ * @param required the required space in the data page, in bytes, including space for storing
* the record size. This must be less than or equal to the page size (records
* that exceed the page size are handled via a different code path which uses
* special overflow pages).
*/
- private void acquireNewPageIfNecessary(int requiredSpace) throws IOException {
- assert (requiredSpace <= pageSizeBytes);
- if (requiredSpace > freeSpaceInCurrentPage) {
- logger.trace("Required space {} is less than free space in current page ({})", requiredSpace,
- freeSpaceInCurrentPage);
- // TODO: we should track metrics on the amount of space wasted when we roll over to a new page
- // without using the free space at the end of the current page. We should also do this for
- // BytesToBytesMap.
- if (requiredSpace > pageSizeBytes) {
- throw new IOException("Required space " + requiredSpace + " is greater than page size (" +
- pageSizeBytes + ")");
- } else {
- acquireNewPage();
- }
+ private void acquireNewPageIfNecessary(int required) {
+ if (currentPage == null ||
+ pageCursor + required > currentPage.getBaseOffset() + currentPage.size()) {
+ // TODO: try to find space on previous pages
+ currentPage = allocatePage(required);
+ pageCursor = currentPage.getBaseOffset();
+ allocatedPages.add(currentPage);
}
}
/**
- * Acquire a new page from the memory manager.
- *
- * If there is not enough space to allocate the new page, spill all existing ones
- * and try again. If there is still not enough space, report error to the caller.
- */
- private void acquireNewPage() throws IOException {
- currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
- if (currentPage == null) {
- spill();
- currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
- if (currentPage == null) {
- throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory");
- }
- }
- currentPagePosition = currentPage.getBaseOffset();
- freeSpaceInCurrentPage = pageSizeBytes;
- allocatedPages.add(currentPage);
- }
-
- /**
* Write a record to the sorter.
*/
- public void insertRecord(
- Object recordBaseObject,
- long recordBaseOffset,
- int lengthInBytes,
- long prefix) throws IOException {
+ public void insertRecord(Object recordBase, long recordOffset, int length, long prefix)
+ throws IOException {
growPointerArrayIfNecessary();
// Need 4 bytes to store the record length.
- final int totalSpaceRequired = lengthInBytes + 4;
-
- // --- Figure out where to insert the new record ----------------------------------------------
-
- final MemoryBlock dataPage;
- long dataPagePosition;
- boolean useOverflowPage = totalSpaceRequired > pageSizeBytes;
- if (useOverflowPage) {
- long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired);
- // The record is larger than the page size, so allocate a special overflow page just to hold
- // that record.
- MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
- if (overflowPage == null) {
- spill();
- overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
- if (overflowPage == null) {
- throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory");
- }
- }
- allocatedPages.add(overflowPage);
- dataPage = overflowPage;
- dataPagePosition = overflowPage.getBaseOffset();
- } else {
- // The record is small enough to fit in a regular data page, but the current page might not
- // have enough space to hold it (or no pages have been allocated yet).
- acquireNewPageIfNecessary(totalSpaceRequired);
- dataPage = currentPage;
- dataPagePosition = currentPagePosition;
- // Update bookkeeping information
- freeSpaceInCurrentPage -= totalSpaceRequired;
- currentPagePosition += totalSpaceRequired;
- }
- final Object dataPageBaseObject = dataPage.getBaseObject();
-
- // --- Insert the record ----------------------------------------------------------------------
-
- final long recordAddress =
- taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition);
- Platform.putInt(dataPageBaseObject, dataPagePosition, lengthInBytes);
- dataPagePosition += 4;
- Platform.copyMemory(
- recordBaseObject, recordBaseOffset, dataPageBaseObject, dataPagePosition, lengthInBytes);
+ final int required = length + 4;
+ acquireNewPageIfNecessary(required);
+
+ final Object base = currentPage.getBaseObject();
+ final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor);
+ Platform.putInt(base, pageCursor, length);
+ pageCursor += 4;
+ Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length);
+ pageCursor += length;
assert(inMemSorter != null);
inMemSorter.insertRecord(recordAddress, prefix);
}
@@ -411,59 +362,24 @@ public final class UnsafeExternalSorter {
*
* record length = key length + value length + 4
*/
- public void insertKVRecord(
- Object keyBaseObj, long keyOffset, int keyLen,
- Object valueBaseObj, long valueOffset, int valueLen, long prefix) throws IOException {
+ public void insertKVRecord(Object keyBase, long keyOffset, int keyLen,
+ Object valueBase, long valueOffset, int valueLen, long prefix)
+ throws IOException {
growPointerArrayIfNecessary();
- final int totalSpaceRequired = keyLen + valueLen + 4 + 4;
-
- // --- Figure out where to insert the new record ----------------------------------------------
-
- final MemoryBlock dataPage;
- long dataPagePosition;
- boolean useOverflowPage = totalSpaceRequired > pageSizeBytes;
- if (useOverflowPage) {
- long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired);
- // The record is larger than the page size, so allocate a special overflow page just to hold
- // that record.
- MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
- if (overflowPage == null) {
- spill();
- overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
- if (overflowPage == null) {
- throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory");
- }
- }
- allocatedPages.add(overflowPage);
- dataPage = overflowPage;
- dataPagePosition = overflowPage.getBaseOffset();
- } else {
- // The record is small enough to fit in a regular data page, but the current page might not
- // have enough space to hold it (or no pages have been allocated yet).
- acquireNewPageIfNecessary(totalSpaceRequired);
- dataPage = currentPage;
- dataPagePosition = currentPagePosition;
- // Update bookkeeping information
- freeSpaceInCurrentPage -= totalSpaceRequired;
- currentPagePosition += totalSpaceRequired;
- }
- final Object dataPageBaseObject = dataPage.getBaseObject();
-
- // --- Insert the record ----------------------------------------------------------------------
-
- final long recordAddress =
- taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition);
- Platform.putInt(dataPageBaseObject, dataPagePosition, keyLen + valueLen + 4);
- dataPagePosition += 4;
-
- Platform.putInt(dataPageBaseObject, dataPagePosition, keyLen);
- dataPagePosition += 4;
-
- Platform.copyMemory(keyBaseObj, keyOffset, dataPageBaseObject, dataPagePosition, keyLen);
- dataPagePosition += keyLen;
-
- Platform.copyMemory(valueBaseObj, valueOffset, dataPageBaseObject, dataPagePosition, valueLen);
+ final int required = keyLen + valueLen + 4 + 4;
+ acquireNewPageIfNecessary(required);
+
+ final Object base = currentPage.getBaseObject();
+ final long recordAddress = taskMemoryManager.encodePageNumberAndOffset(currentPage, pageCursor);
+ Platform.putInt(base, pageCursor, keyLen + valueLen + 4);
+ pageCursor += 4;
+ Platform.putInt(base, pageCursor, keyLen);
+ pageCursor += 4;
+ Platform.copyMemory(keyBase, keyOffset, base, pageCursor, keyLen);
+ pageCursor += keyLen;
+ Platform.copyMemory(valueBase, valueOffset, base, pageCursor, valueLen);
+ pageCursor += valueLen;
assert(inMemSorter != null);
inMemSorter.insertRecord(recordAddress, prefix);
@@ -475,10 +391,10 @@ public final class UnsafeExternalSorter {
*/
public UnsafeSorterIterator getSortedIterator() throws IOException {
assert(inMemSorter != null);
- final UnsafeInMemorySorter.SortedIterator inMemoryIterator = inMemSorter.getSortedIterator();
- int numIteratorsToMerge = spillWriters.size() + (inMemoryIterator.hasNext() ? 1 : 0);
+ readingIterator = new SpillableIterator(inMemSorter.getSortedIterator());
+ int numIteratorsToMerge = spillWriters.size() + (readingIterator.hasNext() ? 1 : 0);
if (spillWriters.isEmpty()) {
- return inMemoryIterator;
+ return readingIterator;
} else {
final UnsafeSorterSpillMerger spillMerger =
new UnsafeSorterSpillMerger(recordComparator, prefixComparator, numIteratorsToMerge);
@@ -486,9 +402,113 @@ public final class UnsafeExternalSorter {
spillMerger.addSpillIfNotEmpty(spillWriter.getReader(blockManager));
}
spillWriters.clear();
- spillMerger.addSpillIfNotEmpty(inMemoryIterator);
+ spillMerger.addSpillIfNotEmpty(readingIterator);
return spillMerger.getSortedIterator();
}
}
+
+ /**
+ * An UnsafeSorterIterator that support spilling.
+ */
+ class SpillableIterator extends UnsafeSorterIterator {
+ private UnsafeSorterIterator upstream;
+ private UnsafeSorterIterator nextUpstream = null;
+ private MemoryBlock lastPage = null;
+ private boolean loaded = false;
+ private int numRecords = 0;
+
+ public SpillableIterator(UnsafeInMemorySorter.SortedIterator inMemIterator) {
+ this.upstream = inMemIterator;
+ this.numRecords = inMemIterator.numRecordsLeft();
+ }
+
+ public long spill() throws IOException {
+ synchronized (this) {
+ if (!(upstream instanceof UnsafeInMemorySorter.SortedIterator && nextUpstream == null
+ && numRecords > 0)) {
+ return 0L;
+ }
+
+ UnsafeInMemorySorter.SortedIterator inMemIterator =
+ ((UnsafeInMemorySorter.SortedIterator) upstream).clone();
+
+ final UnsafeSorterSpillWriter spillWriter =
+ new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, numRecords);
+ while (inMemIterator.hasNext()) {
+ inMemIterator.loadNext();
+ final Object baseObject = inMemIterator.getBaseObject();
+ final long baseOffset = inMemIterator.getBaseOffset();
+ final int recordLength = inMemIterator.getRecordLength();
+ spillWriter.write(baseObject, baseOffset, recordLength, inMemIterator.getKeyPrefix());
+ }
+ spillWriter.close();
+ spillWriters.add(spillWriter);
+ nextUpstream = spillWriter.getReader(blockManager);
+
+ long released = 0L;
+ synchronized (UnsafeExternalSorter.this) {
+ // release the pages except the one that is used
+ for (MemoryBlock page : allocatedPages) {
+ if (!loaded || page.getBaseObject() != inMemIterator.getBaseObject()) {
+ released += page.size();
+ freePage(page);
+ } else {
+ lastPage = page;
+ }
+ }
+ allocatedPages.clear();
+ }
+ return released;
+ }
+ }
+
+ @Override
+ public boolean hasNext() {
+ return numRecords > 0;
+ }
+
+ @Override
+ public void loadNext() throws IOException {
+ synchronized (this) {
+ loaded = true;
+ if (nextUpstream != null) {
+ // Just consumed the last record from in memory iterator
+ if (lastPage != null) {
+ freePage(lastPage);
+ lastPage = null;
+ }
+ upstream = nextUpstream;
+ nextUpstream = null;
+
+ assert(inMemSorter != null);
+ long used = inMemSorter.getMemoryUsage();
+ inMemSorter = null;
+ releaseMemory(used);
+ }
+ numRecords--;
+ upstream.loadNext();
+ }
+ }
+
+ @Override
+ public Object getBaseObject() {
+ return upstream.getBaseObject();
+ }
+
+ @Override
+ public long getBaseOffset() {
+ return upstream.getBaseOffset();
+ }
+
+ @Override
+ public int getRecordLength() {
+ return upstream.getRecordLength();
+ }
+
+ @Override
+ public long getKeyPrefix() {
+ return upstream.getKeyPrefix();
+ }
+ }
}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index 5aad72c374..1480f0681e 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -70,12 +70,12 @@ public final class UnsafeInMemorySorter {
* Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at
* index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix.
*/
- private long[] pointerArray;
+ private long[] array;
/**
* The position in the sort buffer where new records can be inserted.
*/
- private int pointerArrayInsertPosition = 0;
+ private int pos = 0;
public UnsafeInMemorySorter(
final TaskMemoryManager memoryManager,
@@ -83,37 +83,43 @@ public final class UnsafeInMemorySorter {
final PrefixComparator prefixComparator,
int initialSize) {
assert (initialSize > 0);
- this.pointerArray = new long[initialSize * 2];
+ this.array = new long[initialSize * 2];
this.memoryManager = memoryManager;
this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE);
this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager);
}
+ public void reset() {
+ pos = 0;
+ }
+
/**
* @return the number of records that have been inserted into this sorter.
*/
public int numRecords() {
- return pointerArrayInsertPosition / 2;
+ return pos / 2;
}
- public long getMemoryUsage() {
- return pointerArray.length * 8L;
+ private int newLength() {
+ return array.length < Integer.MAX_VALUE / 2 ? (array.length * 2) : Integer.MAX_VALUE;
+ }
+
+ public long getMemoryToExpand() {
+ return (long) (newLength() - array.length) * 8L;
}
- static long getMemoryRequirementsForPointerArray(long numEntries) {
- return numEntries * 2L * 8L;
+ public long getMemoryUsage() {
+ return array.length * 8L;
}
public boolean hasSpaceForAnotherRecord() {
- return pointerArrayInsertPosition + 2 < pointerArray.length;
+ return pos + 2 <= array.length;
}
public void expandPointerArray() {
- final long[] oldArray = pointerArray;
- // Guard against overflow:
- final int newLength = oldArray.length * 2 > 0 ? (oldArray.length * 2) : Integer.MAX_VALUE;
- pointerArray = new long[newLength];
- System.arraycopy(oldArray, 0, pointerArray, 0, oldArray.length);
+ final long[] oldArray = array;
+ array = new long[newLength()];
+ System.arraycopy(oldArray, 0, array, 0, oldArray.length);
}
/**
@@ -127,10 +133,10 @@ public final class UnsafeInMemorySorter {
if (!hasSpaceForAnotherRecord()) {
expandPointerArray();
}
- pointerArray[pointerArrayInsertPosition] = recordPointer;
- pointerArrayInsertPosition++;
- pointerArray[pointerArrayInsertPosition] = keyPrefix;
- pointerArrayInsertPosition++;
+ array[pos] = recordPointer;
+ pos++;
+ array[pos] = keyPrefix;
+ pos++;
}
public static final class SortedIterator extends UnsafeSorterIterator {
@@ -153,11 +159,25 @@ public final class UnsafeInMemorySorter {
this.sortBuffer = sortBuffer;
}
+ public SortedIterator clone () {
+ SortedIterator iter = new SortedIterator(memoryManager, sortBufferInsertPosition, sortBuffer);
+ iter.position = position;
+ iter.baseObject = baseObject;
+ iter.baseOffset = baseOffset;
+ iter.keyPrefix = keyPrefix;
+ iter.recordLength = recordLength;
+ return iter;
+ }
+
@Override
public boolean hasNext() {
return position < sortBufferInsertPosition;
}
+ public int numRecordsLeft() {
+ return (sortBufferInsertPosition - position) / 2;
+ }
+
@Override
public void loadNext() {
// This pointer points to a 4-byte record length, followed by the record's bytes
@@ -187,7 +207,7 @@ public final class UnsafeInMemorySorter {
* {@code next()} will return the same mutable object.
*/
public SortedIterator getSortedIterator() {
- sorter.sort(pointerArray, 0, pointerArrayInsertPosition / 2, sortComparator);
- return new SortedIterator(memoryManager, pointerArrayInsertPosition, pointerArray);
+ sorter.sort(array, 0, pos / 2, sortComparator);
+ return new SortedIterator(memoryManager, pos, array);
}
}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
index 501dfe77d1..039e940a35 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java
@@ -20,18 +20,18 @@ package org.apache.spark.util.collection.unsafe.sort;
import java.io.*;
import com.google.common.io.ByteStreams;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import org.apache.spark.storage.BlockId;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.unsafe.Platform;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
/**
* Reads spill files written by {@link UnsafeSorterSpillWriter} (see that class for a description
* of the file format).
*/
-final class UnsafeSorterSpillReader extends UnsafeSorterIterator {
+public final class UnsafeSorterSpillReader extends UnsafeSorterIterator {
private static final Logger logger = LoggerFactory.getLogger(UnsafeSorterSpillReader.class);
private final File file;
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
index e59a84ff8d..234e21140a 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
@@ -35,7 +35,7 @@ import org.apache.spark.unsafe.Platform;
*
* [# of records (int)] [[len (int)][prefix (long)][data (bytes)]...]
*/
-final class UnsafeSorterSpillWriter {
+public final class UnsafeSorterSpillWriter {
static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024;
diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala
index 6c9a71c385..b0cf2696a3 100644
--- a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala
+++ b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala
@@ -24,6 +24,7 @@ import scala.collection.mutable.ArrayBuffer
import com.google.common.annotations.VisibleForTesting
+import org.apache.spark.util.Utils
import org.apache.spark.{SparkException, TaskContext, SparkConf, Logging}
import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore}
import org.apache.spark.unsafe.array.ByteArrayMethods
@@ -215,8 +216,12 @@ private[spark] abstract class MemoryManager(conf: SparkConf, numCores: Int) exte
final def releaseExecutionMemory(numBytes: Long, taskAttemptId: Long): Unit = synchronized {
val curMem = executionMemoryForTask.getOrElse(taskAttemptId, 0L)
if (curMem < numBytes) {
- throw new SparkException(
- s"Internal error: release called on $numBytes bytes but task only has $curMem")
+ if (Utils.isTesting) {
+ throw new SparkException(
+ s"Internal error: release called on $numBytes bytes but task only has $curMem")
+ } else {
+ logWarning(s"Internal error: release called on $numBytes bytes but task only has $curMem")
+ }
}
if (executionMemoryForTask.contains(taskAttemptId)) {
executionMemoryForTask(taskAttemptId) -= numBytes
diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
index a76891acf0..9e002621a6 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala
@@ -78,7 +78,7 @@ private[spark] trait Spillable[C] extends Logging {
if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) {
// Claim up to double our current memory from the shuffle memory pool
val amountToRequest = 2 * currentMemory - myMemoryThreshold
- val granted = taskMemoryManager.acquireExecutionMemory(amountToRequest)
+ val granted = taskMemoryManager.acquireExecutionMemory(amountToRequest, null)
myMemoryThreshold += granted
// If we were granted too little memory to grow further (either tryToAcquire returned 0,
// or we already had more memory than myMemoryThreshold), spill the current collection
@@ -107,7 +107,7 @@ private[spark] trait Spillable[C] extends Logging {
*/
def releaseMemory(): Unit = {
// The amount we requested does not include the initial memory tracking threshold
- taskMemoryManager.releaseExecutionMemory(myMemoryThreshold - initialMemoryThreshold)
+ taskMemoryManager.releaseExecutionMemory(myMemoryThreshold - initialMemoryThreshold, null)
myMemoryThreshold = initialMemoryThreshold
}
diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
index f381db0c62..dab7b0592c 100644
--- a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
+++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
@@ -17,6 +17,8 @@
package org.apache.spark.memory;
+import java.io.IOException;
+
import org.junit.Assert;
import org.junit.Test;
@@ -25,19 +27,40 @@ import org.apache.spark.unsafe.memory.MemoryBlock;
public class TaskMemoryManagerSuite {
+ class TestMemoryConsumer extends MemoryConsumer {
+ TestMemoryConsumer(TaskMemoryManager memoryManager) {
+ super(memoryManager);
+ }
+
+ @Override
+ public long spill(long size, MemoryConsumer trigger) throws IOException {
+ long used = getUsed();
+ releaseMemory(used);
+ return used;
+ }
+
+ void use(long size) {
+ acquireMemory(size);
+ }
+
+ void free(long size) {
+ releaseMemory(size);
+ }
+ }
+
@Test
public void leakedPageMemoryIsDetected() {
final TaskMemoryManager manager = new TaskMemoryManager(
- new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0);
- manager.allocatePage(4096); // leak memory
+ new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0);
+ manager.allocatePage(4096, null); // leak memory
Assert.assertEquals(4096, manager.cleanUpAllAllocatedMemory());
}
@Test
public void encodePageNumberAndOffsetOffHeap() {
final TaskMemoryManager manager = new TaskMemoryManager(
- new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "true")), 0);
- final MemoryBlock dataPage = manager.allocatePage(256);
+ new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "true")), 0);
+ final MemoryBlock dataPage = manager.allocatePage(256, null);
// In off-heap mode, an offset is an absolute address that may require more than 51 bits to
// encode. This test exercises that corner-case:
final long offset = ((1L << TaskMemoryManager.OFFSET_BITS) + 10);
@@ -49,11 +72,53 @@ public class TaskMemoryManagerSuite {
@Test
public void encodePageNumberAndOffsetOnHeap() {
final TaskMemoryManager manager = new TaskMemoryManager(
- new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0);
- final MemoryBlock dataPage = manager.allocatePage(256);
+ new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0);
+ final MemoryBlock dataPage = manager.allocatePage(256, null);
final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, 64);
Assert.assertEquals(dataPage.getBaseObject(), manager.getPage(encodedAddress));
Assert.assertEquals(64, manager.getOffsetInPage(encodedAddress));
}
+ @Test
+ public void cooperativeSpilling() {
+ final TestMemoryManager memoryManager = new TestMemoryManager(new SparkConf());
+ memoryManager.limit(100);
+ final TaskMemoryManager manager = new TaskMemoryManager(memoryManager, 0);
+
+ TestMemoryConsumer c1 = new TestMemoryConsumer(manager);
+ TestMemoryConsumer c2 = new TestMemoryConsumer(manager);
+ c1.use(100);
+ assert(c1.getUsed() == 100);
+ c2.use(100);
+ assert(c2.getUsed() == 100);
+ assert(c1.getUsed() == 0); // spilled
+ c1.use(100);
+ assert(c1.getUsed() == 100);
+ assert(c2.getUsed() == 0); // spilled
+
+ c1.use(50);
+ assert(c1.getUsed() == 50); // spilled
+ assert(c2.getUsed() == 0);
+ c2.use(50);
+ assert(c1.getUsed() == 50);
+ assert(c2.getUsed() == 50);
+
+ c1.use(100);
+ assert(c1.getUsed() == 100);
+ assert(c2.getUsed() == 0); // spilled
+
+ c1.free(20);
+ assert(c1.getUsed() == 80);
+ c2.use(10);
+ assert(c1.getUsed() == 80);
+ assert(c2.getUsed() == 10);
+ c2.use(100);
+ assert(c2.getUsed() == 100);
+ assert(c1.getUsed() == 0); // spilled
+
+ c1.free(0);
+ c2.free(100);
+ assert(manager.cleanUpAllAllocatedMemory() == 0);
+ }
+
}
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
index 7fb2f92ca8..9a43f1f3a9 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
@@ -17,25 +17,29 @@
package org.apache.spark.shuffle.sort;
-import org.apache.spark.shuffle.sort.PackedRecordPointer;
+import java.io.IOException;
+
import org.junit.Test;
-import static org.junit.Assert.*;
import org.apache.spark.SparkConf;
-import org.apache.spark.memory.GrantEverythingMemoryManager;
-import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.memory.TestMemoryManager;
import org.apache.spark.memory.TaskMemoryManager;
-import static org.apache.spark.shuffle.sort.PackedRecordPointer.*;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+
+import static org.apache.spark.shuffle.sort.PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES;
+import static org.apache.spark.shuffle.sort.PackedRecordPointer.MAXIMUM_PARTITION_ID;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
public class PackedRecordPointerSuite {
@Test
- public void heap() {
+ public void heap() throws IOException {
final SparkConf conf = new SparkConf().set("spark.unsafe.offHeap", "false");
final TaskMemoryManager memoryManager =
- new TaskMemoryManager(new GrantEverythingMemoryManager(conf), 0);
- final MemoryBlock page0 = memoryManager.allocatePage(128);
- final MemoryBlock page1 = memoryManager.allocatePage(128);
+ new TaskMemoryManager(new TestMemoryManager(conf), 0);
+ final MemoryBlock page0 = memoryManager.allocatePage(128, null);
+ final MemoryBlock page1 = memoryManager.allocatePage(128, null);
final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,
page1.getBaseOffset() + 42);
PackedRecordPointer packedPointer = new PackedRecordPointer();
@@ -49,12 +53,12 @@ public class PackedRecordPointerSuite {
}
@Test
- public void offHeap() {
+ public void offHeap() throws IOException {
final SparkConf conf = new SparkConf().set("spark.unsafe.offHeap", "true");
final TaskMemoryManager memoryManager =
- new TaskMemoryManager(new GrantEverythingMemoryManager(conf), 0);
- final MemoryBlock page0 = memoryManager.allocatePage(128);
- final MemoryBlock page1 = memoryManager.allocatePage(128);
+ new TaskMemoryManager(new TestMemoryManager(conf), 0);
+ final MemoryBlock page0 = memoryManager.allocatePage(128, null);
+ final MemoryBlock page1 = memoryManager.allocatePage(128, null);
final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,
page1.getBaseOffset() + 42);
PackedRecordPointer packedPointer = new PackedRecordPointer();
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
index 5049a5306f..2293b1bbc1 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
@@ -26,7 +26,7 @@ import org.junit.Test;
import org.apache.spark.HashPartitioner;
import org.apache.spark.SparkConf;
import org.apache.spark.unsafe.Platform;
-import org.apache.spark.memory.GrantEverythingMemoryManager;
+import org.apache.spark.memory.TestMemoryManager;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.memory.TaskMemoryManager;
@@ -60,8 +60,8 @@ public class ShuffleInMemorySorterSuite {
};
final SparkConf conf = new SparkConf().set("spark.unsafe.offHeap", "false");
final TaskMemoryManager memoryManager =
- new TaskMemoryManager(new GrantEverythingMemoryManager(conf), 0);
- final MemoryBlock dataPage = memoryManager.allocatePage(2048);
+ new TaskMemoryManager(new TestMemoryManager(conf), 0);
+ final MemoryBlock dataPage = memoryManager.allocatePage(2048, null);
final Object baseObject = dataPage.getBaseObject();
final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4);
final HashPartitioner hashPartitioner = new HashPartitioner(4);
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index d65926949c..4763395d7d 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -54,13 +54,14 @@ import org.apache.spark.serializer.*;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
import org.apache.spark.storage.*;
-import org.apache.spark.memory.GrantEverythingMemoryManager;
+import org.apache.spark.memory.TestMemoryManager;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.util.Utils;
public class UnsafeShuffleWriterSuite {
static final int NUM_PARTITITONS = 4;
+ TestMemoryManager memoryManager;
TaskMemoryManager taskMemoryManager;
final HashPartitioner hashPartitioner = new HashPartitioner(NUM_PARTITITONS);
File mergedOutputFile;
@@ -106,10 +107,11 @@ public class UnsafeShuffleWriterSuite {
partitionSizesInMergedFile = null;
spillFilesCreated.clear();
conf = new SparkConf()
- .set("spark.buffer.pageSize", "128m")
+ .set("spark.buffer.pageSize", "1m")
.set("spark.unsafe.offHeap", "false");
taskMetrics = new TaskMetrics();
- taskMemoryManager = new TaskMemoryManager(new GrantEverythingMemoryManager(conf), 0);
+ memoryManager = new TestMemoryManager(conf);
+ taskMemoryManager = new TaskMemoryManager(memoryManager, 0);
when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
when(blockManager.getDiskWriter(
@@ -344,9 +346,7 @@ public class UnsafeShuffleWriterSuite {
}
assertEquals(sumOfPartitionSizes, mergedOutputFile.length());
- assertEquals(
- HashMultiset.create(dataToWrite),
- HashMultiset.create(readRecordsFromFile()));
+ assertEquals(HashMultiset.create(dataToWrite), HashMultiset.create(readRecordsFromFile()));
assertSpillFilesWereCleanedUp();
ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get();
assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten());
@@ -398,20 +398,14 @@ public class UnsafeShuffleWriterSuite {
@Test
public void writeEnoughDataToTriggerSpill() throws Exception {
- taskMemoryManager = spy(taskMemoryManager);
- doCallRealMethod() // initialize sort buffer
- .doCallRealMethod() // allocate initial data page
- .doReturn(0L) // deny request to allocate new page
- .doCallRealMethod() // grant new sort buffer and data page
- .when(taskMemoryManager).acquireExecutionMemory(anyLong());
+ memoryManager.limit(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES);
final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<Product2<Object, Object>>();
- final byte[] bigByteArray = new byte[PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES / 128];
- for (int i = 0; i < 128 + 1; i++) {
+ final byte[] bigByteArray = new byte[PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES / 10];
+ for (int i = 0; i < 10 + 1; i++) {
dataToWrite.add(new Tuple2<Object, Object>(i, bigByteArray));
}
writer.write(dataToWrite.iterator());
- verify(taskMemoryManager, times(5)).acquireExecutionMemory(anyLong());
assertEquals(2, spillFilesCreated.size());
writer.stop(true);
readRecordsFromFile();
@@ -426,19 +420,13 @@ public class UnsafeShuffleWriterSuite {
@Test
public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception {
- taskMemoryManager = spy(taskMemoryManager);
- doCallRealMethod() // initialize sort buffer
- .doCallRealMethod() // allocate initial data page
- .doReturn(0L) // deny request to allocate new page
- .doCallRealMethod() // grant new sort buffer and data page
- .when(taskMemoryManager).acquireExecutionMemory(anyLong());
+ memoryManager.limit(UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE * 16);
final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<>();
for (int i = 0; i < UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE; i++) {
dataToWrite.add(new Tuple2<Object, Object>(i, i));
}
writer.write(dataToWrite.iterator());
- verify(taskMemoryManager, times(5)).acquireExecutionMemory(anyLong());
assertEquals(2, spillFilesCreated.size());
writer.stop(true);
readRecordsFromFile();
@@ -473,11 +461,11 @@ public class UnsafeShuffleWriterSuite {
final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<Product2<Object, Object>>();
dataToWrite.add(new Tuple2<Object, Object>(1, ByteBuffer.wrap(new byte[1])));
// We should be able to write a record that's right _at_ the max record size
- final byte[] atMaxRecordSize = new byte[writer.maxRecordSizeBytes()];
+ final byte[] atMaxRecordSize = new byte[(int) taskMemoryManager.pageSizeBytes() - 4];
new Random(42).nextBytes(atMaxRecordSize);
dataToWrite.add(new Tuple2<Object, Object>(2, ByteBuffer.wrap(atMaxRecordSize)));
// Inserting a record that's larger than the max record size
- final byte[] exceedsMaxRecordSize = new byte[writer.maxRecordSizeBytes() + 1];
+ final byte[] exceedsMaxRecordSize = new byte[(int) taskMemoryManager.pageSizeBytes()];
new Random(42).nextBytes(exceedsMaxRecordSize);
dataToWrite.add(new Tuple2<Object, Object>(3, ByteBuffer.wrap(exceedsMaxRecordSize)));
writer.write(dataToWrite.iterator());
@@ -524,7 +512,7 @@ public class UnsafeShuffleWriterSuite {
for (int i = 0; i < numRecordsPerPage * 10; i++) {
writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1));
newPeakMemory = writer.getPeakMemoryUsedBytes();
- if (i % numRecordsPerPage == 0 && i != 0) {
+ if (i % numRecordsPerPage == 0) {
// The first page is allocated in constructor, another page will be allocated after
// every numRecordsPerPage records (peak memory should change).
assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index 6e52496cf9..92bd45e5fa 100644
--- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -17,40 +17,117 @@
package org.apache.spark.unsafe.map;
-import java.lang.Exception;
+import java.io.File;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.util.*;
-import org.apache.spark.memory.TaskMemoryManager;
-import org.junit.*;
-import static org.hamcrest.Matchers.greaterThan;
-import static org.junit.Assert.*;
+import scala.Tuple2;
+import scala.Tuple2$;
+import scala.runtime.AbstractFunction1;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
import org.apache.spark.SparkConf;
-import org.apache.spark.memory.GrantEverythingMemoryManager;
-import org.apache.spark.unsafe.array.ByteArrayMethods;
-import org.apache.spark.unsafe.memory.*;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.memory.TestMemoryManager;
+import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.storage.*;
import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.ByteArrayMethods;
+import org.apache.spark.unsafe.memory.MemoryLocation;
+import org.apache.spark.util.Utils;
+
+import static org.hamcrest.Matchers.greaterThan;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.mockito.AdditionalAnswers.returnsSecondArg;
+import static org.mockito.Answers.RETURNS_SMART_NULLS;
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyInt;
+import static org.mockito.Mockito.when;
public abstract class AbstractBytesToBytesMapSuite {
private final Random rand = new Random(42);
- private GrantEverythingMemoryManager memoryManager;
+ private TestMemoryManager memoryManager;
private TaskMemoryManager taskMemoryManager;
private final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes
+ final LinkedList<File> spillFilesCreated = new LinkedList<File>();
+ File tempDir;
+
+ @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
+ @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
+
+ private static final class CompressStream extends AbstractFunction1<OutputStream, OutputStream> {
+ @Override
+ public OutputStream apply(OutputStream stream) {
+ return stream;
+ }
+ }
+
@Before
public void setup() {
memoryManager =
- new GrantEverythingMemoryManager(
+ new TestMemoryManager(
new SparkConf().set("spark.unsafe.offHeap", "" + useOffHeapMemoryAllocator()));
taskMemoryManager = new TaskMemoryManager(memoryManager, 0);
+
+ tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "unsafe-test");
+ spillFilesCreated.clear();
+ MockitoAnnotations.initMocks(this);
+ when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
+ when(diskBlockManager.createTempLocalBlock()).thenAnswer(new Answer<Tuple2<TempLocalBlockId, File>>() {
+ @Override
+ public Tuple2<TempLocalBlockId, File> answer(InvocationOnMock invocationOnMock) throws Throwable {
+ TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID());
+ File file = File.createTempFile("spillFile", ".spill", tempDir);
+ spillFilesCreated.add(file);
+ return Tuple2$.MODULE$.apply(blockId, file);
+ }
+ });
+ when(blockManager.getDiskWriter(
+ any(BlockId.class),
+ any(File.class),
+ any(SerializerInstance.class),
+ anyInt(),
+ any(ShuffleWriteMetrics.class))).thenAnswer(new Answer<DiskBlockObjectWriter>() {
+ @Override
+ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable {
+ Object[] args = invocationOnMock.getArguments();
+
+ return new DiskBlockObjectWriter(
+ (File) args[1],
+ (SerializerInstance) args[2],
+ (Integer) args[3],
+ new CompressStream(),
+ false,
+ (ShuffleWriteMetrics) args[4]
+ );
+ }
+ });
+ when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class)))
+ .then(returnsSecondArg());
}
@After
public void tearDown() {
+ Utils.deleteRecursively(tempDir);
+ tempDir = null;
+
Assert.assertEquals(0L, taskMemoryManager.cleanUpAllAllocatedMemory());
if (taskMemoryManager != null) {
long leakedMemory = taskMemoryManager.getMemoryConsumptionForThisTask();
@@ -415,9 +492,8 @@ public abstract class AbstractBytesToBytesMapSuite {
@Test
public void failureToAllocateFirstPage() {
- memoryManager.markExecutionAsOutOfMemory();
+ memoryManager.limit(1024); // longArray
BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 1, PAGE_SIZE_BYTES);
- memoryManager.markExecutionAsOutOfMemory();
try {
final long[] emptyArray = new long[0];
final BytesToBytesMap.Location loc =
@@ -439,7 +515,7 @@ public abstract class AbstractBytesToBytesMapSuite {
int i;
for (i = 0; i < 127; i++) {
if (i > 0) {
- memoryManager.markExecutionAsOutOfMemory();
+ memoryManager.limit(0);
}
final long[] arr = new long[]{i};
final BytesToBytesMap.Location loc = map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8);
@@ -457,6 +533,44 @@ public abstract class AbstractBytesToBytesMapSuite {
}
@Test
+ public void spillInIterator() throws IOException {
+ BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, blockManager, 1, 0.75, 1024, false);
+ try {
+ int i;
+ for (i = 0; i < 1024; i++) {
+ final long[] arr = new long[]{i};
+ final BytesToBytesMap.Location loc = map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8);
+ loc.putNewKey(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8);
+ }
+ BytesToBytesMap.MapIterator iter = map.iterator();
+ for (i = 0; i < 100; i++) {
+ iter.next();
+ }
+ // Non-destructive iterator is not spillable
+ Assert.assertEquals(0, iter.spill(1024L * 10));
+ for (i = 100; i < 1024; i++) {
+ iter.next();
+ }
+
+ BytesToBytesMap.MapIterator iter2 = map.destructiveIterator();
+ for (i = 0; i < 100; i++) {
+ iter2.next();
+ }
+ Assert.assertTrue(iter2.spill(1024) >= 1024);
+ for (i = 100; i < 1024; i++) {
+ iter2.next();
+ }
+ assertFalse(iter2.hasNext());
+ } finally {
+ map.free();
+ for (File spillFile : spillFilesCreated) {
+ assertFalse("Spill file " + spillFile.getPath() + " was not cleaned up",
+ spillFile.exists());
+ }
+ }
+ }
+
+ @Test
public void initialCapacityBoundsChecking() {
try {
new BytesToBytesMap(taskMemoryManager, 0, PAGE_SIZE_BYTES);
@@ -500,7 +614,7 @@ public abstract class AbstractBytesToBytesMapSuite {
Platform.LONG_ARRAY_OFFSET,
8);
newPeakMemory = map.getPeakMemoryUsedBytes();
- if (i % numRecordsPerPage == 0 && i > 0) {
+ if (i % numRecordsPerPage == 0) {
// We allocated a new page for this record, so peak memory should change
assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);
} else {
@@ -519,11 +633,4 @@ public abstract class AbstractBytesToBytesMapSuite {
}
}
- @Test
- public void testAcquirePageInConstructor() {
- final BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 1, PAGE_SIZE_BYTES);
- assertEquals(1, map.getNumDataPages());
- map.free();
- }
-
}
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
index 94d50b94fd..cfead0e592 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -36,28 +36,29 @@ import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
-import static org.hamcrest.Matchers.greaterThanOrEqualTo;
-import static org.junit.Assert.*;
-import static org.mockito.AdditionalAnswers.returnsSecondArg;
-import static org.mockito.Answers.RETURNS_SMART_NULLS;
-import static org.mockito.Mockito.*;
import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.executor.TaskMetrics;
-import org.apache.spark.memory.GrantEverythingMemoryManager;
+import org.apache.spark.memory.TestMemoryManager;
+import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.storage.*;
import org.apache.spark.unsafe.Platform;
-import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.util.Utils;
+import static org.hamcrest.Matchers.greaterThanOrEqualTo;
+import static org.junit.Assert.*;
+import static org.mockito.AdditionalAnswers.returnsSecondArg;
+import static org.mockito.Answers.RETURNS_SMART_NULLS;
+import static org.mockito.Mockito.*;
+
public class UnsafeExternalSorterSuite {
final LinkedList<File> spillFilesCreated = new LinkedList<File>();
- final GrantEverythingMemoryManager memoryManager =
- new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false"));
+ final TestMemoryManager memoryManager =
+ new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false"));
final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0);
// Use integer comparison for comparing prefixes (which are partition ids, in this case)
final PrefixComparator prefixComparator = new PrefixComparator() {
@@ -86,7 +87,7 @@ public class UnsafeExternalSorterSuite {
@Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext;
- private final long pageSizeBytes = new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "64m");
+ private final long pageSizeBytes = new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "4m");
private static final class CompressStream extends AbstractFunction1<OutputStream, OutputStream> {
@Override
@@ -233,7 +234,7 @@ public class UnsafeExternalSorterSuite {
insertNumber(sorter, numRecords - i);
}
assertEquals(1, sorter.getNumberOfAllocatedPages());
- memoryManager.markExecutionAsOutOfMemory();
+ memoryManager.markExecutionAsOutOfMemoryOnce();
// The insertion of this record should trigger a spill:
insertNumber(sorter, 0);
// Ensure that spill files were created
@@ -312,6 +313,62 @@ public class UnsafeExternalSorterSuite {
}
@Test
+ public void forcedSpillingWithReadIterator() throws Exception {
+ final UnsafeExternalSorter sorter = newSorter();
+ long[] record = new long[100];
+ int recordSize = record.length * 8;
+ int n = (int) pageSizeBytes / recordSize * 3;
+ for (int i = 0; i < n; i++) {
+ record[0] = (long) i;
+ sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0);
+ }
+ assert(sorter.getNumberOfAllocatedPages() >= 2);
+ UnsafeExternalSorter.SpillableIterator iter =
+ (UnsafeExternalSorter.SpillableIterator) sorter.getSortedIterator();
+ int lastv = 0;
+ for (int i = 0; i < n / 3; i++) {
+ iter.hasNext();
+ iter.loadNext();
+ assert(Platform.getLong(iter.getBaseObject(), iter.getBaseOffset()) == i);
+ lastv = i;
+ }
+ assert(iter.spill() > 0);
+ assert(iter.spill() == 0);
+ assert(Platform.getLong(iter.getBaseObject(), iter.getBaseOffset()) == lastv);
+ for (int i = n / 3; i < n; i++) {
+ iter.hasNext();
+ iter.loadNext();
+ assert(Platform.getLong(iter.getBaseObject(), iter.getBaseOffset()) == i);
+ }
+ sorter.cleanupResources();
+ assertSpillFilesWereCleanedUp();
+ }
+
+ @Test
+ public void forcedSpillingWithNotReadIterator() throws Exception {
+ final UnsafeExternalSorter sorter = newSorter();
+ long[] record = new long[100];
+ int recordSize = record.length * 8;
+ int n = (int) pageSizeBytes / recordSize * 3;
+ for (int i = 0; i < n; i++) {
+ record[0] = (long) i;
+ sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0);
+ }
+ assert(sorter.getNumberOfAllocatedPages() >= 2);
+ UnsafeExternalSorter.SpillableIterator iter =
+ (UnsafeExternalSorter.SpillableIterator) sorter.getSortedIterator();
+ assert(iter.spill() > 0);
+ assert(iter.spill() == 0);
+ for (int i = 0; i < n; i++) {
+ iter.hasNext();
+ iter.loadNext();
+ assert(Platform.getLong(iter.getBaseObject(), iter.getBaseOffset()) == i);
+ }
+ sorter.cleanupResources();
+ assertSpillFilesWereCleanedUp();
+ }
+
+ @Test
public void testPeakMemoryUsed() throws Exception {
final long recordLengthBytes = 8;
final long pageSizeBytes = 256;
@@ -334,7 +391,7 @@ public class UnsafeExternalSorterSuite {
insertNumber(sorter, i);
newPeakMemory = sorter.getPeakMemoryUsedBytes();
// The first page is pre-allocated on instantiation
- if (i % numRecordsPerPage == 0 && i > 0) {
+ if (i % numRecordsPerPage == 0) {
// We allocated a new page for this record, so peak memory should change
assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);
} else {
@@ -358,21 +415,5 @@ public class UnsafeExternalSorterSuite {
}
}
- @Test
- public void testReservePageOnInstantiation() throws Exception {
- final UnsafeExternalSorter sorter = newSorter();
- try {
- assertEquals(1, sorter.getNumberOfAllocatedPages());
- // Inserting a new record doesn't allocate more memory since we already have a page
- long peakMemory = sorter.getPeakMemoryUsedBytes();
- insertNumber(sorter, 100);
- assertEquals(peakMemory, sorter.getPeakMemoryUsedBytes());
- assertEquals(1, sorter.getNumberOfAllocatedPages());
- } finally {
- sorter.cleanupResources();
- assertSpillFilesWereCleanedUp();
- }
- }
-
}
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
index d5de56a051..642f6585f8 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
@@ -20,17 +20,19 @@ package org.apache.spark.util.collection.unsafe.sort;
import java.util.Arrays;
import org.junit.Test;
-import static org.hamcrest.MatcherAssert.assertThat;
-import static org.hamcrest.Matchers.*;
-import static org.junit.Assert.*;
-import static org.mockito.Mockito.mock;
import org.apache.spark.HashPartitioner;
import org.apache.spark.SparkConf;
+import org.apache.spark.memory.TestMemoryManager;
+import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.unsafe.Platform;
-import org.apache.spark.memory.GrantEverythingMemoryManager;
import org.apache.spark.unsafe.memory.MemoryBlock;
-import org.apache.spark.memory.TaskMemoryManager;
+
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.greaterThanOrEqualTo;
+import static org.hamcrest.Matchers.isIn;
+import static org.junit.Assert.assertEquals;
+import static org.mockito.Mockito.mock;
public class UnsafeInMemorySorterSuite {
@@ -44,7 +46,7 @@ public class UnsafeInMemorySorterSuite {
public void testSortingEmptyInput() {
final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(
new TaskMemoryManager(
- new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0),
+ new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0),
mock(RecordComparator.class),
mock(PrefixComparator.class),
100);
@@ -66,8 +68,8 @@ public class UnsafeInMemorySorterSuite {
"Mango"
};
final TaskMemoryManager memoryManager = new TaskMemoryManager(
- new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0);
- final MemoryBlock dataPage = memoryManager.allocatePage(2048);
+ new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0);
+ final MemoryBlock dataPage = memoryManager.allocatePage(2048, null);
final Object baseObject = dataPage.getBaseObject();
// Write the records into the data page:
long position = dataPage.getBaseOffset();
diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala
index 0242cbc924..203dab934c 100644
--- a/core/src/test/scala/org/apache/spark/FailureSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala
@@ -149,7 +149,7 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext {
// cause is preserved
val thrownDueToTaskFailure = intercept[SparkException] {
sc.parallelize(Seq(0)).mapPartitions { iter =>
- TaskContext.get().taskMemoryManager().allocatePage(128)
+ TaskContext.get().taskMemoryManager().allocatePage(128, null)
throw new Exception("intentional task failure")
iter
}.count()
@@ -159,7 +159,7 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext {
// If the task succeeded but memory was leaked, then the task should fail due to that leak
val thrownDueToMemoryLeak = intercept[SparkException] {
sc.parallelize(Seq(0)).mapPartitions { iter =>
- TaskContext.get().taskMemoryManager().allocatePage(128)
+ TaskContext.get().taskMemoryManager().allocatePage(128, null)
iter
}.count()
}
diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala
index 1265087743..4a9479cf49 100644
--- a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala
@@ -145,20 +145,20 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite {
val manager = createMemoryManager(1000L)
val taskMemoryManager = new TaskMemoryManager(manager, 0)
- assert(taskMemoryManager.acquireExecutionMemory(100L) === 100L)
- assert(taskMemoryManager.acquireExecutionMemory(400L) === 400L)
- assert(taskMemoryManager.acquireExecutionMemory(400L) === 400L)
- assert(taskMemoryManager.acquireExecutionMemory(200L) === 100L)
- assert(taskMemoryManager.acquireExecutionMemory(100L) === 0L)
- assert(taskMemoryManager.acquireExecutionMemory(100L) === 0L)
+ assert(taskMemoryManager.acquireExecutionMemory(100L, null) === 100L)
+ assert(taskMemoryManager.acquireExecutionMemory(400L, null) === 400L)
+ assert(taskMemoryManager.acquireExecutionMemory(400L, null) === 400L)
+ assert(taskMemoryManager.acquireExecutionMemory(200L, null) === 100L)
+ assert(taskMemoryManager.acquireExecutionMemory(100L, null) === 0L)
+ assert(taskMemoryManager.acquireExecutionMemory(100L, null) === 0L)
- taskMemoryManager.releaseExecutionMemory(500L)
- assert(taskMemoryManager.acquireExecutionMemory(300L) === 300L)
- assert(taskMemoryManager.acquireExecutionMemory(300L) === 200L)
+ taskMemoryManager.releaseExecutionMemory(500L, null)
+ assert(taskMemoryManager.acquireExecutionMemory(300L, null) === 300L)
+ assert(taskMemoryManager.acquireExecutionMemory(300L, null) === 200L)
taskMemoryManager.cleanUpAllAllocatedMemory()
- assert(taskMemoryManager.acquireExecutionMemory(1000L) === 1000L)
- assert(taskMemoryManager.acquireExecutionMemory(100L) === 0L)
+ assert(taskMemoryManager.acquireExecutionMemory(1000L, null) === 1000L)
+ assert(taskMemoryManager.acquireExecutionMemory(100L, null) === 0L)
}
test("two tasks requesting full execution memory") {
@@ -168,15 +168,15 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite {
val futureTimeout: Duration = 20.seconds
// Have both tasks request 500 bytes, then wait until both requests have been granted:
- val t1Result1 = Future { t1MemManager.acquireExecutionMemory(500L) }
- val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L) }
+ val t1Result1 = Future { t1MemManager.acquireExecutionMemory(500L, null) }
+ val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L, null) }
assert(Await.result(t1Result1, futureTimeout) === 500L)
assert(Await.result(t2Result1, futureTimeout) === 500L)
// Have both tasks each request 500 bytes more; both should immediately return 0 as they are
// both now at 1 / N
- val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L) }
- val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L) }
+ val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L, null) }
+ val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, null) }
assert(Await.result(t1Result2, 200.millis) === 0L)
assert(Await.result(t2Result2, 200.millis) === 0L)
}
@@ -188,15 +188,15 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite {
val futureTimeout: Duration = 20.seconds
// Have both tasks request 250 bytes, then wait until both requests have been granted:
- val t1Result1 = Future { t1MemManager.acquireExecutionMemory(250L) }
- val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L) }
+ val t1Result1 = Future { t1MemManager.acquireExecutionMemory(250L, null) }
+ val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L, null) }
assert(Await.result(t1Result1, futureTimeout) === 250L)
assert(Await.result(t2Result1, futureTimeout) === 250L)
// Have both tasks each request 500 bytes more.
// We should only grant 250 bytes to each of them on this second request
- val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L) }
- val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L) }
+ val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L, null) }
+ val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, null) }
assert(Await.result(t1Result2, futureTimeout) === 250L)
assert(Await.result(t2Result2, futureTimeout) === 250L)
}
@@ -208,17 +208,17 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite {
val futureTimeout: Duration = 20.seconds
// t1 grabs 1000 bytes and then waits until t2 is ready to make a request.
- val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L) }
+ val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L, null) }
assert(Await.result(t1Result1, futureTimeout) === 1000L)
- val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L) }
+ val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L, null) }
// Make sure that t2 didn't grab the memory right away. This is hacky but it would be difficult
// to make sure the other thread blocks for some time otherwise.
Thread.sleep(300)
- t1MemManager.releaseExecutionMemory(250L)
+ t1MemManager.releaseExecutionMemory(250L, null)
// The memory freed from t1 should now be granted to t2.
assert(Await.result(t2Result1, futureTimeout) === 250L)
// Further requests by t2 should be denied immediately because it now has 1 / 2N of the memory.
- val t2Result2 = Future { t2MemManager.acquireExecutionMemory(100L) }
+ val t2Result2 = Future { t2MemManager.acquireExecutionMemory(100L, null) }
assert(Await.result(t2Result2, 200.millis) === 0L)
}
@@ -229,18 +229,18 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite {
val futureTimeout: Duration = 20.seconds
// t1 grabs 1000 bytes and then waits until t2 is ready to make a request.
- val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L) }
+ val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L, null) }
assert(Await.result(t1Result1, futureTimeout) === 1000L)
- val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L) }
+ val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L, null) }
// Make sure that t2 didn't grab the memory right away. This is hacky but it would be difficult
// to make sure the other thread blocks for some time otherwise.
Thread.sleep(300)
// t1 releases all of its memory, so t2 should be able to grab all of the memory
t1MemManager.cleanUpAllAllocatedMemory()
assert(Await.result(t2Result1, futureTimeout) === 500L)
- val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L) }
+ val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, null) }
assert(Await.result(t2Result2, futureTimeout) === 500L)
- val t2Result3 = Future { t2MemManager.acquireExecutionMemory(500L) }
+ val t2Result3 = Future { t2MemManager.acquireExecutionMemory(500L, null) }
assert(Await.result(t2Result3, 200.millis) === 0L)
}
@@ -251,13 +251,13 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite {
val t2MemManager = new TaskMemoryManager(memoryManager, 2)
val futureTimeout: Duration = 20.seconds
- val t1Result1 = Future { t1MemManager.acquireExecutionMemory(700L) }
+ val t1Result1 = Future { t1MemManager.acquireExecutionMemory(700L, null) }
assert(Await.result(t1Result1, futureTimeout) === 700L)
- val t2Result1 = Future { t2MemManager.acquireExecutionMemory(300L) }
+ val t2Result1 = Future { t2MemManager.acquireExecutionMemory(300L, null) }
assert(Await.result(t2Result1, futureTimeout) === 300L)
- val t1Result2 = Future { t1MemManager.acquireExecutionMemory(300L) }
+ val t1Result2 = Future { t1MemManager.acquireExecutionMemory(300L, null) }
assert(Await.result(t1Result2, 200.millis) === 0L)
}
}
diff --git a/core/src/test/scala/org/apache/spark/memory/GrantEverythingMemoryManager.scala b/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala
index fe102d8aeb..77e43554ee 100644
--- a/core/src/test/scala/org/apache/spark/memory/GrantEverythingMemoryManager.scala
+++ b/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala
@@ -22,16 +22,22 @@ import scala.collection.mutable
import org.apache.spark.SparkConf
import org.apache.spark.storage.{BlockStatus, BlockId}
-class GrantEverythingMemoryManager(conf: SparkConf) extends MemoryManager(conf, numCores = 1) {
+class TestMemoryManager(conf: SparkConf) extends MemoryManager(conf, numCores = 1) {
private[memory] override def doAcquireExecutionMemory(
numBytes: Long,
evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long = synchronized {
- if (oom) {
- oom = false
+ if (oomOnce) {
+ oomOnce = false
0
- } else {
+ } else if (available >= numBytes) {
_executionMemoryUsed += numBytes // To suppress warnings when freeing unallocated memory
+ available -= numBytes
numBytes
+ } else {
+ _executionMemoryUsed += available
+ val grant = available
+ available = 0
+ grant
}
}
override def acquireStorageMemory(
@@ -42,13 +48,23 @@ class GrantEverythingMemoryManager(conf: SparkConf) extends MemoryManager(conf,
blockId: BlockId,
numBytes: Long,
evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = true
- override def releaseStorageMemory(numBytes: Long): Unit = { }
+ override def releaseExecutionMemory(numBytes: Long): Unit = {
+ available += numBytes
+ _executionMemoryUsed -= numBytes
+ }
+ override def releaseStorageMemory(numBytes: Long): Unit = {}
override def maxExecutionMemory: Long = Long.MaxValue
override def maxStorageMemory: Long = Long.MaxValue
- private var oom = false
+ private var oomOnce = false
+ private var available = Long.MaxValue
- def markExecutionAsOutOfMemory(): Unit = {
- oom = true
+ def markExecutionAsOutOfMemoryOnce(): Unit = {
+ oomOnce = true
}
+
+ def limit(avail: Long): Unit = {
+ available = avail
+ }
+
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
index 810c74fd2f..f7063d1e5c 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -96,15 +96,10 @@ final class UnsafeExternalRowSorter {
);
numRowsInserted++;
if (testSpillFrequency > 0 && (numRowsInserted % testSpillFrequency) == 0) {
- spill();
+ sorter.spill();
}
}
- @VisibleForTesting
- void spill() throws IOException {
- sorter.spill();
- }
-
/**
* Return the peak memory used so far, in bytes.
*/
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index 82c645df28..889f970034 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -165,7 +165,7 @@ public final class UnsafeFixedWidthAggregationMap {
public KVIterator<UnsafeRow, UnsafeRow> iterator() {
return new KVIterator<UnsafeRow, UnsafeRow>() {
- private final BytesToBytesMap.BytesToBytesMapIterator mapLocationIterator =
+ private final BytesToBytesMap.MapIterator mapLocationIterator =
map.destructiveIterator();
private final UnsafeRow key = new UnsafeRow();
private final UnsafeRow value = new UnsafeRow();
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
index 46301f0042..845f2ae685 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
@@ -17,13 +17,13 @@
package org.apache.spark.sql.execution;
-import java.io.IOException;
-
import javax.annotation.Nullable;
+import java.io.IOException;
import com.google.common.annotations.VisibleForTesting;
import org.apache.spark.TaskContext;
+import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.catalyst.expressions.codegen.BaseOrdering;
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering;
@@ -33,7 +33,6 @@ import org.apache.spark.unsafe.KVIterator;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.map.BytesToBytesMap;
import org.apache.spark.unsafe.memory.MemoryBlock;
-import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.util.collection.unsafe.sort.*;
/**
@@ -84,18 +83,16 @@ public final class UnsafeKVExternalSorter {
/* initialSize */ 4096,
pageSizeBytes);
} else {
- // Insert the records into the in-memory sorter.
- // We will use the number of elements in the map as the initialSize of the
- // UnsafeInMemorySorter. Because UnsafeInMemorySorter does not accept 0 as the initialSize,
- // we will use 1 as its initial size if the map is empty.
- // TODO: track pointer array memory used by this in-memory sorter! (SPARK-10474)
+ // The memory needed for UnsafeInMemorySorter should be less than longArray in map.
+ map.freeArray();
+ // The memory used by UnsafeInMemorySorter will be counted later (end of this block)
final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter(
taskMemoryManager, recordComparator, prefixComparator, Math.max(1, map.numElements()));
// We cannot use the destructive iterator here because we are reusing the existing memory
// pages in BytesToBytesMap to hold records during sorting.
// The only new memory we are allocating is the pointer/prefix array.
- BytesToBytesMap.BytesToBytesMapIterator iter = map.iterator();
+ BytesToBytesMap.MapIterator iter = map.iterator();
final int numKeyFields = keySchema.size();
UnsafeRow row = new UnsafeRow();
while (iter.hasNext()) {
@@ -117,7 +114,7 @@ public final class UnsafeKVExternalSorter {
}
sorter = UnsafeExternalSorter.createWithExistingInMemorySorter(
- taskContext.taskMemoryManager(),
+ taskMemoryManager,
blockManager,
taskContext,
new KVComparator(ordering, keySchema.length()),
@@ -128,6 +125,8 @@ public final class UnsafeKVExternalSorter {
sorter.spill();
map.free();
+ // counting the memory used UnsafeInMemorySorter
+ taskMemoryManager.acquireExecutionMemory(inMemSorter.getMemoryUsage(), sorter);
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
index dbf4863b76..a38623623a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
@@ -24,7 +24,7 @@ import scala.util.{Try, Random}
import org.scalatest.Matchers
import org.apache.spark.{SparkConf, TaskContextImpl, TaskContext, SparkFunSuite}
-import org.apache.spark.memory.{TaskMemoryManager, GrantEverythingMemoryManager}
+import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection}
import org.apache.spark.sql.test.SharedSQLContext
@@ -48,7 +48,7 @@ class UnsafeFixedWidthAggregationMapSuite
private def emptyAggregationBuffer: InternalRow = InternalRow(0)
private val PAGE_SIZE_BYTES: Long = 1L << 26; // 64 megabytes
- private var memoryManager: GrantEverythingMemoryManager = null
+ private var memoryManager: TestMemoryManager = null
private var taskMemoryManager: TaskMemoryManager = null
def testWithMemoryLeakDetection(name: String)(f: => Unit) {
@@ -62,7 +62,7 @@ class UnsafeFixedWidthAggregationMapSuite
test(name) {
val conf = new SparkConf().set("spark.unsafe.offHeap", "false")
- memoryManager = new GrantEverythingMemoryManager(conf)
+ memoryManager = new TestMemoryManager(conf)
taskMemoryManager = new TaskMemoryManager(memoryManager, 0)
TaskContext.setTaskContext(new TaskContextImpl(
@@ -193,10 +193,6 @@ class UnsafeFixedWidthAggregationMapSuite
// Convert the map into a sorter
val sorter = map.destructAndCreateExternalSorter()
- withClue(s"destructAndCreateExternalSorter should release memory used by the map") {
- assert(taskMemoryManager.getMemoryConsumptionForThisTask() === initialMemoryConsumption)
- }
-
// Add more keys to the sorter and make sure the results come out sorted.
val additionalKeys = randomStrings(1024)
val keyConverter = UnsafeProjection.create(groupKeySchema)
@@ -208,7 +204,7 @@ class UnsafeFixedWidthAggregationMapSuite
sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v))
if ((i % 100) == 0) {
- memoryManager.markExecutionAsOutOfMemory()
+ memoryManager.markExecutionAsOutOfMemoryOnce()
sorter.closeCurrentPage()
}
}
@@ -251,7 +247,7 @@ class UnsafeFixedWidthAggregationMapSuite
sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v))
if ((i % 100) == 0) {
- memoryManager.markExecutionAsOutOfMemory()
+ memoryManager.markExecutionAsOutOfMemoryOnce()
sorter.closeCurrentPage()
}
}
@@ -294,16 +290,12 @@ class UnsafeFixedWidthAggregationMapSuite
// Convert the map into a sorter. Right now, it contains one record.
val sorter = map.destructAndCreateExternalSorter()
- withClue(s"destructAndCreateExternalSorter should release memory used by the map") {
- assert(taskMemoryManager.getMemoryConsumptionForThisTask() === initialMemoryConsumption)
- }
-
// Add more keys to the sorter and make sure the results come out sorted.
(1 to 4096).foreach { i =>
sorter.insertKV(UnsafeRow.createFromByteArray(0, 0), UnsafeRow.createFromByteArray(0, 0))
if ((i % 100) == 0) {
- memoryManager.markExecutionAsOutOfMemory()
+ memoryManager.markExecutionAsOutOfMemoryOnce()
sorter.closeCurrentPage()
}
}
@@ -342,7 +334,7 @@ class UnsafeFixedWidthAggregationMapSuite
buf.setInt(0, str.length)
}
// Simulate running out of space
- memoryManager.markExecutionAsOutOfMemory()
+ memoryManager.limit(0)
val str = rand.nextString(1024)
val buf = map.getAggregationBuffer(InternalRow(UTF8String.fromString(str)))
assert(buf == null)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
index 13dc1754c9..7b80963ec8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
import scala.util.Random
import org.apache.spark._
-import org.apache.spark.memory.{TaskMemoryManager, GrantEverythingMemoryManager}
+import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager}
import org.apache.spark.sql.{RandomDataGenerator, Row}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeRow, UnsafeProjection}
@@ -109,7 +109,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext {
pageSize: Long,
spill: Boolean): Unit = {
val memoryManager =
- new GrantEverythingMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false"))
+ new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false"))
val taskMemMgr = new TaskMemoryManager(memoryManager, 0)
TaskContext.setTaskContext(new TaskContextImpl(
stageId = 0,
@@ -128,7 +128,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext {
sorter.insertKV(k.asInstanceOf[UnsafeRow], v.asInstanceOf[UnsafeRow])
// 1% chance we will spill
if (rand.nextDouble() < 0.01 && spill) {
- memoryManager.markExecutionAsOutOfMemory()
+ memoryManager.markExecutionAsOutOfMemoryOnce()
sorter.closeCurrentPage()
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
deleted file mode 100644
index 475037bd45..0000000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
+++ /dev/null
@@ -1,54 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.aggregate
-
-import org.apache.spark._
-import org.apache.spark.memory.TaskMemoryManager
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.InterpretedMutableProjection
-import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.test.SharedSQLContext
-
-class TungstenAggregationIteratorSuite extends SparkFunSuite with SharedSQLContext {
-
- test("memory acquired on construction") {
- val taskMemoryManager = new TaskMemoryManager(SparkEnv.get.memoryManager, 0)
- val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, null, Seq.empty)
- TaskContext.setTaskContext(taskContext)
-
- // Assert that a page is allocated before processing starts
- var iter: TungstenAggregationIterator = null
- try {
- val newMutableProjection = (expr: Seq[Expression], schema: Seq[Attribute]) => {
- () => new InterpretedMutableProjection(expr, schema)
- }
- val dummyAccum = SQLMetrics.createLongMetric(sparkContext, "dummy")
- iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty, Seq.empty, Seq.empty,
- 0, Seq.empty, newMutableProjection, Seq.empty, None,
- dummyAccum, dummyAccum, dummyAccum, dummyAccum)
- val numPages = iter.getHashMap.getNumDataPages
- assert(numPages === 1)
- } finally {
- // Clean up
- if (iter != null) {
- iter.free()
- }
- TaskContext.unset()
- }
- }
-}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
index ebe90d9e63..09847cec9c 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
@@ -23,6 +23,8 @@ import java.util.HashMap;
import java.util.LinkedList;
import java.util.Map;
+import org.apache.spark.unsafe.Platform;
+
/**
* A simple {@link MemoryAllocator} that can allocate up to 16GB using a JVM long primitive array.
*/
@@ -45,9 +47,6 @@ public class HeapMemoryAllocator implements MemoryAllocator {
@Override
public MemoryBlock allocate(long size) throws OutOfMemoryError {
- if (size % 8 != 0) {
- throw new IllegalArgumentException("Size " + size + " was not a multiple of 8");
- }
if (shouldPool(size)) {
synchronized (this) {
final LinkedList<WeakReference<MemoryBlock>> pool = bufferPoolsBySize.get(size);
@@ -64,8 +63,8 @@ public class HeapMemoryAllocator implements MemoryAllocator {
}
}
}
- long[] array = new long[(int) (size / 8)];
- return MemoryBlock.fromLongArray(array);
+ long[] array = new long[(int) ((size + 7) / 8)];
+ return new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size);
}
@Override
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
index cda7826c8c..98ce711176 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
@@ -26,9 +26,6 @@ public class UnsafeMemoryAllocator implements MemoryAllocator {
@Override
public MemoryBlock allocate(long size) throws OutOfMemoryError {
- if (size % 8 != 0) {
- throw new IllegalArgumentException("Size " + size + " was not a multiple of 8");
- }
long address = Platform.allocateMemory(size);
return new MemoryBlock(null, address, size);
}