aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2015-07-31 19:19:27 -0700
committerJosh Rosen <joshrosen@databricks.com>2015-07-31 19:19:27 -0700
commit8cb415a4b9bc1f82127ccce4a5579d433f4e8f83 (patch)
treef9071996c6485e77700463106ae86c1a346508c9 /core
parentf51fd6fbb4d9822502f98b312251e317d757bc3a (diff)
downloadspark-8cb415a4b9bc1f82127ccce4a5579d433f4e8f83.tar.gz
spark-8cb415a4b9bc1f82127ccce4a5579d433f4e8f83.tar.bz2
spark-8cb415a4b9bc1f82127ccce4a5579d433f4e8f83.zip
[SPARK-9451] [SQL] Support entries larger than default page size in BytesToBytesMap & integrate with ShuffleMemoryManager
This patch adds support for entries larger than the default page size in BytesToBytesMap. These large rows are handled by allocating special overflow pages to hold individual entries. In addition, this patch integrates BytesToBytesMap with the ShuffleMemoryManager: - Move BytesToBytesMap from `unsafe` to `core` so that it can import `ShuffleMemoryManager`. - Before allocating new data pages, ask the ShuffleMemoryManager to reserve the memory: - `putNewKey()` now returns a boolean to indicate whether the insert succeeded or failed due to a lack of memory. The caller can use this value to respond to the memory pressure (e.g. by spilling). - `UnsafeFixedWidthAggregationMap. getAggregationBuffer()` now returns `null` to signal failure due to a lack of memory. - Updated all uses of these classes to handle these error conditions. - Added new tests for allocating large records and for allocations which fail due to memory pressure. - Extended the `afterAll()` test teardown methods to detect ShuffleMemoryManager leaks. Author: Josh Rosen <joshrosen@databricks.com> Closes #7762 from JoshRosen/large-rows and squashes the following commits: ae7bc56 [Josh Rosen] Fix compilation 82fc657 [Josh Rosen] Merge remote-tracking branch 'origin/master' into large-rows 34ab943 [Josh Rosen] Remove semi 31a525a [Josh Rosen] Integrate BytesToBytesMap with ShuffleMemoryManager. 626b33c [Josh Rosen] Move code to sql/core and spark/core packages so that ShuffleMemoryManager can be integrated ec4484c [Josh Rosen] Move BytesToBytesMap from unsafe package to core. 642ed69 [Josh Rosen] Rename size to numElements bea1152 [Josh Rosen] Add basic test. 2cd3570 [Josh Rosen] Remove accidental duplicated code 07ff9ef [Josh Rosen] Basic support for large rows in BytesToBytesMap.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java709
-rw-r--r--core/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java41
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala8
-rw-r--r--core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java499
-rw-r--r--core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java29
-rw-r--r--core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java29
6 files changed, 1314 insertions, 1 deletions
diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
new file mode 100644
index 0000000000..0f42950e6e
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -0,0 +1,709 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.map;
+
+import java.io.IOException;
+import java.lang.Override;
+import java.lang.UnsupportedOperationException;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.unsafe.*;
+import org.apache.spark.unsafe.array.ByteArrayMethods;
+import org.apache.spark.unsafe.array.LongArray;
+import org.apache.spark.unsafe.bitset.BitSet;
+import org.apache.spark.unsafe.hash.Murmur3_x86_32;
+import org.apache.spark.unsafe.memory.*;
+
+/**
+ * An append-only hash map where keys and values are contiguous regions of bytes.
+ * <p>
+ * This is backed by a power-of-2-sized hash table, using quadratic probing with triangular numbers,
+ * which is guaranteed to exhaust the space.
+ * <p>
+ * The map can support up to 2^29 keys. If the key cardinality is higher than this, you should
+ * probably be using sorting instead of hashing for better cache locality.
+ * <p>
+ * This class is not thread safe.
+ */
+public final class BytesToBytesMap {
+
+ private final Logger logger = LoggerFactory.getLogger(BytesToBytesMap.class);
+
+ private static final Murmur3_x86_32 HASHER = new Murmur3_x86_32(0);
+
+ private static final HashMapGrowthStrategy growthStrategy = HashMapGrowthStrategy.DOUBLING;
+
+ /**
+ * Special record length that is placed after the last record in a data page.
+ */
+ private static final int END_OF_PAGE_MARKER = -1;
+
+ private final TaskMemoryManager taskMemoryManager;
+
+ private final ShuffleMemoryManager shuffleMemoryManager;
+
+ /**
+ * A linked list for tracking all allocated data pages so that we can free all of our memory.
+ */
+ private final List<MemoryBlock> dataPages = new LinkedList<MemoryBlock>();
+
+ /**
+ * The data page that will be used to store keys and values for new hashtable entries. When this
+ * page becomes full, a new page will be allocated and this pointer will change to point to that
+ * new page.
+ */
+ private MemoryBlock currentDataPage = null;
+
+ /**
+ * Offset into `currentDataPage` that points to the location where new data can be inserted into
+ * the page. This does not incorporate the page's base offset.
+ */
+ private long pageCursor = 0;
+
+ /**
+ * The maximum number of keys that BytesToBytesMap supports. The hash table has to be
+ * power-of-2-sized and its backing Java array can contain at most (1 << 30) elements, since
+ * that's the largest power-of-2 that's less than Integer.MAX_VALUE. We need two long array
+ * entries per key, giving us a maximum capacity of (1 << 29).
+ */
+ @VisibleForTesting
+ static final int MAX_CAPACITY = (1 << 29);
+
+ // This choice of page table size and page size means that we can address up to 500 gigabytes
+ // of memory.
+
+ /**
+ * A single array to store the key and value.
+ *
+ * Position {@code 2 * i} in the array is used to track a pointer to the key at index {@code i},
+ * while position {@code 2 * i + 1} in the array holds key's full 32-bit hashcode.
+ */
+ private LongArray longArray;
+ // TODO: we're wasting 32 bits of space here; we can probably store fewer bits of the hashcode
+ // and exploit word-alignment to use fewer bits to hold the address. This might let us store
+ // only one long per map entry, increasing the chance that this array will fit in cache at the
+ // expense of maybe performing more lookups if we have hash collisions. Say that we stored only
+ // 27 bits of the hashcode and 37 bits of the address. 37 bits is enough to address 1 terabyte
+ // of RAM given word-alignment. If we use 13 bits of this for our page table, that gives us a
+ // maximum page size of 2^24 * 8 = ~134 megabytes per page. This change will require us to store
+ // full base addresses in the page table for off-heap mode so that we can reconstruct the full
+ // absolute memory addresses.
+
+ /**
+ * A {@link BitSet} used to track location of the map where the key is set.
+ * Size of the bitset should be half of the size of the long array.
+ */
+ private BitSet bitset;
+
+ private final double loadFactor;
+
+ /**
+ * The size of the data pages that hold key and value data. Map entries cannot span multiple
+ * pages, so this limits the maximum entry size.
+ */
+ private final long pageSizeBytes;
+
+ /**
+ * Number of keys defined in the map.
+ */
+ private int numElements;
+
+ /**
+ * The map will be expanded once the number of keys exceeds this threshold.
+ */
+ private int growthThreshold;
+
+ /**
+ * Mask for truncating hashcodes so that they do not exceed the long array's size.
+ * This is a strength reduction optimization; we're essentially performing a modulus operation,
+ * but doing so with a bitmask because this is a power-of-2-sized hash map.
+ */
+ private int mask;
+
+ /**
+ * Return value of {@link BytesToBytesMap#lookup(Object, long, int)}.
+ */
+ private final Location loc;
+
+ private final boolean enablePerfMetrics;
+
+ private long timeSpentResizingNs = 0;
+
+ private long numProbes = 0;
+
+ private long numKeyLookups = 0;
+
+ private long numHashCollisions = 0;
+
+ public BytesToBytesMap(
+ TaskMemoryManager taskMemoryManager,
+ ShuffleMemoryManager shuffleMemoryManager,
+ int initialCapacity,
+ double loadFactor,
+ long pageSizeBytes,
+ boolean enablePerfMetrics) {
+ this.taskMemoryManager = taskMemoryManager;
+ this.shuffleMemoryManager = shuffleMemoryManager;
+ this.loadFactor = loadFactor;
+ this.loc = new Location();
+ this.pageSizeBytes = pageSizeBytes;
+ this.enablePerfMetrics = enablePerfMetrics;
+ if (initialCapacity <= 0) {
+ throw new IllegalArgumentException("Initial capacity must be greater than 0");
+ }
+ if (initialCapacity > MAX_CAPACITY) {
+ throw new IllegalArgumentException(
+ "Initial capacity " + initialCapacity + " exceeds maximum capacity of " + MAX_CAPACITY);
+ }
+ if (pageSizeBytes > TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES) {
+ throw new IllegalArgumentException("Page size " + pageSizeBytes + " cannot exceed " +
+ TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES);
+ }
+ allocate(initialCapacity);
+ }
+
+ public BytesToBytesMap(
+ TaskMemoryManager taskMemoryManager,
+ ShuffleMemoryManager shuffleMemoryManager,
+ int initialCapacity,
+ long pageSizeBytes) {
+ this(taskMemoryManager, shuffleMemoryManager, initialCapacity, 0.70, pageSizeBytes, false);
+ }
+
+ public BytesToBytesMap(
+ TaskMemoryManager taskMemoryManager,
+ ShuffleMemoryManager shuffleMemoryManager,
+ int initialCapacity,
+ long pageSizeBytes,
+ boolean enablePerfMetrics) {
+ this(
+ taskMemoryManager,
+ shuffleMemoryManager,
+ initialCapacity,
+ 0.70,
+ pageSizeBytes,
+ enablePerfMetrics);
+ }
+
+ /**
+ * Returns the number of keys defined in the map.
+ */
+ public int numElements() { return numElements; }
+
+ private static final class BytesToBytesMapIterator implements Iterator<Location> {
+
+ private final int numRecords;
+ private final Iterator<MemoryBlock> dataPagesIterator;
+ private final Location loc;
+
+ private int currentRecordNumber = 0;
+ private Object pageBaseObject;
+ private long offsetInPage;
+
+ BytesToBytesMapIterator(int numRecords, Iterator<MemoryBlock> dataPagesIterator, Location loc) {
+ this.numRecords = numRecords;
+ this.dataPagesIterator = dataPagesIterator;
+ this.loc = loc;
+ if (dataPagesIterator.hasNext()) {
+ advanceToNextPage();
+ }
+ }
+
+ private void advanceToNextPage() {
+ final MemoryBlock currentPage = dataPagesIterator.next();
+ pageBaseObject = currentPage.getBaseObject();
+ offsetInPage = currentPage.getBaseOffset();
+ }
+
+ @Override
+ public boolean hasNext() {
+ return currentRecordNumber != numRecords;
+ }
+
+ @Override
+ public Location next() {
+ int keyLength = (int) PlatformDependent.UNSAFE.getLong(pageBaseObject, offsetInPage);
+ if (keyLength == END_OF_PAGE_MARKER) {
+ advanceToNextPage();
+ keyLength = (int) PlatformDependent.UNSAFE.getLong(pageBaseObject, offsetInPage);
+ }
+ loc.with(pageBaseObject, offsetInPage);
+ offsetInPage += 8 + 8 + keyLength + loc.getValueLength();
+ currentRecordNumber++;
+ return loc;
+ }
+
+ @Override
+ public void remove() {
+ throw new UnsupportedOperationException();
+ }
+ }
+
+ /**
+ * Returns an iterator for iterating over the entries of this map.
+ *
+ * For efficiency, all calls to `next()` will return the same {@link Location} object.
+ *
+ * If any other lookups or operations are performed on this map while iterating over it, including
+ * `lookup()`, the behavior of the returned iterator is undefined.
+ */
+ public Iterator<Location> iterator() {
+ return new BytesToBytesMapIterator(numElements, dataPages.iterator(), loc);
+ }
+
+ /**
+ * Looks up a key, and return a {@link Location} handle that can be used to test existence
+ * and read/write values.
+ *
+ * This function always return the same {@link Location} instance to avoid object allocation.
+ */
+ public Location lookup(
+ Object keyBaseObject,
+ long keyBaseOffset,
+ int keyRowLengthBytes) {
+ if (enablePerfMetrics) {
+ numKeyLookups++;
+ }
+ final int hashcode = HASHER.hashUnsafeWords(keyBaseObject, keyBaseOffset, keyRowLengthBytes);
+ int pos = hashcode & mask;
+ int step = 1;
+ while (true) {
+ if (enablePerfMetrics) {
+ numProbes++;
+ }
+ if (!bitset.isSet(pos)) {
+ // This is a new key.
+ return loc.with(pos, hashcode, false);
+ } else {
+ long stored = longArray.get(pos * 2 + 1);
+ if ((int) (stored) == hashcode) {
+ // Full hash code matches. Let's compare the keys for equality.
+ loc.with(pos, hashcode, true);
+ if (loc.getKeyLength() == keyRowLengthBytes) {
+ final MemoryLocation keyAddress = loc.getKeyAddress();
+ final Object storedKeyBaseObject = keyAddress.getBaseObject();
+ final long storedKeyBaseOffset = keyAddress.getBaseOffset();
+ final boolean areEqual = ByteArrayMethods.arrayEquals(
+ keyBaseObject,
+ keyBaseOffset,
+ storedKeyBaseObject,
+ storedKeyBaseOffset,
+ keyRowLengthBytes
+ );
+ if (areEqual) {
+ return loc;
+ } else {
+ if (enablePerfMetrics) {
+ numHashCollisions++;
+ }
+ }
+ }
+ }
+ }
+ pos = (pos + step) & mask;
+ step++;
+ }
+ }
+
+ /**
+ * Handle returned by {@link BytesToBytesMap#lookup(Object, long, int)} function.
+ */
+ public final class Location {
+ /** An index into the hash map's Long array */
+ private int pos;
+ /** True if this location points to a position where a key is defined, false otherwise */
+ private boolean isDefined;
+ /**
+ * The hashcode of the most recent key passed to
+ * {@link BytesToBytesMap#lookup(Object, long, int)}. Caching this hashcode here allows us to
+ * avoid re-hashing the key when storing a value for that key.
+ */
+ private int keyHashcode;
+ private final MemoryLocation keyMemoryLocation = new MemoryLocation();
+ private final MemoryLocation valueMemoryLocation = new MemoryLocation();
+ private int keyLength;
+ private int valueLength;
+
+ private void updateAddressesAndSizes(long fullKeyAddress) {
+ updateAddressesAndSizes(
+ taskMemoryManager.getPage(fullKeyAddress),
+ taskMemoryManager.getOffsetInPage(fullKeyAddress));
+ }
+
+ private void updateAddressesAndSizes(Object page, long keyOffsetInPage) {
+ long position = keyOffsetInPage;
+ keyLength = (int) PlatformDependent.UNSAFE.getLong(page, position);
+ position += 8; // word used to store the key size
+ keyMemoryLocation.setObjAndOffset(page, position);
+ position += keyLength;
+ valueLength = (int) PlatformDependent.UNSAFE.getLong(page, position);
+ position += 8; // word used to store the key size
+ valueMemoryLocation.setObjAndOffset(page, position);
+ }
+
+ Location with(int pos, int keyHashcode, boolean isDefined) {
+ this.pos = pos;
+ this.isDefined = isDefined;
+ this.keyHashcode = keyHashcode;
+ if (isDefined) {
+ final long fullKeyAddress = longArray.get(pos * 2);
+ updateAddressesAndSizes(fullKeyAddress);
+ }
+ return this;
+ }
+
+ Location with(Object page, long keyOffsetInPage) {
+ this.isDefined = true;
+ updateAddressesAndSizes(page, keyOffsetInPage);
+ return this;
+ }
+
+ /**
+ * Returns true if the key is defined at this position, and false otherwise.
+ */
+ public boolean isDefined() {
+ return isDefined;
+ }
+
+ /**
+ * Returns the address of the key defined at this position.
+ * This points to the first byte of the key data.
+ * Unspecified behavior if the key is not defined.
+ * For efficiency reasons, calls to this method always returns the same MemoryLocation object.
+ */
+ public MemoryLocation getKeyAddress() {
+ assert (isDefined);
+ return keyMemoryLocation;
+ }
+
+ /**
+ * Returns the length of the key defined at this position.
+ * Unspecified behavior if the key is not defined.
+ */
+ public int getKeyLength() {
+ assert (isDefined);
+ return keyLength;
+ }
+
+ /**
+ * Returns the address of the value defined at this position.
+ * This points to the first byte of the value data.
+ * Unspecified behavior if the key is not defined.
+ * For efficiency reasons, calls to this method always returns the same MemoryLocation object.
+ */
+ public MemoryLocation getValueAddress() {
+ assert (isDefined);
+ return valueMemoryLocation;
+ }
+
+ /**
+ * Returns the length of the value defined at this position.
+ * Unspecified behavior if the key is not defined.
+ */
+ public int getValueLength() {
+ assert (isDefined);
+ return valueLength;
+ }
+
+ /**
+ * Store a new key and value. This method may only be called once for a given key; if you want
+ * to update the value associated with a key, then you can directly manipulate the bytes stored
+ * at the value address. The return value indicates whether the put succeeded or whether it
+ * failed because additional memory could not be acquired.
+ * <p>
+ * It is only valid to call this method immediately after calling `lookup()` using the same key.
+ * </p>
+ * <p>
+ * The key and value must be word-aligned (that is, their sizes must multiples of 8).
+ * </p>
+ * <p>
+ * After calling this method, calls to `get[Key|Value]Address()` and `get[Key|Value]Length`
+ * will return information on the data stored by this `putNewKey` call.
+ * </p>
+ * <p>
+ * As an example usage, here's the proper way to store a new key:
+ * </p>
+ * <pre>
+ * Location loc = map.lookup(keyBaseObject, keyBaseOffset, keyLengthInBytes);
+ * if (!loc.isDefined()) {
+ * if (!loc.putNewKey(keyBaseObject, keyBaseOffset, keyLengthInBytes, ...)) {
+ * // handle failure to grow map (by spilling, for example)
+ * }
+ * }
+ * </pre>
+ * <p>
+ * Unspecified behavior if the key is not defined.
+ * </p>
+ *
+ * @return true if the put() was successful and false if the put() failed because memory could
+ * not be acquired.
+ */
+ public boolean putNewKey(
+ Object keyBaseObject,
+ long keyBaseOffset,
+ int keyLengthBytes,
+ Object valueBaseObject,
+ long valueBaseOffset,
+ int valueLengthBytes) {
+ assert (!isDefined) : "Can only set value once for a key";
+ assert (keyLengthBytes % 8 == 0);
+ assert (valueLengthBytes % 8 == 0);
+ if (numElements == MAX_CAPACITY) {
+ throw new IllegalStateException("BytesToBytesMap has reached maximum capacity");
+ }
+
+ // Here, we'll copy the data into our data pages. Because we only store a relative offset from
+ // the key address instead of storing the absolute address of the value, the key and value
+ // must be stored in the same memory page.
+ // (8 byte key length) (key) (8 byte value length) (value)
+ final long requiredSize = 8 + keyLengthBytes + 8 + valueLengthBytes;
+
+ // --- Figure out where to insert the new record ---------------------------------------------
+
+ final MemoryBlock dataPage;
+ final Object dataPageBaseObject;
+ final long dataPageInsertOffset;
+ boolean useOverflowPage = requiredSize > pageSizeBytes - 8;
+ if (useOverflowPage) {
+ // The record is larger than the page size, so allocate a special overflow page just to hold
+ // that record.
+ final long memoryRequested = requiredSize + 8;
+ final long memoryGranted = shuffleMemoryManager.tryToAcquire(memoryRequested);
+ if (memoryGranted != memoryRequested) {
+ shuffleMemoryManager.release(memoryGranted);
+ logger.debug("Failed to acquire {} bytes of memory", memoryRequested);
+ return false;
+ }
+ MemoryBlock overflowPage = taskMemoryManager.allocatePage(memoryRequested);
+ dataPages.add(overflowPage);
+ dataPage = overflowPage;
+ dataPageBaseObject = overflowPage.getBaseObject();
+ dataPageInsertOffset = overflowPage.getBaseOffset();
+ } else if (currentDataPage == null || pageSizeBytes - 8 - pageCursor < requiredSize) {
+ // The record can fit in a data page, but either we have not allocated any pages yet or
+ // the current page does not have enough space.
+ if (currentDataPage != null) {
+ // There wasn't enough space in the current page, so write an end-of-page marker:
+ final Object pageBaseObject = currentDataPage.getBaseObject();
+ final long lengthOffsetInPage = currentDataPage.getBaseOffset() + pageCursor;
+ PlatformDependent.UNSAFE.putLong(pageBaseObject, lengthOffsetInPage, END_OF_PAGE_MARKER);
+ }
+ final long memoryGranted = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
+ if (memoryGranted != pageSizeBytes) {
+ shuffleMemoryManager.release(memoryGranted);
+ logger.debug("Failed to acquire {} bytes of memory", pageSizeBytes);
+ return false;
+ }
+ MemoryBlock newPage = taskMemoryManager.allocatePage(pageSizeBytes);
+ dataPages.add(newPage);
+ pageCursor = 0;
+ currentDataPage = newPage;
+ dataPage = currentDataPage;
+ dataPageBaseObject = currentDataPage.getBaseObject();
+ dataPageInsertOffset = currentDataPage.getBaseOffset();
+ } else {
+ // There is enough space in the current data page.
+ dataPage = currentDataPage;
+ dataPageBaseObject = currentDataPage.getBaseObject();
+ dataPageInsertOffset = currentDataPage.getBaseOffset() + pageCursor;
+ }
+
+ // --- Append the key and value data to the current data page --------------------------------
+
+ long insertCursor = dataPageInsertOffset;
+
+ // Compute all of our offsets up-front:
+ final long keySizeOffsetInPage = insertCursor;
+ insertCursor += 8; // word used to store the key size
+ final long keyDataOffsetInPage = insertCursor;
+ insertCursor += keyLengthBytes;
+ final long valueSizeOffsetInPage = insertCursor;
+ insertCursor += 8; // word used to store the value size
+ final long valueDataOffsetInPage = insertCursor;
+ insertCursor += valueLengthBytes; // word used to store the value size
+
+ // Copy the key
+ PlatformDependent.UNSAFE.putLong(dataPageBaseObject, keySizeOffsetInPage, keyLengthBytes);
+ PlatformDependent.copyMemory(
+ keyBaseObject, keyBaseOffset, dataPageBaseObject, keyDataOffsetInPage, keyLengthBytes);
+ // Copy the value
+ PlatformDependent.UNSAFE.putLong(dataPageBaseObject, valueSizeOffsetInPage, valueLengthBytes);
+ PlatformDependent.copyMemory(valueBaseObject, valueBaseOffset, dataPageBaseObject,
+ valueDataOffsetInPage, valueLengthBytes);
+
+ // --- Update bookeeping data structures -----------------------------------------------------
+
+ if (useOverflowPage) {
+ // Store the end-of-page marker at the end of the data page
+ PlatformDependent.UNSAFE.putLong(dataPageBaseObject, insertCursor, END_OF_PAGE_MARKER);
+ } else {
+ pageCursor += requiredSize;
+ }
+
+ numElements++;
+ bitset.set(pos);
+ final long storedKeyAddress = taskMemoryManager.encodePageNumberAndOffset(
+ dataPage, keySizeOffsetInPage);
+ longArray.set(pos * 2, storedKeyAddress);
+ longArray.set(pos * 2 + 1, keyHashcode);
+ updateAddressesAndSizes(storedKeyAddress);
+ isDefined = true;
+ if (numElements > growthThreshold && longArray.size() < MAX_CAPACITY) {
+ growAndRehash();
+ }
+ return true;
+ }
+ }
+
+ /**
+ * Allocate new data structures for this map. When calling this outside of the constructor,
+ * make sure to keep references to the old data structures so that you can free them.
+ *
+ * @param capacity the new map capacity
+ */
+ private void allocate(int capacity) {
+ assert (capacity >= 0);
+ // The capacity needs to be divisible by 64 so that our bit set can be sized properly
+ capacity = Math.max((int) Math.min(MAX_CAPACITY, nextPowerOf2(capacity)), 64);
+ assert (capacity <= MAX_CAPACITY);
+ longArray = new LongArray(MemoryBlock.fromLongArray(new long[capacity * 2]));
+ bitset = new BitSet(MemoryBlock.fromLongArray(new long[capacity / 64]));
+
+ this.growthThreshold = (int) (capacity * loadFactor);
+ this.mask = capacity - 1;
+ }
+
+ /**
+ * Free all allocated memory associated with this map, including the storage for keys and values
+ * as well as the hash map array itself.
+ *
+ * This method is idempotent.
+ */
+ public void free() {
+ longArray = null;
+ bitset = null;
+ Iterator<MemoryBlock> dataPagesIterator = dataPages.iterator();
+ while (dataPagesIterator.hasNext()) {
+ MemoryBlock dataPage = dataPagesIterator.next();
+ dataPagesIterator.remove();
+ taskMemoryManager.freePage(dataPage);
+ shuffleMemoryManager.release(dataPage.size());
+ }
+ assert(dataPages.isEmpty());
+ }
+
+ /** Returns the total amount of memory, in bytes, consumed by this map's managed structures. */
+ public long getTotalMemoryConsumption() {
+ long totalDataPagesSize = 0L;
+ for (MemoryBlock dataPage : dataPages) {
+ totalDataPagesSize += dataPage.size();
+ }
+ return totalDataPagesSize + bitset.memoryBlock().size() + longArray.memoryBlock().size();
+ }
+
+ /**
+ * Returns the total amount of time spent resizing this map (in nanoseconds).
+ */
+ public long getTimeSpentResizingNs() {
+ if (!enablePerfMetrics) {
+ throw new IllegalStateException();
+ }
+ return timeSpentResizingNs;
+ }
+
+
+ /**
+ * Returns the average number of probes per key lookup.
+ */
+ public double getAverageProbesPerLookup() {
+ if (!enablePerfMetrics) {
+ throw new IllegalStateException();
+ }
+ return (1.0 * numProbes) / numKeyLookups;
+ }
+
+ public long getNumHashCollisions() {
+ if (!enablePerfMetrics) {
+ throw new IllegalStateException();
+ }
+ return numHashCollisions;
+ }
+
+ @VisibleForTesting
+ int getNumDataPages() {
+ return dataPages.size();
+ }
+
+ /**
+ * Grows the size of the hash table and re-hash everything.
+ */
+ @VisibleForTesting
+ void growAndRehash() {
+ long resizeStartTime = -1;
+ if (enablePerfMetrics) {
+ resizeStartTime = System.nanoTime();
+ }
+ // Store references to the old data structures to be used when we re-hash
+ final LongArray oldLongArray = longArray;
+ final BitSet oldBitSet = bitset;
+ final int oldCapacity = (int) oldBitSet.capacity();
+
+ // Allocate the new data structures
+ allocate(Math.min(growthStrategy.nextCapacity(oldCapacity), MAX_CAPACITY));
+
+ // Re-mask (we don't recompute the hashcode because we stored all 32 bits of it)
+ for (int pos = oldBitSet.nextSetBit(0); pos >= 0; pos = oldBitSet.nextSetBit(pos + 1)) {
+ final long keyPointer = oldLongArray.get(pos * 2);
+ final int hashcode = (int) oldLongArray.get(pos * 2 + 1);
+ int newPos = hashcode & mask;
+ int step = 1;
+ boolean keepGoing = true;
+
+ // No need to check for equality here when we insert so this has one less if branch than
+ // the similar code path in addWithoutResize.
+ while (keepGoing) {
+ if (!bitset.isSet(newPos)) {
+ bitset.set(newPos);
+ longArray.set(newPos * 2, keyPointer);
+ longArray.set(newPos * 2 + 1, hashcode);
+ keepGoing = false;
+ } else {
+ newPos = (newPos + step) & mask;
+ step++;
+ }
+ }
+ }
+
+ if (enablePerfMetrics) {
+ timeSpentResizingNs += System.nanoTime() - resizeStartTime;
+ }
+ }
+
+ /** Returns the next number greater or equal num that is power of 2. */
+ private static long nextPowerOf2(long num) {
+ final long highBit = Long.highestOneBit(num);
+ return (highBit == num) ? num : highBit << 1;
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java b/core/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java
new file mode 100644
index 0000000000..20654e4eea
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.map;
+
+/**
+ * Interface that defines how we can grow the size of a hash map when it is over a threshold.
+ */
+public interface HashMapGrowthStrategy {
+
+ int nextCapacity(int currentCapacity);
+
+ /**
+ * Double the size of the hash map every time.
+ */
+ HashMapGrowthStrategy DOUBLING = new Doubling();
+
+ class Doubling implements HashMapGrowthStrategy {
+ @Override
+ public int nextCapacity(int currentCapacity) {
+ assert (currentCapacity > 0);
+ // Guard against overflow
+ return (currentCapacity * 2 > 0) ? (currentCapacity * 2) : Integer.MAX_VALUE;
+ }
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
index f038b72295..00c1e078a4 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
@@ -85,7 +85,7 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging {
return toGrant
} else {
logInfo(
- s"Thread $taskAttemptId waiting for at least 1/2N of shuffle memory pool to be free")
+ s"TID $taskAttemptId waiting for at least 1/2N of shuffle memory pool to be free")
wait()
}
} else {
@@ -116,6 +116,12 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging {
taskMemory.remove(taskAttemptId)
notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed
}
+
+ /** Returns the memory consumption, in bytes, for the current task */
+ def getMemoryConsumptionForThisTask(): Long = synchronized {
+ val taskAttemptId = currentTaskAttemptId()
+ taskMemory.getOrElse(taskAttemptId, 0L)
+ }
}
private object ShuffleMemoryManager {
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
new file mode 100644
index 0000000000..60f483acbc
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -0,0 +1,499 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.map;
+
+import java.lang.Exception;
+import java.nio.ByteBuffer;
+import java.util.*;
+
+import org.junit.*;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+import static org.hamcrest.Matchers.greaterThan;
+import static org.mockito.AdditionalMatchers.geq;
+import static org.mockito.Mockito.*;
+
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.unsafe.array.ByteArrayMethods;
+import org.apache.spark.unsafe.memory.*;
+import org.apache.spark.unsafe.PlatformDependent;
+import static org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET;
+import static org.apache.spark.unsafe.PlatformDependent.LONG_ARRAY_OFFSET;
+
+
+public abstract class AbstractBytesToBytesMapSuite {
+
+ private final Random rand = new Random(42);
+
+ private ShuffleMemoryManager shuffleMemoryManager;
+ private TaskMemoryManager taskMemoryManager;
+ private TaskMemoryManager sizeLimitedTaskMemoryManager;
+ private final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes
+
+ @Before
+ public void setup() {
+ shuffleMemoryManager = new ShuffleMemoryManager(Long.MAX_VALUE);
+ taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(getMemoryAllocator()));
+ // Mocked memory manager for tests that check the maximum array size, since actually allocating
+ // such large arrays will cause us to run out of memory in our tests.
+ sizeLimitedTaskMemoryManager = mock(TaskMemoryManager.class);
+ when(sizeLimitedTaskMemoryManager.allocate(geq(1L << 20))).thenAnswer(
+ new Answer<MemoryBlock>() {
+ @Override
+ public MemoryBlock answer(InvocationOnMock invocation) throws Throwable {
+ if (((Long) invocation.getArguments()[0] / 8) > Integer.MAX_VALUE) {
+ throw new OutOfMemoryError("Requested array size exceeds VM limit");
+ }
+ return new MemoryBlock(null, 0, (Long) invocation.getArguments()[0]);
+ }
+ }
+ );
+ }
+
+ @After
+ public void tearDown() {
+ if (taskMemoryManager != null) {
+ long leakedShuffleMemory = shuffleMemoryManager.getMemoryConsumptionForThisTask();
+ Assert.assertEquals(0, taskMemoryManager.cleanUpAllAllocatedMemory());
+ Assert.assertEquals(0, leakedShuffleMemory);
+ shuffleMemoryManager = null;
+ taskMemoryManager = null;
+ }
+ }
+
+ protected abstract MemoryAllocator getMemoryAllocator();
+
+ private static byte[] getByteArray(MemoryLocation loc, int size) {
+ final byte[] arr = new byte[size];
+ PlatformDependent.copyMemory(
+ loc.getBaseObject(),
+ loc.getBaseOffset(),
+ arr,
+ BYTE_ARRAY_OFFSET,
+ size
+ );
+ return arr;
+ }
+
+ private byte[] getRandomByteArray(int numWords) {
+ Assert.assertTrue(numWords >= 0);
+ final int lengthInBytes = numWords * 8;
+ final byte[] bytes = new byte[lengthInBytes];
+ rand.nextBytes(bytes);
+ return bytes;
+ }
+
+ /**
+ * Fast equality checking for byte arrays, since these comparisons are a bottleneck
+ * in our stress tests.
+ */
+ private static boolean arrayEquals(
+ byte[] expected,
+ MemoryLocation actualAddr,
+ long actualLengthBytes) {
+ return (actualLengthBytes == expected.length) && ByteArrayMethods.arrayEquals(
+ expected,
+ BYTE_ARRAY_OFFSET,
+ actualAddr.getBaseObject(),
+ actualAddr.getBaseOffset(),
+ expected.length
+ );
+ }
+
+ @Test
+ public void emptyMap() {
+ BytesToBytesMap map = new BytesToBytesMap(
+ taskMemoryManager, shuffleMemoryManager, 64, PAGE_SIZE_BYTES);
+ try {
+ Assert.assertEquals(0, map.numElements());
+ final int keyLengthInWords = 10;
+ final int keyLengthInBytes = keyLengthInWords * 8;
+ final byte[] key = getRandomByteArray(keyLengthInWords);
+ Assert.assertFalse(map.lookup(key, BYTE_ARRAY_OFFSET, keyLengthInBytes).isDefined());
+ Assert.assertFalse(map.iterator().hasNext());
+ } finally {
+ map.free();
+ }
+ }
+
+ @Test
+ public void setAndRetrieveAKey() {
+ BytesToBytesMap map = new BytesToBytesMap(
+ taskMemoryManager, shuffleMemoryManager, 64, PAGE_SIZE_BYTES);
+ final int recordLengthWords = 10;
+ final int recordLengthBytes = recordLengthWords * 8;
+ final byte[] keyData = getRandomByteArray(recordLengthWords);
+ final byte[] valueData = getRandomByteArray(recordLengthWords);
+ try {
+ final BytesToBytesMap.Location loc =
+ map.lookup(keyData, BYTE_ARRAY_OFFSET, recordLengthBytes);
+ Assert.assertFalse(loc.isDefined());
+ Assert.assertTrue(loc.putNewKey(
+ keyData,
+ BYTE_ARRAY_OFFSET,
+ recordLengthBytes,
+ valueData,
+ BYTE_ARRAY_OFFSET,
+ recordLengthBytes
+ ));
+ // After storing the key and value, the other location methods should return results that
+ // reflect the result of this store without us having to call lookup() again on the same key.
+ Assert.assertEquals(recordLengthBytes, loc.getKeyLength());
+ Assert.assertEquals(recordLengthBytes, loc.getValueLength());
+ Assert.assertArrayEquals(keyData, getByteArray(loc.getKeyAddress(), recordLengthBytes));
+ Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes));
+
+ // After calling lookup() the location should still point to the correct data.
+ Assert.assertTrue(map.lookup(keyData, BYTE_ARRAY_OFFSET, recordLengthBytes).isDefined());
+ Assert.assertEquals(recordLengthBytes, loc.getKeyLength());
+ Assert.assertEquals(recordLengthBytes, loc.getValueLength());
+ Assert.assertArrayEquals(keyData, getByteArray(loc.getKeyAddress(), recordLengthBytes));
+ Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes));
+
+ try {
+ Assert.assertTrue(loc.putNewKey(
+ keyData,
+ BYTE_ARRAY_OFFSET,
+ recordLengthBytes,
+ valueData,
+ BYTE_ARRAY_OFFSET,
+ recordLengthBytes
+ ));
+ Assert.fail("Should not be able to set a new value for a key");
+ } catch (AssertionError e) {
+ // Expected exception; do nothing.
+ }
+ } finally {
+ map.free();
+ }
+ }
+
+ @Test
+ public void iteratorTest() throws Exception {
+ final int size = 4096;
+ BytesToBytesMap map = new BytesToBytesMap(
+ taskMemoryManager, shuffleMemoryManager, size / 2, PAGE_SIZE_BYTES);
+ try {
+ for (long i = 0; i < size; i++) {
+ final long[] value = new long[] { i };
+ final BytesToBytesMap.Location loc =
+ map.lookup(value, PlatformDependent.LONG_ARRAY_OFFSET, 8);
+ Assert.assertFalse(loc.isDefined());
+ // Ensure that we store some zero-length keys
+ if (i % 5 == 0) {
+ Assert.assertTrue(loc.putNewKey(
+ null,
+ PlatformDependent.LONG_ARRAY_OFFSET,
+ 0,
+ value,
+ PlatformDependent.LONG_ARRAY_OFFSET,
+ 8
+ ));
+ } else {
+ Assert.assertTrue(loc.putNewKey(
+ value,
+ PlatformDependent.LONG_ARRAY_OFFSET,
+ 8,
+ value,
+ PlatformDependent.LONG_ARRAY_OFFSET,
+ 8
+ ));
+ }
+ }
+ final java.util.BitSet valuesSeen = new java.util.BitSet(size);
+ final Iterator<BytesToBytesMap.Location> iter = map.iterator();
+ while (iter.hasNext()) {
+ final BytesToBytesMap.Location loc = iter.next();
+ Assert.assertTrue(loc.isDefined());
+ final MemoryLocation keyAddress = loc.getKeyAddress();
+ final MemoryLocation valueAddress = loc.getValueAddress();
+ final long value = PlatformDependent.UNSAFE.getLong(
+ valueAddress.getBaseObject(), valueAddress.getBaseOffset());
+ final long keyLength = loc.getKeyLength();
+ if (keyLength == 0) {
+ Assert.assertTrue("value " + value + " was not divisible by 5", value % 5 == 0);
+ } else {
+ final long key = PlatformDependent.UNSAFE.getLong(
+ keyAddress.getBaseObject(), keyAddress.getBaseOffset());
+ Assert.assertEquals(value, key);
+ }
+ valuesSeen.set((int) value);
+ }
+ Assert.assertEquals(size, valuesSeen.cardinality());
+ } finally {
+ map.free();
+ }
+ }
+
+ @Test
+ public void iteratingOverDataPagesWithWastedSpace() throws Exception {
+ final int NUM_ENTRIES = 1000 * 1000;
+ final int KEY_LENGTH = 16;
+ final int VALUE_LENGTH = 40;
+ final BytesToBytesMap map = new BytesToBytesMap(
+ taskMemoryManager, shuffleMemoryManager, NUM_ENTRIES, PAGE_SIZE_BYTES);
+ // Each record will take 8 + 8 + 16 + 40 = 72 bytes of space in the data page. Our 64-megabyte
+ // pages won't be evenly-divisible by records of this size, which will cause us to waste some
+ // space at the end of the page. This is necessary in order for us to take the end-of-record
+ // handling branch in iterator().
+ try {
+ for (int i = 0; i < NUM_ENTRIES; i++) {
+ final long[] key = new long[] { i, i }; // 2 * 8 = 16 bytes
+ final long[] value = new long[] { i, i, i, i, i }; // 5 * 8 = 40 bytes
+ final BytesToBytesMap.Location loc = map.lookup(
+ key,
+ LONG_ARRAY_OFFSET,
+ KEY_LENGTH
+ );
+ Assert.assertFalse(loc.isDefined());
+ Assert.assertTrue(loc.putNewKey(
+ key,
+ LONG_ARRAY_OFFSET,
+ KEY_LENGTH,
+ value,
+ LONG_ARRAY_OFFSET,
+ VALUE_LENGTH
+ ));
+ }
+ Assert.assertEquals(2, map.getNumDataPages());
+
+ final java.util.BitSet valuesSeen = new java.util.BitSet(NUM_ENTRIES);
+ final Iterator<BytesToBytesMap.Location> iter = map.iterator();
+ final long key[] = new long[KEY_LENGTH / 8];
+ final long value[] = new long[VALUE_LENGTH / 8];
+ while (iter.hasNext()) {
+ final BytesToBytesMap.Location loc = iter.next();
+ Assert.assertTrue(loc.isDefined());
+ Assert.assertEquals(KEY_LENGTH, loc.getKeyLength());
+ Assert.assertEquals(VALUE_LENGTH, loc.getValueLength());
+ PlatformDependent.copyMemory(
+ loc.getKeyAddress().getBaseObject(),
+ loc.getKeyAddress().getBaseOffset(),
+ key,
+ LONG_ARRAY_OFFSET,
+ KEY_LENGTH
+ );
+ PlatformDependent.copyMemory(
+ loc.getValueAddress().getBaseObject(),
+ loc.getValueAddress().getBaseOffset(),
+ value,
+ LONG_ARRAY_OFFSET,
+ VALUE_LENGTH
+ );
+ for (long j : key) {
+ Assert.assertEquals(key[0], j);
+ }
+ for (long j : value) {
+ Assert.assertEquals(key[0], j);
+ }
+ valuesSeen.set((int) key[0]);
+ }
+ Assert.assertEquals(NUM_ENTRIES, valuesSeen.cardinality());
+ } finally {
+ map.free();
+ }
+ }
+
+ @Test
+ public void randomizedStressTest() {
+ final int size = 65536;
+ // Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays
+ // into ByteBuffers in order to use them as keys here.
+ final Map<ByteBuffer, byte[]> expected = new HashMap<ByteBuffer, byte[]>();
+ final BytesToBytesMap map = new BytesToBytesMap(
+ taskMemoryManager, shuffleMemoryManager, size, PAGE_SIZE_BYTES);
+
+ try {
+ // Fill the map to 90% full so that we can trigger probing
+ for (int i = 0; i < size * 0.9; i++) {
+ final byte[] key = getRandomByteArray(rand.nextInt(256) + 1);
+ final byte[] value = getRandomByteArray(rand.nextInt(512) + 1);
+ if (!expected.containsKey(ByteBuffer.wrap(key))) {
+ expected.put(ByteBuffer.wrap(key), value);
+ final BytesToBytesMap.Location loc = map.lookup(
+ key,
+ BYTE_ARRAY_OFFSET,
+ key.length
+ );
+ Assert.assertFalse(loc.isDefined());
+ Assert.assertTrue(loc.putNewKey(
+ key,
+ BYTE_ARRAY_OFFSET,
+ key.length,
+ value,
+ BYTE_ARRAY_OFFSET,
+ value.length
+ ));
+ // After calling putNewKey, the following should be true, even before calling
+ // lookup():
+ Assert.assertTrue(loc.isDefined());
+ Assert.assertEquals(key.length, loc.getKeyLength());
+ Assert.assertEquals(value.length, loc.getValueLength());
+ Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), key.length));
+ Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), value.length));
+ }
+ }
+
+ for (Map.Entry<ByteBuffer, byte[]> entry : expected.entrySet()) {
+ final byte[] key = entry.getKey().array();
+ final byte[] value = entry.getValue();
+ final BytesToBytesMap.Location loc = map.lookup(key, BYTE_ARRAY_OFFSET, key.length);
+ Assert.assertTrue(loc.isDefined());
+ Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), loc.getKeyLength()));
+ Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), loc.getValueLength()));
+ }
+ } finally {
+ map.free();
+ }
+ }
+
+ @Test
+ public void randomizedTestWithRecordsLargerThanPageSize() {
+ final long pageSizeBytes = 128;
+ final BytesToBytesMap map = new BytesToBytesMap(
+ taskMemoryManager, shuffleMemoryManager, 64, pageSizeBytes);
+ // Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays
+ // into ByteBuffers in order to use them as keys here.
+ final Map<ByteBuffer, byte[]> expected = new HashMap<ByteBuffer, byte[]>();
+ try {
+ for (int i = 0; i < 1000; i++) {
+ final byte[] key = getRandomByteArray(rand.nextInt(128));
+ final byte[] value = getRandomByteArray(rand.nextInt(128));
+ if (!expected.containsKey(ByteBuffer.wrap(key))) {
+ expected.put(ByteBuffer.wrap(key), value);
+ final BytesToBytesMap.Location loc = map.lookup(
+ key,
+ BYTE_ARRAY_OFFSET,
+ key.length
+ );
+ Assert.assertFalse(loc.isDefined());
+ Assert.assertTrue(loc.putNewKey(
+ key,
+ BYTE_ARRAY_OFFSET,
+ key.length,
+ value,
+ BYTE_ARRAY_OFFSET,
+ value.length
+ ));
+ // After calling putNewKey, the following should be true, even before calling
+ // lookup():
+ Assert.assertTrue(loc.isDefined());
+ Assert.assertEquals(key.length, loc.getKeyLength());
+ Assert.assertEquals(value.length, loc.getValueLength());
+ Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), key.length));
+ Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), value.length));
+ }
+ }
+ for (Map.Entry<ByteBuffer, byte[]> entry : expected.entrySet()) {
+ final byte[] key = entry.getKey().array();
+ final byte[] value = entry.getValue();
+ final BytesToBytesMap.Location loc = map.lookup(key, BYTE_ARRAY_OFFSET, key.length);
+ Assert.assertTrue(loc.isDefined());
+ Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), loc.getKeyLength()));
+ Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), loc.getValueLength()));
+ }
+ } finally {
+ map.free();
+ }
+ }
+
+ @Test
+ public void failureToAllocateFirstPage() {
+ shuffleMemoryManager = new ShuffleMemoryManager(1024);
+ BytesToBytesMap map =
+ new BytesToBytesMap(taskMemoryManager, shuffleMemoryManager, 1, PAGE_SIZE_BYTES);
+ try {
+ final long[] emptyArray = new long[0];
+ final BytesToBytesMap.Location loc =
+ map.lookup(emptyArray, PlatformDependent.LONG_ARRAY_OFFSET, 0);
+ Assert.assertFalse(loc.isDefined());
+ Assert.assertFalse(loc.putNewKey(
+ emptyArray, LONG_ARRAY_OFFSET, 0,
+ emptyArray, LONG_ARRAY_OFFSET, 0
+ ));
+ } finally {
+ map.free();
+ }
+ }
+
+
+ @Test
+ public void failureToGrow() {
+ shuffleMemoryManager = new ShuffleMemoryManager(1024 * 10);
+ BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, shuffleMemoryManager, 1, 1024);
+ try {
+ boolean success = true;
+ int i;
+ for (i = 0; i < 1024; i++) {
+ final long[] arr = new long[]{i};
+ final BytesToBytesMap.Location loc = map.lookup(arr, PlatformDependent.LONG_ARRAY_OFFSET, 8);
+ success = loc.putNewKey(arr, LONG_ARRAY_OFFSET, 8, arr, LONG_ARRAY_OFFSET, 8);
+ if (!success) {
+ break;
+ }
+ }
+ Assert.assertThat(i, greaterThan(0));
+ Assert.assertFalse(success);
+ } finally {
+ map.free();
+ }
+ }
+
+ @Test
+ public void initialCapacityBoundsChecking() {
+ try {
+ new BytesToBytesMap(sizeLimitedTaskMemoryManager, shuffleMemoryManager, 0, PAGE_SIZE_BYTES);
+ Assert.fail("Expected IllegalArgumentException to be thrown");
+ } catch (IllegalArgumentException e) {
+ // expected exception
+ }
+
+ try {
+ new BytesToBytesMap(
+ sizeLimitedTaskMemoryManager,
+ shuffleMemoryManager,
+ BytesToBytesMap.MAX_CAPACITY + 1,
+ PAGE_SIZE_BYTES);
+ Assert.fail("Expected IllegalArgumentException to be thrown");
+ } catch (IllegalArgumentException e) {
+ // expected exception
+ }
+
+ // Ignored because this can OOM now that we allocate the long array w/o a TaskMemoryManager
+ // Can allocate _at_ the max capacity
+ // BytesToBytesMap map = new BytesToBytesMap(
+ // sizeLimitedTaskMemoryManager,
+ // shuffleMemoryManager,
+ // BytesToBytesMap.MAX_CAPACITY,
+ // PAGE_SIZE_BYTES);
+ // map.free();
+ }
+
+ // Ignored because this can OOM now that we allocate the long array w/o a TaskMemoryManager
+ @Ignore
+ public void resizingLargeMap() {
+ // As long as a map's capacity is below the max, we should be able to resize up to the max
+ BytesToBytesMap map = new BytesToBytesMap(
+ sizeLimitedTaskMemoryManager,
+ shuffleMemoryManager,
+ BytesToBytesMap.MAX_CAPACITY - 64,
+ PAGE_SIZE_BYTES);
+ map.growAndRehash();
+ map.free();
+ }
+}
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java
new file mode 100644
index 0000000000..5a10de49f5
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOffHeapSuite.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.map;
+
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+
+public class BytesToBytesMapOffHeapSuite extends AbstractBytesToBytesMapSuite {
+
+ @Override
+ protected MemoryAllocator getMemoryAllocator() {
+ return MemoryAllocator.UNSAFE;
+ }
+
+}
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java
new file mode 100644
index 0000000000..12cc9b25d9
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/unsafe/map/BytesToBytesMapOnHeapSuite.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.map;
+
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+
+public class BytesToBytesMapOnHeapSuite extends AbstractBytesToBytesMapSuite {
+
+ @Override
+ protected MemoryAllocator getMemoryAllocator() {
+ return MemoryAllocator.HEAP;
+ }
+
+}