aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--unsafe/pom.xml5
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java151
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java4
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java2
-rw-r--r--unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java165
5 files changed, 274 insertions, 53 deletions
diff --git a/unsafe/pom.xml b/unsafe/pom.xml
index 9e151fc7a9..2fd17267ac 100644
--- a/unsafe/pom.xml
+++ b/unsafe/pom.xml
@@ -65,6 +65,11 @@
<artifactId>junit-interface</artifactId>
<scope>test</scope>
</dependency>
+ <dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-all</artifactId>
+ <scope>test</scope>
+ </dependency>
</dependencies>
<build>
<outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index 19d6a169fd..bd4ca74cc7 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -23,6 +23,8 @@ import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
+import com.google.common.annotations.VisibleForTesting;
+
import org.apache.spark.unsafe.*;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.array.LongArray;
@@ -36,9 +38,8 @@ import org.apache.spark.unsafe.memory.*;
* This is backed by a power-of-2-sized hash table, using quadratic probing with triangular numbers,
* which is guaranteed to exhaust the space.
* <p>
- * The map can support up to 2^31 keys because we use 32 bit MurmurHash. If the key cardinality is
- * higher than this, you should probably be using sorting instead of hashing for better cache
- * locality.
+ * The map can support up to 2^29 keys. If the key cardinality is higher than this, you should
+ * probably be using sorting instead of hashing for better cache locality.
* <p>
* This class is not thread safe.
*/
@@ -48,6 +49,11 @@ public final class BytesToBytesMap {
private static final HashMapGrowthStrategy growthStrategy = HashMapGrowthStrategy.DOUBLING;
+ /**
+ * Special record length that is placed after the last record in a data page.
+ */
+ private static final int END_OF_PAGE_MARKER = -1;
+
private final TaskMemoryManager memoryManager;
/**
@@ -64,7 +70,7 @@ public final class BytesToBytesMap {
/**
* Offset into `currentDataPage` that points to the location where new data can be inserted into
- * the page.
+ * the page. This does not incorporate the page's base offset.
*/
private long pageCursor = 0;
@@ -74,6 +80,15 @@ public final class BytesToBytesMap {
*/
private static final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes
+ /**
+ * The maximum number of keys that BytesToBytesMap supports. The hash table has to be
+ * power-of-2-sized and its backing Java array can contain at most (1 << 30) elements, since
+ * that's the largest power-of-2 that's less than Integer.MAX_VALUE. We need two long array
+ * entries per key, giving us a maximum capacity of (1 << 29).
+ */
+ @VisibleForTesting
+ static final int MAX_CAPACITY = (1 << 29);
+
// This choice of page table size and page size means that we can address up to 500 gigabytes
// of memory.
@@ -143,6 +158,13 @@ public final class BytesToBytesMap {
this.loadFactor = loadFactor;
this.loc = new Location();
this.enablePerfMetrics = enablePerfMetrics;
+ if (initialCapacity <= 0) {
+ throw new IllegalArgumentException("Initial capacity must be greater than 0");
+ }
+ if (initialCapacity > MAX_CAPACITY) {
+ throw new IllegalArgumentException(
+ "Initial capacity " + initialCapacity + " exceeds maximum capacity of " + MAX_CAPACITY);
+ }
allocate(initialCapacity);
}
@@ -162,6 +184,55 @@ public final class BytesToBytesMap {
*/
public int size() { return size; }
+ private static final class BytesToBytesMapIterator implements Iterator<Location> {
+
+ private final int numRecords;
+ private final Iterator<MemoryBlock> dataPagesIterator;
+ private final Location loc;
+
+ private int currentRecordNumber = 0;
+ private Object pageBaseObject;
+ private long offsetInPage;
+
+ BytesToBytesMapIterator(int numRecords, Iterator<MemoryBlock> dataPagesIterator, Location loc) {
+ this.numRecords = numRecords;
+ this.dataPagesIterator = dataPagesIterator;
+ this.loc = loc;
+ if (dataPagesIterator.hasNext()) {
+ advanceToNextPage();
+ }
+ }
+
+ private void advanceToNextPage() {
+ final MemoryBlock currentPage = dataPagesIterator.next();
+ pageBaseObject = currentPage.getBaseObject();
+ offsetInPage = currentPage.getBaseOffset();
+ }
+
+ @Override
+ public boolean hasNext() {
+ return currentRecordNumber != numRecords;
+ }
+
+ @Override
+ public Location next() {
+ int keyLength = (int) PlatformDependent.UNSAFE.getLong(pageBaseObject, offsetInPage);
+ if (keyLength == END_OF_PAGE_MARKER) {
+ advanceToNextPage();
+ keyLength = (int) PlatformDependent.UNSAFE.getLong(pageBaseObject, offsetInPage);
+ }
+ loc.with(pageBaseObject, offsetInPage);
+ offsetInPage += 8 + 8 + keyLength + loc.getValueLength();
+ currentRecordNumber++;
+ return loc;
+ }
+
+ @Override
+ public void remove() {
+ throw new UnsupportedOperationException();
+ }
+ }
+
/**
* Returns an iterator for iterating over the entries of this map.
*
@@ -171,27 +242,7 @@ public final class BytesToBytesMap {
* `lookup()`, the behavior of the returned iterator is undefined.
*/
public Iterator<Location> iterator() {
- return new Iterator<Location>() {
-
- private int nextPos = bitset.nextSetBit(0);
-
- @Override
- public boolean hasNext() {
- return nextPos != -1;
- }
-
- @Override
- public Location next() {
- final int pos = nextPos;
- nextPos = bitset.nextSetBit(nextPos + 1);
- return loc.with(pos, 0, true);
- }
-
- @Override
- public void remove() {
- throw new UnsupportedOperationException();
- }
- };
+ return new BytesToBytesMapIterator(size, dataPages.iterator(), loc);
}
/**
@@ -268,8 +319,11 @@ public final class BytesToBytesMap {
private int valueLength;
private void updateAddressesAndSizes(long fullKeyAddress) {
- final Object page = memoryManager.getPage(fullKeyAddress);
- final long keyOffsetInPage = memoryManager.getOffsetInPage(fullKeyAddress);
+ updateAddressesAndSizes(
+ memoryManager.getPage(fullKeyAddress), memoryManager.getOffsetInPage(fullKeyAddress));
+ }
+
+ private void updateAddressesAndSizes(Object page, long keyOffsetInPage) {
long position = keyOffsetInPage;
keyLength = (int) PlatformDependent.UNSAFE.getLong(page, position);
position += 8; // word used to store the key size
@@ -291,6 +345,12 @@ public final class BytesToBytesMap {
return this;
}
+ Location with(Object page, long keyOffsetInPage) {
+ this.isDefined = true;
+ updateAddressesAndSizes(page, keyOffsetInPage);
+ return this;
+ }
+
/**
* Returns true if the key is defined at this position, and false otherwise.
*/
@@ -345,6 +405,8 @@ public final class BytesToBytesMap {
* <p>
* It is only valid to call this method immediately after calling `lookup()` using the same key.
* <p>
+ * The key and value must be word-aligned (that is, their sizes must multiples of 8).
+ * <p>
* After calling this method, calls to `get[Key|Value]Address()` and `get[Key|Value]Length`
* will return information on the data stored by this `putNewKey` call.
* <p>
@@ -370,17 +432,27 @@ public final class BytesToBytesMap {
isDefined = true;
assert (keyLengthBytes % 8 == 0);
assert (valueLengthBytes % 8 == 0);
+ if (size == MAX_CAPACITY) {
+ throw new IllegalStateException("BytesToBytesMap has reached maximum capacity");
+ }
// Here, we'll copy the data into our data pages. Because we only store a relative offset from
// the key address instead of storing the absolute address of the value, the key and value
// must be stored in the same memory page.
// (8 byte key length) (key) (8 byte value length) (value)
final long requiredSize = 8 + keyLengthBytes + 8 + valueLengthBytes;
- assert(requiredSize <= PAGE_SIZE_BYTES);
+ assert (requiredSize <= PAGE_SIZE_BYTES - 8); // Reserve 8 bytes for the end-of-page marker.
size++;
bitset.set(pos);
- // If there's not enough space in the current page, allocate a new page:
- if (currentDataPage == null || PAGE_SIZE_BYTES - pageCursor < requiredSize) {
+ // If there's not enough space in the current page, allocate a new page (8 bytes are reserved
+ // for the end-of-page marker).
+ if (currentDataPage == null || PAGE_SIZE_BYTES - 8 - pageCursor < requiredSize) {
+ if (currentDataPage != null) {
+ // There wasn't enough space in the current page, so write an end-of-page marker:
+ final Object pageBaseObject = currentDataPage.getBaseObject();
+ final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor;
+ PlatformDependent.UNSAFE.putLong(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER);
+ }
MemoryBlock newPage = memoryManager.allocatePage(PAGE_SIZE_BYTES);
dataPages.add(newPage);
pageCursor = 0;
@@ -414,7 +486,7 @@ public final class BytesToBytesMap {
longArray.set(pos * 2 + 1, keyHashcode);
updateAddressesAndSizes(storedKeyAddress);
isDefined = true;
- if (size > growthThreshold) {
+ if (size > growthThreshold && longArray.size() < MAX_CAPACITY) {
growAndRehash();
}
}
@@ -427,8 +499,11 @@ public final class BytesToBytesMap {
* @param capacity the new map capacity
*/
private void allocate(int capacity) {
- capacity = Math.max((int) Math.min(Integer.MAX_VALUE, nextPowerOf2(capacity)), 64);
- longArray = new LongArray(memoryManager.allocate(capacity * 8 * 2));
+ assert (capacity >= 0);
+ // The capacity needs to be divisible by 64 so that our bit set can be sized properly
+ capacity = Math.max((int) Math.min(MAX_CAPACITY, nextPowerOf2(capacity)), 64);
+ assert (capacity <= MAX_CAPACITY);
+ longArray = new LongArray(memoryManager.allocate(capacity * 8L * 2));
bitset = new BitSet(MemoryBlock.fromLongArray(new long[capacity / 64]));
this.growthThreshold = (int) (capacity * loadFactor);
@@ -494,10 +569,16 @@ public final class BytesToBytesMap {
return numHashCollisions;
}
+ @VisibleForTesting
+ int getNumDataPages() {
+ return dataPages.size();
+ }
+
/**
* Grows the size of the hash table and re-hash everything.
*/
- private void growAndRehash() {
+ @VisibleForTesting
+ void growAndRehash() {
long resizeStartTime = -1;
if (enablePerfMetrics) {
resizeStartTime = System.nanoTime();
@@ -508,7 +589,7 @@ public final class BytesToBytesMap {
final int oldCapacity = (int) oldBitSet.capacity();
// Allocate the new data structures
- allocate(Math.min(Integer.MAX_VALUE, growthStrategy.nextCapacity(oldCapacity)));
+ allocate(Math.min(growthStrategy.nextCapacity(oldCapacity), MAX_CAPACITY));
// Re-mask (we don't recompute the hashcode because we stored all 32 bits of it)
for (int pos = oldBitSet.nextSetBit(0); pos >= 0; pos = oldBitSet.nextSetBit(pos + 1)) {
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java b/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java
index 7c321baffe..20654e4eea 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java
@@ -32,7 +32,9 @@ public interface HashMapGrowthStrategy {
class Doubling implements HashMapGrowthStrategy {
@Override
public int nextCapacity(int currentCapacity) {
- return currentCapacity * 2;
+ assert (currentCapacity > 0);
+ // Guard against overflow
+ return (currentCapacity * 2 > 0) ? (currentCapacity * 2) : Integer.MAX_VALUE;
}
}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
index 2906ac8aba..10881969db 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
@@ -44,7 +44,7 @@ import org.slf4j.LoggerFactory;
* maximum size of a long[] array, allowing us to address 8192 * 2^32 * 8 bytes, which is
* approximately 35 terabytes of memory.
*/
-public final class TaskMemoryManager {
+public class TaskMemoryManager {
private final Logger logger = LoggerFactory.getLogger(TaskMemoryManager.class);
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index 7a5c0622d1..81315f7c94 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -25,24 +25,40 @@ import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+import static org.mockito.AdditionalMatchers.geq;
+import static org.mockito.Mockito.*;
import org.apache.spark.unsafe.array.ByteArrayMethods;
+import org.apache.spark.unsafe.memory.*;
import org.apache.spark.unsafe.PlatformDependent;
import static org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET;
-import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
-import org.apache.spark.unsafe.memory.MemoryAllocator;
-import org.apache.spark.unsafe.memory.MemoryLocation;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import static org.apache.spark.unsafe.PlatformDependent.LONG_ARRAY_OFFSET;
+
public abstract class AbstractBytesToBytesMapSuite {
private final Random rand = new Random(42);
private TaskMemoryManager memoryManager;
+ private TaskMemoryManager sizeLimitedMemoryManager;
@Before
public void setup() {
memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(getMemoryAllocator()));
+ // Mocked memory manager for tests that check the maximum array size, since actually allocating
+ // such large arrays will cause us to run out of memory in our tests.
+ sizeLimitedMemoryManager = spy(memoryManager);
+ when(sizeLimitedMemoryManager.allocate(geq(1L << 20))).thenAnswer(new Answer<MemoryBlock>() {
+ @Override
+ public MemoryBlock answer(InvocationOnMock invocation) throws Throwable {
+ if (((Long) invocation.getArguments()[0] / 8) > Integer.MAX_VALUE) {
+ throw new OutOfMemoryError("Requested array size exceeds VM limit");
+ }
+ return memoryManager.allocate(1L << 20);
+ }
+ });
}
@After
@@ -101,6 +117,7 @@ public abstract class AbstractBytesToBytesMapSuite {
final int keyLengthInBytes = keyLengthInWords * 8;
final byte[] key = getRandomByteArray(keyLengthInWords);
Assert.assertFalse(map.lookup(key, BYTE_ARRAY_OFFSET, keyLengthInBytes).isDefined());
+ Assert.assertFalse(map.iterator().hasNext());
} finally {
map.free();
}
@@ -159,7 +176,7 @@ public abstract class AbstractBytesToBytesMapSuite {
@Test
public void iteratorTest() throws Exception {
- final int size = 128;
+ final int size = 4096;
BytesToBytesMap map = new BytesToBytesMap(memoryManager, size / 2);
try {
for (long i = 0; i < size; i++) {
@@ -167,14 +184,26 @@ public abstract class AbstractBytesToBytesMapSuite {
final BytesToBytesMap.Location loc =
map.lookup(value, PlatformDependent.LONG_ARRAY_OFFSET, 8);
Assert.assertFalse(loc.isDefined());
- loc.putNewKey(
- value,
- PlatformDependent.LONG_ARRAY_OFFSET,
- 8,
- value,
- PlatformDependent.LONG_ARRAY_OFFSET,
- 8
- );
+ // Ensure that we store some zero-length keys
+ if (i % 5 == 0) {
+ loc.putNewKey(
+ null,
+ PlatformDependent.LONG_ARRAY_OFFSET,
+ 0,
+ value,
+ PlatformDependent.LONG_ARRAY_OFFSET,
+ 8
+ );
+ } else {
+ loc.putNewKey(
+ value,
+ PlatformDependent.LONG_ARRAY_OFFSET,
+ 8,
+ value,
+ PlatformDependent.LONG_ARRAY_OFFSET,
+ 8
+ );
+ }
}
final java.util.BitSet valuesSeen = new java.util.BitSet(size);
final Iterator<BytesToBytesMap.Location> iter = map.iterator();
@@ -183,11 +212,16 @@ public abstract class AbstractBytesToBytesMapSuite {
Assert.assertTrue(loc.isDefined());
final MemoryLocation keyAddress = loc.getKeyAddress();
final MemoryLocation valueAddress = loc.getValueAddress();
- final long key = PlatformDependent.UNSAFE.getLong(
- keyAddress.getBaseObject(), keyAddress.getBaseOffset());
final long value = PlatformDependent.UNSAFE.getLong(
valueAddress.getBaseObject(), valueAddress.getBaseOffset());
- Assert.assertEquals(key, value);
+ final long keyLength = loc.getKeyLength();
+ if (keyLength == 0) {
+ Assert.assertTrue("value " + value + " was not divisible by 5", value % 5 == 0);
+ } else {
+ final long key = PlatformDependent.UNSAFE.getLong(
+ keyAddress.getBaseObject(), keyAddress.getBaseOffset());
+ Assert.assertEquals(value, key);
+ }
valuesSeen.set((int) value);
}
Assert.assertEquals(size, valuesSeen.cardinality());
@@ -197,6 +231,74 @@ public abstract class AbstractBytesToBytesMapSuite {
}
@Test
+ public void iteratingOverDataPagesWithWastedSpace() throws Exception {
+ final int NUM_ENTRIES = 1000 * 1000;
+ final int KEY_LENGTH = 16;
+ final int VALUE_LENGTH = 40;
+ final BytesToBytesMap map = new BytesToBytesMap(memoryManager, NUM_ENTRIES);
+ // Each record will take 8 + 8 + 16 + 40 = 72 bytes of space in the data page. Our 64-megabyte
+ // pages won't be evenly-divisible by records of this size, which will cause us to waste some
+ // space at the end of the page. This is necessary in order for us to take the end-of-record
+ // handling branch in iterator().
+ try {
+ for (int i = 0; i < NUM_ENTRIES; i++) {
+ final long[] key = new long[] { i, i }; // 2 * 8 = 16 bytes
+ final long[] value = new long[] { i, i, i, i, i }; // 5 * 8 = 40 bytes
+ final BytesToBytesMap.Location loc = map.lookup(
+ key,
+ LONG_ARRAY_OFFSET,
+ KEY_LENGTH
+ );
+ Assert.assertFalse(loc.isDefined());
+ loc.putNewKey(
+ key,
+ LONG_ARRAY_OFFSET,
+ KEY_LENGTH,
+ value,
+ LONG_ARRAY_OFFSET,
+ VALUE_LENGTH
+ );
+ }
+ Assert.assertEquals(2, map.getNumDataPages());
+
+ final java.util.BitSet valuesSeen = new java.util.BitSet(NUM_ENTRIES);
+ final Iterator<BytesToBytesMap.Location> iter = map.iterator();
+ final long key[] = new long[KEY_LENGTH / 8];
+ final long value[] = new long[VALUE_LENGTH / 8];
+ while (iter.hasNext()) {
+ final BytesToBytesMap.Location loc = iter.next();
+ Assert.assertTrue(loc.isDefined());
+ Assert.assertEquals(KEY_LENGTH, loc.getKeyLength());
+ Assert.assertEquals(VALUE_LENGTH, loc.getValueLength());
+ PlatformDependent.copyMemory(
+ loc.getKeyAddress().getBaseObject(),
+ loc.getKeyAddress().getBaseOffset(),
+ key,
+ LONG_ARRAY_OFFSET,
+ KEY_LENGTH
+ );
+ PlatformDependent.copyMemory(
+ loc.getValueAddress().getBaseObject(),
+ loc.getValueAddress().getBaseOffset(),
+ value,
+ LONG_ARRAY_OFFSET,
+ VALUE_LENGTH
+ );
+ for (long j : key) {
+ Assert.assertEquals(key[0], j);
+ }
+ for (long j : value) {
+ Assert.assertEquals(key[0], j);
+ }
+ valuesSeen.set((int) key[0]);
+ }
+ Assert.assertEquals(NUM_ENTRIES, valuesSeen.cardinality());
+ } finally {
+ map.free();
+ }
+ }
+
+ @Test
public void randomizedStressTest() {
final int size = 65536;
// Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays
@@ -247,4 +349,35 @@ public abstract class AbstractBytesToBytesMapSuite {
map.free();
}
}
+
+ @Test
+ public void initialCapacityBoundsChecking() {
+ try {
+ new BytesToBytesMap(sizeLimitedMemoryManager, 0);
+ Assert.fail("Expected IllegalArgumentException to be thrown");
+ } catch (IllegalArgumentException e) {
+ // expected exception
+ }
+
+ try {
+ new BytesToBytesMap(sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY + 1);
+ Assert.fail("Expected IllegalArgumentException to be thrown");
+ } catch (IllegalArgumentException e) {
+ // expected exception
+ }
+
+ // Can allocate _at_ the max capacity
+ BytesToBytesMap map =
+ new BytesToBytesMap(sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY);
+ map.free();
+ }
+
+ @Test
+ public void resizingLargeMap() {
+ // As long as a map's capacity is below the max, we should be able to resize up to the max
+ BytesToBytesMap map =
+ new BytesToBytesMap(sizeLimitedMemoryManager, BytesToBytesMap.MAX_CAPACITY - 64);
+ map.growAndRehash();
+ map.free();
+ }
}