aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2015-07-31 19:19:27 -0700
committerJosh Rosen <joshrosen@databricks.com>2015-07-31 19:19:27 -0700
commit8cb415a4b9bc1f82127ccce4a5579d433f4e8f83 (patch)
treef9071996c6485e77700463106ae86c1a346508c9
parentf51fd6fbb4d9822502f98b312251e317d757bc3a (diff)
downloadspark-8cb415a4b9bc1f82127ccce4a5579d433f4e8f83.tar.gz
spark-8cb415a4b9bc1f82127ccce4a5579d433f4e8f83.tar.bz2
spark-8cb415a4b9bc1f82127ccce4a5579d433f4e8f83.zip
[SPARK-9451] [SQL] Support entries larger than default page size in BytesToBytesMap & integrate with ShuffleMemoryManager
This patch adds support for entries larger than the default page size in BytesToBytesMap. These large rows are handled by allocating special overflow pages to hold individual entries. In addition, this patch integrates BytesToBytesMap with the ShuffleMemoryManager: - Move BytesToBytesMap from `unsafe` to `core` so that it can import `ShuffleMemoryManager`. - Before allocating new data pages, ask the ShuffleMemoryManager to reserve the memory: - `putNewKey()` now returns a boolean to indicate whether the insert succeeded or failed due to a lack of memory. The caller can use this value to respond to the memory pressure (e.g. by spilling). - `UnsafeFixedWidthAggregationMap. getAggregationBuffer()` now returns `null` to signal failure due to a lack of memory. - Updated all uses of these classes to handle these error conditions. - Added new tests for allocating large records and for allocations which fail due to memory pressure. - Extended the `afterAll()` test teardown methods to detect ShuffleMemoryManager leaks. Author: Josh Rosen <joshrosen@databricks.com> Closes #7762 from JoshRosen/large-rows and squashes the following commits: ae7bc56 [Josh Rosen] Fix compilation 82fc657 [Josh Rosen] Merge remote-tracking branch 'origin/master' into large-rows 34ab943 [Josh Rosen] Remove semi 31a525a [Josh Rosen] Integrate BytesToBytesMap with ShuffleMemoryManager. 626b33c [Josh Rosen] Move code to sql/core and spark/core packages so that ShuffleMemoryManager can be integrated ec4484c [Josh Rosen] Move BytesToBytesMap from unsafe package to core. 642ed69 [Josh Rosen] Rename size to numElements bea1152 [Josh Rosen] Add basic test. 2cd3570 [Josh Rosen] Remove accidental duplicated code 07ff9ef [Josh Rosen] Basic support for large rows in BytesToBytesMap.
-rw-r--r--core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java (renamed from unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java)170
-rw-r--r--core/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java (renamed from unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java)0
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala8
-rw-r--r--core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java (renamed from unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java)204
-rw-r--r--core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java (renamed from unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java)0
-rw-r--r--core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java (renamed from unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java)0
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java (renamed from sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java)25
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala27
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala (renamed from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala)36
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java2
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java1
12 files changed, 353 insertions, 126 deletions
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index 198e0684f3..0f42950e6e 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -17,6 +17,7 @@
package org.apache.spark.unsafe.map;
+import java.io.IOException;
import java.lang.Override;
import java.lang.UnsupportedOperationException;
import java.util.Iterator;
@@ -24,7 +25,10 @@ import java.util.LinkedList;
import java.util.List;
import com.google.common.annotations.VisibleForTesting;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.apache.spark.shuffle.ShuffleMemoryManager;
import org.apache.spark.unsafe.*;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.array.LongArray;
@@ -45,6 +49,8 @@ import org.apache.spark.unsafe.memory.*;
*/
public final class BytesToBytesMap {
+ private final Logger logger = LoggerFactory.getLogger(BytesToBytesMap.class);
+
private static final Murmur3_x86_32 HASHER = new Murmur3_x86_32(0);
private static final HashMapGrowthStrategy growthStrategy = HashMapGrowthStrategy.DOUBLING;
@@ -54,7 +60,9 @@ public final class BytesToBytesMap {
*/
private static final int END_OF_PAGE_MARKER = -1;
- private final TaskMemoryManager memoryManager;
+ private final TaskMemoryManager taskMemoryManager;
+
+ private final ShuffleMemoryManager shuffleMemoryManager;
/**
* A linked list for tracking all allocated data pages so that we can free all of our memory.
@@ -120,7 +128,7 @@ public final class BytesToBytesMap {
/**
* Number of keys defined in the map.
*/
- private int size;
+ private int numElements;
/**
* The map will be expanded once the number of keys exceeds this threshold.
@@ -150,12 +158,14 @@ public final class BytesToBytesMap {
private long numHashCollisions = 0;
public BytesToBytesMap(
- TaskMemoryManager memoryManager,
+ TaskMemoryManager taskMemoryManager,
+ ShuffleMemoryManager shuffleMemoryManager,
int initialCapacity,
double loadFactor,
long pageSizeBytes,
boolean enablePerfMetrics) {
- this.memoryManager = memoryManager;
+ this.taskMemoryManager = taskMemoryManager;
+ this.shuffleMemoryManager = shuffleMemoryManager;
this.loadFactor = loadFactor;
this.loc = new Location();
this.pageSizeBytes = pageSizeBytes;
@@ -175,24 +185,32 @@ public final class BytesToBytesMap {
}
public BytesToBytesMap(
- TaskMemoryManager memoryManager,
+ TaskMemoryManager taskMemoryManager,
+ ShuffleMemoryManager shuffleMemoryManager,
int initialCapacity,
long pageSizeBytes) {
- this(memoryManager, initialCapacity, 0.70, pageSizeBytes, false);
+ this(taskMemoryManager, shuffleMemoryManager, initialCapacity, 0.70, pageSizeBytes, false);
}
public BytesToBytesMap(
- TaskMemoryManager memoryManager,
+ TaskMemoryManager taskMemoryManager,
+ ShuffleMemoryManager shuffleMemoryManager,
int initialCapacity,
long pageSizeBytes,
boolean enablePerfMetrics) {
- this(memoryManager, initialCapacity, 0.70, pageSizeBytes, enablePerfMetrics);
+ this(
+ taskMemoryManager,
+ shuffleMemoryManager,
+ initialCapacity,
+ 0.70,
+ pageSizeBytes,
+ enablePerfMetrics);
}
/**
* Returns the number of keys defined in the map.
*/
- public int size() { return size; }
+ public int numElements() { return numElements; }
private static final class BytesToBytesMapIterator implements Iterator<Location> {
@@ -252,7 +270,7 @@ public final class BytesToBytesMap {
* `lookup()`, the behavior of the returned iterator is undefined.
*/
public Iterator<Location> iterator() {
- return new BytesToBytesMapIterator(size, dataPages.iterator(), loc);
+ return new BytesToBytesMapIterator(numElements, dataPages.iterator(), loc);
}
/**
@@ -330,7 +348,8 @@ public final class BytesToBytesMap {
private void updateAddressesAndSizes(long fullKeyAddress) {
updateAddressesAndSizes(
- memoryManager.getPage(fullKeyAddress), memoryManager.getOffsetInPage(fullKeyAddress));
+ taskMemoryManager.getPage(fullKeyAddress),
+ taskMemoryManager.getOffsetInPage(fullKeyAddress));
}
private void updateAddressesAndSizes(Object page, long keyOffsetInPage) {
@@ -411,7 +430,8 @@ public final class BytesToBytesMap {
/**
* Store a new key and value. This method may only be called once for a given key; if you want
* to update the value associated with a key, then you can directly manipulate the bytes stored
- * at the value address.
+ * at the value address. The return value indicates whether the put succeeded or whether it
+ * failed because additional memory could not be acquired.
* <p>
* It is only valid to call this method immediately after calling `lookup()` using the same key.
* </p>
@@ -428,14 +448,19 @@ public final class BytesToBytesMap {
* <pre>
* Location loc = map.lookup(keyBaseObject, keyBaseOffset, keyLengthInBytes);
* if (!loc.isDefined()) {
- * loc.putNewKey(keyBaseObject, keyBaseOffset, keyLengthInBytes, ...)
+ * if (!loc.putNewKey(keyBaseObject, keyBaseOffset, keyLengthInBytes, ...)) {
+ * // handle failure to grow map (by spilling, for example)
+ * }
* }
* </pre>
* <p>
* Unspecified behavior if the key is not defined.
* </p>
+ *
+ * @return true if the put() was successful and false if the put() failed because memory could
+ * not be acquired.
*/
- public void putNewKey(
+ public boolean putNewKey(
Object keyBaseObject,
long keyBaseOffset,
int keyLengthBytes,
@@ -445,63 +470,110 @@ public final class BytesToBytesMap {
assert (!isDefined) : "Can only set value once for a key";
assert (keyLengthBytes % 8 == 0);
assert (valueLengthBytes % 8 == 0);
- if (size == MAX_CAPACITY) {
+ if (numElements == MAX_CAPACITY) {
throw new IllegalStateException("BytesToBytesMap has reached maximum capacity");
}
+
// Here, we'll copy the data into our data pages. Because we only store a relative offset from
// the key address instead of storing the absolute address of the value, the key and value
// must be stored in the same memory page.
// (8 byte key length) (key) (8 byte value length) (value)
final long requiredSize = 8 + keyLengthBytes + 8 + valueLengthBytes;
- assert (requiredSize <= pageSizeBytes - 8); // Reserve 8 bytes for the end-of-page marker.
- size++;
- bitset.set(pos);
- // If there's not enough space in the current page, allocate a new page (8 bytes are reserved
- // for the end-of-page marker).
- if (currentDataPage == null || pageSizeBytes - 8 - pageCursor < requiredSize) {
+ // --- Figure out where to insert the new record ---------------------------------------------
+
+ final MemoryBlock dataPage;
+ final Object dataPageBaseObject;
+ final long dataPageInsertOffset;
+ boolean useOverflowPage = requiredSize > pageSizeBytes - 8;
+ if (useOverflowPage) {
+ // The record is larger than the page size, so allocate a special overflow page just to hold
+ // that record.
+ final long memoryRequested = requiredSize + 8;
+ final long memoryGranted = shuffleMemoryManager.tryToAcquire(memoryRequested);
+ if (memoryGranted != memoryRequested) {
+ shuffleMemoryManager.release(memoryGranted);
+ logger.debug("Failed to acquire {} bytes of memory", memoryRequested);
+ return false;
+ }
+ MemoryBlock overflowPage = taskMemoryManager.allocatePage(memoryRequested);
+ dataPages.add(overflowPage);
+ dataPage = overflowPage;
+ dataPageBaseObject = overflowPage.getBaseObject();
+ dataPageInsertOffset = overflowPage.getBaseOffset();
+ } else if (currentDataPage == null || pageSizeBytes - 8 - pageCursor < requiredSize) {
+ // The record can fit in a data page, but either we have not allocated any pages yet or
+ // the current page does not have enough space.
if (currentDataPage != null) {
// There wasn't enough space in the current page, so write an end-of-page marker:
final Object pageBaseObject = currentDataPage.getBaseObject();
final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor;
PlatformDependent.UNSAFE.putLong(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER);
}
- MemoryBlock newPage = memoryManager.allocatePage(pageSizeBytes);
+ final long memoryGranted = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
+ if (memoryGranted != pageSizeBytes) {
+ shuffleMemoryManager.release(memoryGranted);
+ logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes);
+ return false;
+ }
+ MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes);
dataPages.add(newPage);
pageCursor = 0;
currentDataPage = newPage;
+ dataPage = currentDataPage;
+ dataPageBaseObject = currentDataPage.getBaseObject();
+ dataPageInsertOffset = currentDataPage.getBaseOffset();
+ } else {
+ // There is enough space in the current data page.
+ dataPage = currentDataPage;
+ dataPageBaseObject = currentDataPage.getBaseObject();
+ dataPageInsertOffset = currentDataPage.getBaseOffset() + pageCursor;
}
+ // --- Append the key and value data to the current data page --------------------------------
+
+ long insertCursor = dataPageInsertOffset;
+
// Compute all of our offsets up-front:
- final Object pageBaseObject = currentDataPage.getBaseObject();
- final long pageBaseOffset = currentDataPage.getBaseOffset();
- final long keySizeOffsetInPage = pageBaseOffset + pageCursor;
- pageCursor += 8; // word used to store the key size
- final long keyDataOffsetInPage = pageBaseOffset + pageCursor;
- pageCursor += keyLengthBytes;
- final long valueSizeOffsetInPage = pageBaseOffset + pageCursor;
- pageCursor += 8; // word used to store the value size
- final long valueDataOffsetInPage = pageBaseOffset + pageCursor;
- pageCursor += valueLengthBytes;
+ final long keySizeOffsetInPage = insertCursor;
+ insertCursor += 8; // word used to store the key size
+ final long keyDataOffsetInPage = insertCursor;
+ insertCursor += keyLengthBytes;
+ final long valueSizeOffsetInPage = insertCursor;
+ insertCursor += 8; // word used to store the value size
+ final long valueDataOffsetInPage = insertCursor;
+ insertCursor += valueLengthBytes; // word used to store the value size
// Copy the key
- PlatformDependent.UNSAFE.putLong(pageBaseObject, keySizeOffsetInPage, keyLengthBytes);
+ PlatformDependent.UNSAFE.putLong(dataPageBaseObject, keySizeOffsetInPage, keyLengthBytes);
PlatformDependent.copyMemory(
- keyBaseObject, keyBaseOffset, pageBaseObject, keyDataOffsetInPage, keyLengthBytes);
+ keyBaseObject, keyBaseOffset, dataPageBaseObject, keyDataOffsetInPage, keyLengthBytes);
// Copy the value
- PlatformDependent.UNSAFE.putLong(pageBaseObject, valueSizeOffsetInPage, valueLengthBytes);
- PlatformDependent.copyMemory(
- valueBaseObject, valueBaseOffset, pageBaseObject, valueDataOffsetInPage, valueLengthBytes);
+ PlatformDependent.UNSAFE.putLong(dataPageBaseObject, valueSizeOffsetInPage, valueLengthBytes);
+ PlatformDependent.copyMemory(valueBaseObject, valueBaseOffset, dataPageBaseObject,
+ valueDataOffsetInPage, valueLengthBytes);
+
+ // --- Update bookeeping data structures -----------------------------------------------------
+
+ if (useOverflowPage) {
+ // Store the end-of-page marker at the end of the data page
+ PlatformDependent.UNSAFE.putLong(dataPageBaseObject, insertCursor, END_OF_PAGE_MARKER);
+ } else {
+ pageCursor += requiredSize;
+ }
- final long storedKeyAddress = memoryManager.encodePageNumberAndOffset(
- currentDataPage, keySizeOffsetInPage);
+ numElements++;
+ bitset.set(pos);
+ final long storedKeyAddress = taskMemoryManager.encodePageNumberAndOffset(
+ dataPage, keySizeOffsetInPage);
longArray.set(pos * 2, storedKeyAddress);
longArray.set(pos * 2 + 1, keyHashcode);
updateAddressesAndSizes(storedKeyAddress);
isDefined = true;
- if (size > growthThreshold && longArray.size() < MAX_CAPACITY) {
+ if (numElements > growthThreshold && longArray.size() < MAX_CAPACITY) {
growAndRehash();
}
+ return true;
}
}
@@ -516,7 +588,7 @@ public final class BytesToBytesMap {
// The capacity needs to be divisible by 64 so that our bit set can be sized properly
capacity = Math.max((int) Math.min(MAX_CAPACITY, nextPowerOf2(capacity)), 64);
assert (capacity <= MAX_CAPACITY);
- longArray = new LongArray(memoryManager.allocate(capacity * 8L * 2));
+ longArray = new LongArray(MemoryBlock.fromLongArray(new long[capacity * 2]));
bitset = new BitSet(MemoryBlock.fromLongArray(new long[capacity / 64]));
this.growthThreshold = (int) (capacity * loadFactor);
@@ -530,18 +602,14 @@ public final class BytesToBytesMap {
* This method is idempotent.
*/
public void free() {
- if (longArray != null) {
- memoryManager.free(longArray.memoryBlock());
- longArray = null;
- }
- if (bitset != null) {
- // The bitset's heap memory isn't managed by a memory manager, so no need to free it here.
- bitset = null;
- }
+ longArray = null;
+ bitset = null;
Iterator<MemoryBlock> dataPagesIterator = dataPages.iterator();
while (dataPagesIterator.hasNext()) {
- memoryManager.freePage(dataPagesIterator.next());
+ MemoryBlock dataPage = dataPagesIterator.next();
dataPagesIterator.remove();
+ taskMemoryManager.freePage(dataPage);
+ shuffleMemoryManager.release(dataPage.size());
}
assert(dataPages.isEmpty());
}
@@ -628,8 +696,6 @@ public final class BytesToBytesMap {
}
}
- // Deallocate the old data structures.
- memoryManager.free(oldLongArray.memoryBlock());
if (enablePerfMetrics) {
timeSpentResizingNs += System.nanoTime() - resizeStartTime;
}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java b/core/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java
index 20654e4eea..20654e4eea 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java
+++ b/core/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
index f038b72295..00c1e078a4 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
@@ -85,7 +85,7 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging {
return toGrant
} else {
logInfo(
- s"Thread $taskAttemptId waiting for at least 1/2N of shuffle memory pool to be free")
+ s"TID $taskAttemptId waiting for at least 1/2N of shuffle memory pool to be free")
wait()
}
} else {
@@ -116,6 +116,12 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging {
taskMemory.remove(taskAttemptId)
notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed
}
+
+ /** Returns the memory consumption, in bytes, for the current task */
+ def getMemoryConsumptionForThisTask(): Long = synchronized {
+ val taskAttemptId = currentTaskAttemptId()
+ taskMemory.getOrElse(taskAttemptId, 0L)
+ }
}
private object ShuffleMemoryManager {
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index 0be94ad371..60f483acbc 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -21,15 +21,14 @@ import java.lang.Exception;
import java.nio.ByteBuffer;
import java.util.*;
-import org.junit.After;
-import org.junit.Assert;
-import org.junit.Before;
-import org.junit.Test;
+import org.junit.*;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
+import static org.hamcrest.Matchers.greaterThan;
import static org.mockito.AdditionalMatchers.geq;
import static org.mockito.Mockito.*;
+import org.apache.spark.shuffle.ShuffleMemoryManager;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.memory.*;
import org.apache.spark.unsafe.PlatformDependent;
@@ -41,32 +40,39 @@ public abstract class AbstractBytesToBytesMapSuite {
private final Random rand = new Random(42);
- private TaskMemoryManager memoryManager;
- private TaskMemoryManager sizeLimitedMemoryManager;
+ private ShuffleMemoryManager shuffleMemoryManager;
+ private TaskMemoryManager taskMemoryManager;
+ private TaskMemoryManager sizeLimitedTaskMemoryManager;
private final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes
@Before
public void setup() {
- memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(getMemoryAllocator()));
+ shuffleMemoryManager = new ShuffleMemoryManager(Long.MAX_VALUE);
+ taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(getMemoryAllocator()));
// Mocked memory manager for tests that check the maximum array size, since actually allocating
// such large arrays will cause us to run out of memory in our tests.
- sizeLimitedMemoryManager = spy(memoryManager);
- when(sizeLimitedMemoryManager.allocate(geq(1L << 20))).thenAnswer(new Answer<MemoryBlock>() {
- @Override
- public MemoryBlock answer(InvocationOnMock invocation) throws Throwable {
- if (((Long) invocation.getArguments()[0] / 8) > Integer.MAX_VALUE) {
- throw new OutOfMemoryError("Requested array size exceeds VM limit");
+ sizeLimitedTaskMemoryManager = mock(TaskMemoryManager.class);
+ when(sizeLimitedTaskMemoryManager.allocate(geq(1L << 20))).thenAnswer(
+ new Answer<MemoryBlock>() {
+ @Override
+ public MemoryBlock answer(InvocationOnMock invocation) throws Throwable {
+ if (((Long) invocation.getArguments()[0] / 8) > Integer.MAX_VALUE) {
+ throw new OutOfMemoryError("Requested array size exceeds VM limit");
+ }
+ return new MemoryBlock(null, 0, (Long) invocation.getArguments()[0]);
}
- return memoryManager.allocate(1L << 20);
}
- });
+ );
}
@After
public void tearDown() {
- if (memoryManager != null) {
- memoryManager.cleanUpAllAllocatedMemory();
- memoryManager = null;
+ if (taskMemoryManager != null) {
+ long leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask();
+ Assert.assertEquals(0, taskMemoryManager.cleanUpAllAllocatedMemory());
+ Assert.assertEquals(0, leakedShuffleMemory);
+ shuffleMemoryManager = null;
+ taskMemoryManager = null;
}
}
@@ -85,7 +91,7 @@ public abstract class AbstractBytesToBytesMapSuite {
}
private byte[] getRandomByteArray(int numWords) {
- Assert.assertTrue(numWords > 0);
+ Assert.assertTrue(numWords >= 0);
final int lengthInBytes = numWords * 8;
final byte[] bytes = new byte[lengthInBytes];
rand.nextBytes(bytes);
@@ -111,9 +117,10 @@ public abstract class AbstractBytesToBytesMapSuite {
@Test
public void emptyMap() {
- BytesToBytesMap map = new BytesToBytesMap(memoryManager, 64, PAGE_SIZE_BYTES);
+ BytesToBytesMap map = new BytesToBytesMap(
+ taskMemoryManager, shuffleMemoryManager, 64, PAGE_SIZE_BYTES);
try {
- Assert.assertEquals(0, map.size());
+ Assert.assertEquals(0, map.numElements());
final int keyLengthInWords = 10;
final int keyLengthInBytes = keyLengthInWords * 8;
final byte[] key = getRandomByteArray(keyLengthInWords);
@@ -126,7 +133,8 @@ public abstract class AbstractBytesToBytesMapSuite {
@Test
public void setAndRetrieveAKey() {
- BytesToBytesMap map = new BytesToBytesMap(memoryManager, 64, PAGE_SIZE_BYTES);
+ BytesToBytesMap map = new BytesToBytesMap(
+ taskMemoryManager, shuffleMemoryManager, 64, PAGE_SIZE_BYTES);
final int recordLengthWords = 10;
final int recordLengthBytes = recordLengthWords * 8;
final byte[] keyData = getRandomByteArray(recordLengthWords);
@@ -135,14 +143,14 @@ public abstract class AbstractBytesToBytesMapSuite {
final BytesToBytesMap.Location loc =
map.lookup(keyData, BYTE_ARRAY_OFFSET, recordLengthBytes);
Assert.assertFalse(loc.isDefined());
- loc.putNewKey(
+ Assert.assertTrue(loc.putNewKey(
keyData,
BYTE_ARRAY_OFFSET,
recordLengthBytes,
valueData,
BYTE_ARRAY_OFFSET,
recordLengthBytes
- );
+ ));
// After storing the key and value, the other location methods should return results that
// reflect the result of this store without us having to call lookup() again on the same key.
Assert.assertEquals(recordLengthBytes, loc.getKeyLength());
@@ -158,14 +166,14 @@ public abstract class AbstractBytesToBytesMapSuite {
Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes));
try {
- loc.putNewKey(
+ Assert.assertTrue(loc.putNewKey(
keyData,
BYTE_ARRAY_OFFSET,
recordLengthBytes,
valueData,
BYTE_ARRAY_OFFSET,
recordLengthBytes
- );
+ ));
Assert.fail("Should not be able to set a new value for a key");
} catch (AssertionError e) {
// Expected exception; do nothing.
@@ -178,7 +186,8 @@ public abstract class AbstractBytesToBytesMapSuite {
@Test
public void iteratorTest() throws Exception {
final int size = 4096;
- BytesToBytesMap map = new BytesToBytesMap(memoryManager, size / 2, PAGE_SIZE_BYTES);
+ BytesToBytesMap map = new BytesToBytesMap(
+ taskMemoryManager, shuffleMemoryManager, size / 2, PAGE_SIZE_BYTES);
try {
for (long i = 0; i < size; i++) {
final long[] value = new long[] { i };
@@ -187,23 +196,23 @@ public abstract class AbstractBytesToBytesMapSuite {
Assert.assertFalse(loc.isDefined());
// Ensure that we store some zero-length keys
if (i % 5 == 0) {
- loc.putNewKey(
+ Assert.assertTrue(loc.putNewKey(
null,
PlatformDependent.LONG_ARRAY_OFFSET,
0,
value,
PlatformDependent.LONG_ARRAY_OFFSET,
8
- );
+ ));
} else {
- loc.putNewKey(
+ Assert.assertTrue(loc.putNewKey(
value,
PlatformDependent.LONG_ARRAY_OFFSET,
8,
value,
PlatformDependent.LONG_ARRAY_OFFSET,
8
- );
+ ));
}
}
final java.util.BitSet valuesSeen = new java.util.BitSet(size);
@@ -236,7 +245,8 @@ public abstract class AbstractBytesToBytesMapSuite {
final int NUM_ENTRIES = 1000 * 1000;
final int KEY_LENGTH = 16;
final int VALUE_LENGTH = 40;
- final BytesToBytesMap map = new BytesToBytesMap(memoryManager, NUM_ENTRIES, PAGE_SIZE_BYTES);
+ final BytesToBytesMap map = new BytesToBytesMap(
+ taskMemoryManager, shuffleMemoryManager, NUM_ENTRIES, PAGE_SIZE_BYTES);
// Each record will take 8 + 8 + 16 + 40 = 72 bytes of space in the data page. Our 64-megabyte
// pages won't be evenly-divisible by records of this size, which will cause us to waste some
// space at the end of the page. This is necessary in order for us to take the end-of-record
@@ -251,14 +261,14 @@ public abstract class AbstractBytesToBytesMapSuite {
KEY_LENGTH
);
Assert.assertFalse(loc.isDefined());
- loc.putNewKey(
+ Assert.assertTrue(loc.putNewKey(
key,
LONG_ARRAY_OFFSET,
KEY_LENGTH,
value,
LONG_ARRAY_OFFSET,
VALUE_LENGTH
- );
+ ));
}
Assert.assertEquals(2, map.getNumDataPages());
@@ -305,7 +315,8 @@ public abstract class AbstractBytesToBytesMapSuite {
// Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays
// into ByteBuffers in order to use them as keys here.
final Map<ByteBuffer, byte[]> expected = new HashMap<ByteBuffer, byte[]>();
- final BytesToBytesMap map = new BytesToBytesMap(memoryManager, size, PAGE_SIZE_BYTES);
+ final BytesToBytesMap map = new BytesToBytesMap(
+ taskMemoryManager, shuffleMemoryManager, size, PAGE_SIZE_BYTES);
try {
// Fill the map to 90% full so that we can trigger probing
@@ -320,14 +331,14 @@ public abstract class AbstractBytesToBytesMapSuite {
key.length
);
Assert.assertFalse(loc.isDefined());
- loc.putNewKey(
+ Assert.assertTrue(loc.putNewKey(
key,
BYTE_ARRAY_OFFSET,
key.length,
value,
BYTE_ARRAY_OFFSET,
value.length
- );
+ ));
// After calling putNewKey, the following should be true, even before calling
// lookup():
Assert.assertTrue(loc.isDefined());
@@ -352,9 +363,101 @@ public abstract class AbstractBytesToBytesMapSuite {
}
@Test
+ public void randomizedTestWithRecordsLargerThanPageSize() {
+ final long pageSizeBytes = 128;
+ final BytesToBytesMap map = new BytesToBytesMap(
+ taskMemoryManager, shuffleMemoryManager, 64, pageSizeBytes);
+ // Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays
+ // into ByteBuffers in order to use them as keys here.
+ final Map<ByteBuffer, byte[]> expected = new HashMap<ByteBuffer, byte[]>();
+ try {
+ for (int i = 0; i < 1000; i++) {
+ final byte[] key = getRandomByteArray(rand.nextInt(128));
+ final byte[] value = getRandomByteArray(rand.nextInt(128));
+ if (!expected.containsKey(ByteBuffer.wrap(key))) {
+ expected.put(ByteBuffer.wrap(key), value);
+ final BytesToBytesMap.Location loc = map.lookup(
+ key,
+ BYTE_ARRAY_OFFSET,
+ key.length
+ );
+ Assert.assertFalse(loc.isDefined());
+ Assert.assertTrue(loc.putNewKey(
+ key,
+ BYTE_ARRAY_OFFSET,
+ key.length,
+ value,
+ BYTE_ARRAY_OFFSET,
+ value.length
+ ));
+ // After calling putNewKey, the following should be true, even before calling
+ // lookup():
+ Assert.assertTrue(loc.isDefined());
+ Assert.assertEquals(key.length, loc.getKeyLength());
+ Assert.assertEquals(value.length, loc.getValueLength());
+ Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), key.length));
+ Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), value.length));
+ }
+ }
+ for (Map.Entry<ByteBuffer, byte[]> entry : expected.entrySet()) {
+ final byte[] key = entry.getKey().array();
+ final byte[] value = entry.getValue();
+ final BytesToBytesMap.Location loc = map.lookup(key, BYTE_ARRAY_OFFSET, key.length);
+ Assert.assertTrue(loc.isDefined());
+ Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), loc.getKeyLength()));
+ Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), loc.getValueLength()));
+ }
+ } finally {
+ map.free();
+ }
+ }
+
+ @Test
+ public void failureToAllocateFirstPage() {
+ shuffleMemoryManager = new ShuffleMemoryManager(1024);
+ BytesToBytesMap map =
+ new BytesToBytesMap(taskMemoryManager, shuffleMemoryManager, 1, PAGE_SIZE_BYTES);
+ try {
+ final long[] emptyArray = new long[0];
+ final BytesToBytesMap.Location loc =
+ map.lookup(emptyArray, PlatformDependent.LONG_ARRAY_OFFSET, 0);
+ Assert.assertFalse(loc.isDefined());
+ Assert.assertFalse(loc.putNewKey(
+ emptyArray, LONG_ARRAY_OFFSET, 0,
+ emptyArray, LONG_ARRAY_OFFSET, 0
+ ));
+ } finally {
+ map.free();
+ }
+ }
+
+
+ @Test
+ public void failureToGrow() {
+ shuffleMemoryManager = new ShuffleMemoryManager(1024 * 10);
+ BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, shuffleMemoryManager, 1, 1024);
+ try {
+ boolean success = true;
+ int i;
+ for (i = 0; i < 1024; i++) {
+ final long[] arr = new long[]{i};
+ final BytesToBytesMap.Location loc = map.lookup(arr, PlatformDependent.LONG_ARRAY_OFFSET, 8);
+ success = loc.putNewKey(arr, LONG_ARRAY_OFFSET, 8, arr, LONG_ARRAY_OFFSET, 8);
+ if (!success) {
+ break;
+ }
+ }
+ Assert.assertThat(i, greaterThan(0));
+ Assert.assertFalse(success);
+ } finally {
+ map.free();
+ }
+ }
+
+ @Test
public void initialCapacityBoundsChecking() {
try {
- new BytesToBytesMap(sizeLimitedMemoryManager, 0, PAGE_SIZE_BYTES);
+ new BytesToBytesMap(sizeLimitedTaskMemoryManager, shuffleMemoryManager, 0, PAGE_SIZE_BYTES);
Assert.fail("Expected IllegalArgumentException to be thrown");
} catch (IllegalArgumentException e) {
// expected exception
@@ -362,23 +465,34 @@ public abstract class AbstractBytesToBytesMapSuite {
try {
new BytesToBytesMap(
- sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY + 1, PAGE_SIZE_BYTES);
+ sizeLimitedTaskMemoryManager,
+ shuffleMemoryManager,
+ BytesToBytesMap.MAX_CAPACITY + 1,
+ PAGE_SIZE_BYTES);
Assert.fail("Expected IllegalArgumentException to be thrown");
} catch (IllegalArgumentException e) {
// expected exception
}
- // Can allocate _at_ the max capacity
- BytesToBytesMap map =
- new BytesToBytesMap(sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY, PAGE_SIZE_BYTES);
- map.free();
+ // Ignored because this can OOM now that we allocate the long array w/o a TaskMemoryManager
+ // Can allocate _at_ the max capacity
+ // BytesToBytesMap map = new BytesToBytesMap(
+ // sizeLimitedTaskMemoryManager,
+ // shuffleMemoryManager,
+ // BytesToBytesMap.MAX_CAPACITY,
+ // PAGE_SIZE_BYTES);
+ // map.free();
}
- @Test
+ // Ignored because this can OOM now that we allocate the long array w/o a TaskMemoryManager
+ @Ignore
public void resizingLargeMap() {
// As long as a map's capacity is below the max, we should be able to resize up to the max
BytesToBytesMap map = new BytesToBytesMap(
- sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY - 64, PAGE_SIZE_BYTES);
+ sizeLimitedTaskMemoryManager,
+ shuffleMemoryManager,
+ BytesToBytesMap.MAX_CAPACITY - 64,
+ PAGE_SIZE_BYTES);
map.growAndRehash();
map.free();
}
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java
index 5a10de49f5..5a10de49f5 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java
+++ b/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java
index 12cc9b25d9..12cc9b25d9 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java
+++ b/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index f3b462778d..66012e3c94 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -15,11 +15,15 @@
* limitations under the License.
*/
-package org.apache.spark.sql.catalyst.expressions;
+package org.apache.spark.sql.execution;
+import java.io.IOException;
import java.util.Iterator;
+import org.apache.spark.shuffle.ShuffleMemoryManager;
import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.sql.types.DecimalType;
import org.apache.spark.sql.types.StructField;
@@ -87,7 +91,9 @@ public final class UnsafeFixedWidthAggregationMap {
* @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function)
* @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion.
* @param groupingKeySchema the schema of the grouping key, used for row conversion.
- * @param memoryManager the memory manager used to allocate our Unsafe memory structures.
+ * @param taskMemoryManager the memory manager used to allocate our Unsafe memory structures.
+ * @param shuffleMemoryManager the shuffle memory manager, for coordinating our memory usage with
+ * other tasks.
* @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing).
* @param pageSizeBytes the data page size, in bytes; limits the maximum record size.
* @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact)
@@ -96,15 +102,16 @@ public final class UnsafeFixedWidthAggregationMap {
InternalRow emptyAggregationBuffer,
StructType aggregationBufferSchema,
StructType groupingKeySchema,
- TaskMemoryManager memoryManager,
+ TaskMemoryManager taskMemoryManager,
+ ShuffleMemoryManager shuffleMemoryManager,
int initialCapacity,
long pageSizeBytes,
boolean enablePerfMetrics) {
this.aggregationBufferSchema = aggregationBufferSchema;
this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema);
this.groupingKeySchema = groupingKeySchema;
- this.map =
- new BytesToBytesMap(memoryManager, initialCapacity, pageSizeBytes, enablePerfMetrics);
+ this.map = new BytesToBytesMap(
+ taskMemoryManager, shuffleMemoryManager, initialCapacity, pageSizeBytes, enablePerfMetrics);
this.enablePerfMetrics = enablePerfMetrics;
// Initialize the buffer for aggregation value
@@ -116,7 +123,8 @@ public final class UnsafeFixedWidthAggregationMap {
/**
* Return the aggregation buffer for the current group. For efficiency, all calls to this method
- * return the same object.
+ * return the same object. If additional memory could not be allocated, then this method will
+ * signal an error by returning null.
*/
public UnsafeRow getAggregationBuffer(InternalRow groupingKey) {
final UnsafeRow unsafeGroupingKeyRow = this.groupingKeyProjection.apply(groupingKey);
@@ -129,7 +137,7 @@ public final class UnsafeFixedWidthAggregationMap {
if (!loc.isDefined()) {
// This is the first time that we've seen this grouping key, so we'll insert a copy of the
// empty aggregation buffer into the map:
- loc.putNewKey(
+ boolean putSucceeded = loc.putNewKey(
unsafeGroupingKeyRow.getBaseObject(),
unsafeGroupingKeyRow.getBaseOffset(),
unsafeGroupingKeyRow.getSizeInBytes(),
@@ -137,6 +145,9 @@ public final class UnsafeFixedWidthAggregationMap {
PlatformDependent.BYTE_ARRAY_OFFSET,
emptyAggregationBuffer.length
);
+ if (!putSucceeded) {
+ return null;
+ }
}
// Reset the pointer to point to the value that we just stored or looked up:
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index d851eae3fc..469de6ca8e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution
+import java.io.IOException
+
import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
@@ -266,6 +268,7 @@ case class GeneratedAggregate(
aggregationBufferSchema,
groupKeySchema,
TaskContext.get.taskMemoryManager(),
+ SparkEnv.get.shuffleMemoryManager,
1024 * 16, // initial capacity
pageSizeBytes,
false // disable tracking of performance metrics
@@ -275,6 +278,9 @@ case class GeneratedAggregate(
val currentRow: InternalRow = iter.next()
val groupKey: InternalRow = groupProjection(currentRow)
val aggregationBuffer = aggregationMap.getAggregationBuffer(groupKey)
+ if (aggregationBuffer == null) {
+ throw new IOException("Could not allocate memory to grow aggregation buffer")
+ }
updateProjection.target(aggregationBuffer)(joinedRow(aggregationBuffer, currentRow))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index f88a45f48a..cc8bbfd2f8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -17,10 +17,11 @@
package org.apache.spark.sql.execution.joins
-import java.io.{Externalizable, ObjectInput, ObjectOutput}
+import java.io.{IOException, Externalizable, ObjectInput, ObjectOutput}
import java.nio.ByteOrder
import java.util.{HashMap => JavaHashMap}
+import org.apache.spark.shuffle.ShuffleMemoryManager
import org.apache.spark.{SparkConf, SparkEnv, TaskContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
@@ -28,6 +29,7 @@ import org.apache.spark.sql.execution.SparkSqlSerializer
import org.apache.spark.unsafe.PlatformDependent
import org.apache.spark.unsafe.map.BytesToBytesMap
import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
+import org.apache.spark.util.Utils
import org.apache.spark.util.collection.CompactBuffer
@@ -217,7 +219,7 @@ private[joins] final class UnsafeHashedRelation(
}
}
- override def writeExternal(out: ObjectOutput): Unit = {
+ override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
out.writeInt(hashTable.size())
val iter = hashTable.entrySet().iterator()
@@ -256,16 +258,26 @@ private[joins] final class UnsafeHashedRelation(
}
}
- override def readExternal(in: ObjectInput): Unit = {
+ override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
val nKeys = in.readInt()
// This is used in Broadcast, shared by multiple tasks, so we use on-heap memory
- val memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
+ val taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
+
+ // Dummy shuffle memory manager which always grants all memory allocation requests.
+ // We use this because it doesn't make sense count shared broadcast variables' memory usage
+ // towards individual tasks' quotas. In the future, we should devise a better way of handling
+ // this.
+ val shuffleMemoryManager = new ShuffleMemoryManager(new SparkConf()) {
+ override def tryToAcquire(numBytes: Long): Long = numBytes
+ override def release(numBytes: Long): Unit = {}
+ }
val pageSizeBytes = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
.getSizeAsBytes("spark.buffer.pageSize", "64m")
binaryMap = new BytesToBytesMap(
- memoryManager,
+ taskMemoryManager,
+ shuffleMemoryManager,
nKeys * 2, // reduce hash collision
pageSizeBytes)
@@ -287,8 +299,11 @@ private[joins] final class UnsafeHashedRelation(
// put it into binary map
val loc = binaryMap.lookup(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize)
assert(!loc.isDefined, "Duplicated key found!")
- loc.putNewKey(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize,
+ val putSuceeded = loc.putNewKey(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize,
valuesBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, valuesSize)
+ if (!putSuceeded) {
+ throw new IOException("Could not allocate memory to grow BytesToBytesMap")
+ }
i += 1
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
index c6b4c729de..79fd52dacd 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
@@ -15,17 +15,18 @@
* limitations under the License.
*/
-package org.apache.spark.sql.catalyst.expressions
+package org.apache.spark.sql.execution
+
+import org.scalatest.{BeforeAndAfterEach, Matchers}
import scala.collection.JavaConverters._
import scala.util.Random
-import org.scalatest.{BeforeAndAfterEach, Matchers}
-
import org.apache.spark.SparkFunSuite
+import org.apache.spark.shuffle.ShuffleMemoryManager
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, TaskMemoryManager, MemoryAllocator}
+import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
import org.apache.spark.unsafe.types.UTF8String
@@ -41,16 +42,20 @@ class UnsafeFixedWidthAggregationMapSuite
private def emptyAggregationBuffer: InternalRow = InternalRow(0)
private val PAGE_SIZE_BYTES: Long = 1L << 26; // 64 megabytes
- private var memoryManager: TaskMemoryManager = null
+ private var taskMemoryManager: TaskMemoryManager = null
+ private var shuffleMemoryManager: ShuffleMemoryManager = null
override def beforeEach(): Unit = {
- memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
+ taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
+ shuffleMemoryManager = new ShuffleMemoryManager(Long.MaxValue)
}
override def afterEach(): Unit = {
- if (memoryManager != null) {
- memoryManager.cleanUpAllAllocatedMemory()
- memoryManager = null
+ if (taskMemoryManager != null) {
+ val leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask
+ assert(taskMemoryManager.cleanUpAllAllocatedMemory() === 0)
+ assert(leakedShuffleMemory === 0)
+ taskMemoryManager = null
}
}
@@ -69,7 +74,8 @@ class UnsafeFixedWidthAggregationMapSuite
emptyAggregationBuffer,
aggBufferSchema,
groupKeySchema,
- memoryManager,
+ taskMemoryManager,
+ shuffleMemoryManager,
1024, // initial capacity,
PAGE_SIZE_BYTES,
false // disable perf metrics
@@ -83,7 +89,8 @@ class UnsafeFixedWidthAggregationMapSuite
emptyAggregationBuffer,
aggBufferSchema,
groupKeySchema,
- memoryManager,
+ taskMemoryManager,
+ shuffleMemoryManager,
1024, // initial capacity
PAGE_SIZE_BYTES,
false // disable perf metrics
@@ -91,7 +98,7 @@ class UnsafeFixedWidthAggregationMapSuite
val groupKey = InternalRow(UTF8String.fromString("cats"))
// Looking up a key stores a zero-entry in the map (like Python Counters or DefaultDicts)
- map.getAggregationBuffer(groupKey)
+ assert(map.getAggregationBuffer(groupKey) != null)
val iter = map.iterator()
val entry = iter.next()
assert(!iter.hasNext)
@@ -110,7 +117,8 @@ class UnsafeFixedWidthAggregationMapSuite
emptyAggregationBuffer,
aggBufferSchema,
groupKeySchema,
- memoryManager,
+ taskMemoryManager,
+ shuffleMemoryManager,
128, // initial capacity
PAGE_SIZE_BYTES,
false // disable perf metrics
@@ -118,7 +126,7 @@ class UnsafeFixedWidthAggregationMapSuite
val rand = new Random(42)
val groupKeys: Set[String] = Seq.fill(512)(rand.nextString(1024)).toSet
groupKeys.foreach { keyString =>
- map.getAggregationBuffer(InternalRow(UTF8String.fromString(keyString)))
+ assert(map.getAggregationBuffer(InternalRow(UTF8String.fromString(keyString))) != null)
}
val seenKeys: Set[String] = map.iterator().asScala.map { entry =>
entry.key.getString(0)
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
index 3dc82d8c2e..91be46ba21 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
@@ -34,7 +34,7 @@ public class MemoryBlock extends MemoryLocation {
*/
int pageNumber = -1;
- MemoryBlock(@Nullable Object obj, long offset, long length) {
+ public MemoryBlock(@Nullable Object obj, long offset, long length) {
super(obj, offset);
this.length = length;
}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
index dd70df3b1f..358bb37250 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
@@ -164,6 +164,7 @@ public class TaskMemoryManager {
* top-level Javadoc for more details).
*/
public MemoryBlock allocate(long size) throws OutOfMemoryError {
+ assert(size > 0) : "Size must be positive, but got " + size;
final MemoryBlock memory = executorMemoryManager.allocate(size);
allocatedNonPageMemory.add(memory);
return memory;