aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-11-05 19:02:18 -0800
committerJosh Rosen <joshrosen@databricks.com>2015-11-05 19:02:18 -0800
commiteec74ba8bde7f9446cc38e687bda103e85669d35 (patch)
tree5e656d6333afde0255e96d930b245df28994bf9b /core
parent3cc2c053b5d68c747a30bd58cf388b87b1922f13 (diff)
downloadspark-eec74ba8bde7f9446cc38e687bda103e85669d35.tar.gz
spark-eec74ba8bde7f9446cc38e687bda103e85669d35.tar.bz2
spark-eec74ba8bde7f9446cc38e687bda103e85669d35.zip
[SPARK-7542][SQL] Support off-heap index/sort buffer
This brings the support of off-heap memory for array inside BytesToBytesMap and InMemorySorter, then we could allocate all the memory from off-heap for execution. Closes #8068 Author: Davies Liu <davies@databricks.com> Closes #9477 from davies/unsafe_timsort.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/java/org/apache/spark/memory/MemoryConsumer.java36
-rw-r--r--core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java6
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java26
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java67
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java38
-rw-r--r--core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java18
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java28
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java66
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java47
-rw-r--r--core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java23
-rw-r--r--core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java45
-rw-r--r--core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java16
-rw-r--r--core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java1
-rw-r--r--core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java12
14 files changed, 242 insertions, 187 deletions
diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java
index 008799cc77..8fbdb72832 100644
--- a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java
+++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java
@@ -20,6 +20,7 @@ package org.apache.spark.memory;
import java.io.IOException;
+import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.unsafe.memory.MemoryBlock;
@@ -28,9 +29,9 @@ import org.apache.spark.unsafe.memory.MemoryBlock;
*/
public abstract class MemoryConsumer {
- private final TaskMemoryManager taskMemoryManager;
+ protected final TaskMemoryManager taskMemoryManager;
private final long pageSize;
- private long used;
+ protected long used;
protected MemoryConsumer(TaskMemoryManager taskMemoryManager, long pageSize) {
this.taskMemoryManager = taskMemoryManager;
@@ -74,26 +75,29 @@ public abstract class MemoryConsumer {
public abstract long spill(long size, MemoryConsumer trigger) throws IOException;
/**
- * Acquire `size` bytes memory.
- *
- * If there is not enough memory, throws OutOfMemoryError.
+ * Allocates a LongArray of `size`.
*/
- protected void acquireMemory(long size) {
- long got = taskMemoryManager.acquireExecutionMemory(size, this);
- if (got < size) {
- taskMemoryManager.releaseExecutionMemory(got, this);
+ public LongArray allocateArray(long size) {
+ long required = size * 8L;
+ MemoryBlock page = taskMemoryManager.allocatePage(required, this);
+ if (page == null || page.size() < required) {
+ long got = 0;
+ if (page != null) {
+ got = page.size();
+ taskMemoryManager.freePage(page, this);
+ }
taskMemoryManager.showMemoryUsage();
- throw new OutOfMemoryError("Could not acquire " + size + " bytes of memory, got " + got);
+ throw new OutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + got);
}
- used += got;
+ used += required;
+ return new LongArray(page);
}
/**
- * Release `size` bytes memory.
+ * Frees a LongArray.
*/
- protected void releaseMemory(long size) {
- used -= size;
- taskMemoryManager.releaseExecutionMemory(size, this);
+ public void freeArray(LongArray array) {
+ freePage(array.memoryBlock());
}
/**
@@ -109,7 +113,7 @@ public abstract class MemoryConsumer {
long got = 0;
if (page != null) {
got = page.size();
- freePage(page);
+ taskMemoryManager.freePage(page, this);
}
taskMemoryManager.showMemoryUsage();
throw new OutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + got);
diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
index 4230575446..6440f9c0f3 100644
--- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
+++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
@@ -137,7 +137,7 @@ public class TaskMemoryManager {
if (got < required) {
// Call spill() on other consumers to release memory
for (MemoryConsumer c: consumers) {
- if (c != null && c != consumer && c.getUsed() > 0) {
+ if (c != consumer && c.getUsed() > 0) {
try {
long released = c.spill(required - got, consumer);
if (released > 0) {
@@ -173,7 +173,9 @@ public class TaskMemoryManager {
}
}
- consumers.add(consumer);
+ if (consumer != null) {
+ consumers.add(consumer);
+ }
logger.debug("Task {} acquire {} for {}", taskAttemptId, Utils.bytesToString(got), consumer);
return got;
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
index 400d852001..9affff8014 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
@@ -39,6 +39,7 @@ import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.DiskBlockObjectWriter;
import org.apache.spark.storage.TempShuffleBlockId;
import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.util.Utils;
@@ -114,8 +115,7 @@ final class ShuffleExternalSorter extends MemoryConsumer {
this.numElementsForSpillThreshold =
conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MAX_VALUE);
this.writeMetrics = writeMetrics;
- acquireMemory(initialSize * 8L);
- this.inMemSorter = new ShuffleInMemorySorter(initialSize);
+ this.inMemSorter = new ShuffleInMemorySorter(this, initialSize);
this.peakMemoryUsedBytes = getMemoryUsage();
}
@@ -301,9 +301,8 @@ final class ShuffleExternalSorter extends MemoryConsumer {
public void cleanupResources() {
freeMemory();
if (inMemSorter != null) {
- long sorterMemoryUsage = inMemSorter.getMemoryUsage();
+ inMemSorter.free();
inMemSorter = null;
- releaseMemory(sorterMemoryUsage);
}
for (SpillInfo spill : spills) {
if (spill.file.exists() && !spill.file.delete()) {
@@ -321,9 +320,10 @@ final class ShuffleExternalSorter extends MemoryConsumer {
assert(inMemSorter != null);
if (!inMemSorter.hasSpaceForAnotherRecord()) {
long used = inMemSorter.getMemoryUsage();
- long needed = used + inMemSorter.getMemoryToExpand();
+ LongArray array;
try {
- acquireMemory(needed); // could trigger spilling
+ // could trigger spilling
+ array = allocateArray(used / 8 * 2);
} catch (OutOfMemoryError e) {
// should have trigger spilling
assert(inMemSorter.hasSpaceForAnotherRecord());
@@ -331,16 +331,9 @@ final class ShuffleExternalSorter extends MemoryConsumer {
}
// check if spilling is triggered or not
if (inMemSorter.hasSpaceForAnotherRecord()) {
- releaseMemory(needed);
+ freeArray(array);
} else {
- try {
- inMemSorter.expandPointerArray();
- releaseMemory(used);
- } catch (OutOfMemoryError oom) {
- // Just in case that JVM had run out of memory
- releaseMemory(needed);
- spill();
- }
+ inMemSorter.expandPointerArray(array);
}
}
}
@@ -404,9 +397,8 @@ final class ShuffleExternalSorter extends MemoryConsumer {
// Do not count the final file towards the spill count.
writeSortedFile(true);
freeMemory();
- long sorterMemoryUsage = inMemSorter.getMemoryUsage();
+ inMemSorter.free();
inMemSorter = null;
- releaseMemory(sorterMemoryUsage);
}
return spills.toArray(new SpillInfo[spills.size()]);
} catch (IOException e) {
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
index e630575d1a..58ad88e1ed 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
@@ -19,11 +19,14 @@ package org.apache.spark.shuffle.sort;
import java.util.Comparator;
+import org.apache.spark.memory.MemoryConsumer;
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.util.collection.Sorter;
final class ShuffleInMemorySorter {
- private final Sorter<PackedRecordPointer, long[]> sorter;
+ private final Sorter<PackedRecordPointer, LongArray> sorter;
private static final class SortComparator implements Comparator<PackedRecordPointer> {
@Override
public int compare(PackedRecordPointer left, PackedRecordPointer right) {
@@ -32,24 +35,34 @@ final class ShuffleInMemorySorter {
}
private static final SortComparator SORT_COMPARATOR = new SortComparator();
+ private final MemoryConsumer consumer;
+
/**
* An array of record pointers and partition ids that have been encoded by
* {@link PackedRecordPointer}. The sort operates on this array instead of directly manipulating
* records.
*/
- private long[] array;
+ private LongArray array;
/**
* The position in the pointer array where new records can be inserted.
*/
private int pos = 0;
- public ShuffleInMemorySorter(int initialSize) {
+ public ShuffleInMemorySorter(MemoryConsumer consumer, int initialSize) {
+ this.consumer = consumer;
assert (initialSize > 0);
- this.array = new long[initialSize];
+ this.array = consumer.allocateArray(initialSize);
this.sorter = new Sorter<>(ShuffleSortDataFormat.INSTANCE);
}
+ public void free() {
+ if (array != null) {
+ consumer.freeArray(array);
+ array = null;
+ }
+ }
+
public int numRecords() {
return pos;
}
@@ -58,30 +71,25 @@ final class ShuffleInMemorySorter {
pos = 0;
}
- private int newLength() {
- // Guard against overflow:
- return array.length <= Integer.MAX_VALUE / 2 ? (array.length * 2) : Integer.MAX_VALUE;
- }
-
- /**
- * Returns the memory needed to expand
- */
- public long getMemoryToExpand() {
- return ((long) (newLength() - array.length)) * 8;
- }
-
- public void expandPointerArray() {
- final long[] oldArray = array;
- array = new long[newLength()];
- System.arraycopy(oldArray, 0, array, 0, oldArray.length);
+ public void expandPointerArray(LongArray newArray) {
+ assert(newArray.size() > array.size());
+ Platform.copyMemory(
+ array.getBaseObject(),
+ array.getBaseOffset(),
+ newArray.getBaseObject(),
+ newArray.getBaseOffset(),
+ array.size() * 8L
+ );
+ consumer.freeArray(array);
+ array = newArray;
}
public boolean hasSpaceForAnotherRecord() {
- return pos < array.length;
+ return pos < array.size();
}
public long getMemoryUsage() {
- return array.length * 8L;
+ return array.size() * 8L;
}
/**
@@ -96,14 +104,9 @@ final class ShuffleInMemorySorter {
*/
public void insertRecord(long recordPointer, int partitionId) {
if (!hasSpaceForAnotherRecord()) {
- if (array.length == Integer.MAX_VALUE) {
- throw new IllegalStateException("Sort pointer array has reached maximum size");
- } else {
- expandPointerArray();
- }
+ expandPointerArray(consumer.allocateArray(array.size() * 2));
}
- array[pos] =
- PackedRecordPointer.packPointer(recordPointer, partitionId);
+ array.set(pos, PackedRecordPointer.packPointer(recordPointer, partitionId));
pos++;
}
@@ -112,12 +115,12 @@ final class ShuffleInMemorySorter {
*/
public static final class ShuffleSorterIterator {
- private final long[] pointerArray;
+ private final LongArray pointerArray;
private final int numRecords;
final PackedRecordPointer packedRecordPointer = new PackedRecordPointer();
private int position = 0;
- public ShuffleSorterIterator(int numRecords, long[] pointerArray) {
+ public ShuffleSorterIterator(int numRecords, LongArray pointerArray) {
this.numRecords = numRecords;
this.pointerArray = pointerArray;
}
@@ -127,7 +130,7 @@ final class ShuffleInMemorySorter {
}
public void loadNext() {
- packedRecordPointer.set(pointerArray[position]);
+ packedRecordPointer.set(pointerArray.get(position));
position++;
}
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java
index 8a1e5aec6f..8f4e322997 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java
@@ -17,16 +17,19 @@
package org.apache.spark.shuffle.sort;
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.LongArray;
+import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.util.collection.SortDataFormat;
-final class ShuffleSortDataFormat extends SortDataFormat<PackedRecordPointer, long[]> {
+final class ShuffleSortDataFormat extends SortDataFormat<PackedRecordPointer, LongArray> {
public static final ShuffleSortDataFormat INSTANCE = new ShuffleSortDataFormat();
private ShuffleSortDataFormat() { }
@Override
- public PackedRecordPointer getKey(long[] data, int pos) {
+ public PackedRecordPointer getKey(LongArray data, int pos) {
// Since we re-use keys, this method shouldn't be called.
throw new UnsupportedOperationException();
}
@@ -37,31 +40,38 @@ final class ShuffleSortDataFormat extends SortDataFormat<PackedRecordPointer, lo
}
@Override
- public PackedRecordPointer getKey(long[] data, int pos, PackedRecordPointer reuse) {
- reuse.set(data[pos]);
+ public PackedRecordPointer getKey(LongArray data, int pos, PackedRecordPointer reuse) {
+ reuse.set(data.get(pos));
return reuse;
}
@Override
- public void swap(long[] data, int pos0, int pos1) {
- final long temp = data[pos0];
- data[pos0] = data[pos1];
- data[pos1] = temp;
+ public void swap(LongArray data, int pos0, int pos1) {
+ final long temp = data.get(pos0);
+ data.set(pos0, data.get(pos1));
+ data.set(pos1, temp);
}
@Override
- public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) {
- dst[dstPos] = src[srcPos];
+ public void copyElement(LongArray src, int srcPos, LongArray dst, int dstPos) {
+ dst.set(dstPos, src.get(srcPos));
}
@Override
- public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) {
- System.arraycopy(src, srcPos, dst, dstPos, length);
+ public void copyRange(LongArray src, int srcPos, LongArray dst, int dstPos, int length) {
+ Platform.copyMemory(
+ src.getBaseObject(),
+ src.getBaseOffset() + srcPos * 8,
+ dst.getBaseObject(),
+ dst.getBaseOffset() + dstPos * 8,
+ length * 8
+ );
}
@Override
- public long[] allocate(int length) {
- return new long[length];
+ public LongArray allocate(int length) {
+ // This buffer is used temporary (usually small), so it's fine to allocated from JVM heap.
+ return new LongArray(MemoryBlock.fromLongArray(new long[length]));
}
}
diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
index 6656fd1d0b..04694dc544 100644
--- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
+++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -20,7 +20,6 @@ package org.apache.spark.unsafe.map;
import javax.annotation.Nullable;
import java.io.File;
import java.io.IOException;
-import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;
@@ -724,11 +723,10 @@ public final class BytesToBytesMap extends MemoryConsumer {
*/
private void allocate(int capacity) {
assert (capacity >= 0);
- // The capacity needs to be divisible by 64 so that our bit set can be sized properly
capacity = Math.max((int) Math.min(MAX_CAPACITY, ByteArrayMethods.nextPowerOf2(capacity)), 64);
assert (capacity <= MAX_CAPACITY);
- acquireMemory(capacity * 16);
- longArray = new LongArray(MemoryBlock.fromLongArray(new long[capacity * 2]));
+ longArray = allocateArray(capacity * 2);
+ longArray.zeroOut();
this.growthThreshold = (int) (capacity * loadFactor);
this.mask = capacity - 1;
@@ -743,9 +741,8 @@ public final class BytesToBytesMap extends MemoryConsumer {
public void free() {
updatePeakMemoryUsed();
if (longArray != null) {
- long used = longArray.memoryBlock().size();
+ freeArray(longArray);
longArray = null;
- releaseMemory(used);
}
Iterator<MemoryBlock> dataPagesIterator = dataPages.iterator();
while (dataPagesIterator.hasNext()) {
@@ -834,9 +831,9 @@ public final class BytesToBytesMap extends MemoryConsumer {
/**
* Returns the underline long[] of longArray.
*/
- public long[] getArray() {
+ public LongArray getArray() {
assert(longArray != null);
- return (long[]) longArray.memoryBlock().getBaseObject();
+ return longArray;
}
/**
@@ -844,7 +841,8 @@ public final class BytesToBytesMap extends MemoryConsumer {
*/
public void reset() {
numElements = 0;
- Arrays.fill(getArray(), 0);
+ longArray.zeroOut();
+
while (dataPages.size() > 0) {
MemoryBlock dataPage = dataPages.removeLast();
freePage(dataPage);
@@ -887,7 +885,7 @@ public final class BytesToBytesMap extends MemoryConsumer {
longArray.set(newPos * 2, keyPointer);
longArray.set(newPos * 2 + 1, hashcode);
}
- releaseMemory(oldLongArray.memoryBlock().size());
+ freeArray(oldLongArray);
if (enablePerfMetrics) {
timeSpentResizingNs += System.nanoTime() - resizeStartTime;
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index cba043bc48..9a7b2ad06c 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -32,6 +32,7 @@ import org.apache.spark.memory.MemoryConsumer;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.util.TaskCompletionListener;
import org.apache.spark.util.Utils;
@@ -123,9 +124,8 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
this.writeMetrics = new ShuffleWriteMetrics();
if (existingInMemorySorter == null) {
- this.inMemSorter =
- new UnsafeInMemorySorter(taskMemoryManager, recordComparator, prefixComparator, initialSize);
- acquireMemory(inMemSorter.getMemoryUsage());
+ this.inMemSorter = new UnsafeInMemorySorter(
+ this, taskMemoryManager, recordComparator, prefixComparator, initialSize);
} else {
this.inMemSorter = existingInMemorySorter;
}
@@ -277,9 +277,8 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
deleteSpillFiles();
freeMemory();
if (inMemSorter != null) {
- long used = inMemSorter.getMemoryUsage();
+ inMemSorter.free();
inMemSorter = null;
- releaseMemory(used);
}
}
}
@@ -293,9 +292,10 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
assert(inMemSorter != null);
if (!inMemSorter.hasSpaceForAnotherRecord()) {
long used = inMemSorter.getMemoryUsage();
- long needed = used + inMemSorter.getMemoryToExpand();
+ LongArray array;
try {
- acquireMemory(needed); // could trigger spilling
+ // could trigger spilling
+ array = allocateArray(used / 8 * 2);
} catch (OutOfMemoryError e) {
// should have trigger spilling
assert(inMemSorter.hasSpaceForAnotherRecord());
@@ -303,16 +303,9 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
}
// check if spilling is triggered or not
if (inMemSorter.hasSpaceForAnotherRecord()) {
- releaseMemory(needed);
+ freeArray(array);
} else {
- try {
- inMemSorter.expandPointerArray();
- releaseMemory(used);
- } catch (OutOfMemoryError oom) {
- // Just in case that JVM had run out of memory
- releaseMemory(needed);
- spill();
- }
+ inMemSorter.expandPointerArray(array);
}
}
}
@@ -498,9 +491,8 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
nextUpstream = null;
assert(inMemSorter != null);
- long used = inMemSorter.getMemoryUsage();
+ inMemSorter.free();
inMemSorter = null;
- releaseMemory(used);
}
numRecords--;
upstream.loadNext();
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index d57213b9b8..a218ad4623 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -19,8 +19,10 @@ package org.apache.spark.util.collection.unsafe.sort;
import java.util.Comparator;
+import org.apache.spark.memory.MemoryConsumer;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.util.collection.Sorter;
/**
@@ -62,15 +64,16 @@ public final class UnsafeInMemorySorter {
}
}
+ private final MemoryConsumer consumer;
private final TaskMemoryManager memoryManager;
- private final Sorter<RecordPointerAndKeyPrefix, long[]> sorter;
+ private final Sorter<RecordPointerAndKeyPrefix, LongArray> sorter;
private final Comparator<RecordPointerAndKeyPrefix> sortComparator;
/**
* Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at
* index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix.
*/
- private long[] array;
+ private LongArray array;
/**
* The position in the sort buffer where new records can be inserted.
@@ -78,22 +81,33 @@ public final class UnsafeInMemorySorter {
private int pos = 0;
public UnsafeInMemorySorter(
+ final MemoryConsumer consumer,
final TaskMemoryManager memoryManager,
final RecordComparator recordComparator,
final PrefixComparator prefixComparator,
int initialSize) {
- this(memoryManager, recordComparator, prefixComparator, new long[initialSize * 2]);
+ this(consumer, memoryManager, recordComparator, prefixComparator,
+ consumer.allocateArray(initialSize * 2));
}
public UnsafeInMemorySorter(
+ final MemoryConsumer consumer,
final TaskMemoryManager memoryManager,
final RecordComparator recordComparator,
final PrefixComparator prefixComparator,
- long[] array) {
- this.array = array;
+ LongArray array) {
+ this.consumer = consumer;
this.memoryManager = memoryManager;
this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE);
this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager);
+ this.array = array;
+ }
+
+ /**
+ * Free the memory used by pointer array.
+ */
+ public void free() {
+ consumer.freeArray(array);
}
public void reset() {
@@ -107,26 +121,26 @@ public final class UnsafeInMemorySorter {
return pos / 2;
}
- private int newLength() {
- return array.length < Integer.MAX_VALUE / 2 ? (array.length * 2) : Integer.MAX_VALUE;
- }
-
- public long getMemoryToExpand() {
- return (long) (newLength() - array.length) * 8L;
- }
-
public long getMemoryUsage() {
- return array.length * 8L;
+ return array.size() * 8L;
}
public boolean hasSpaceForAnotherRecord() {
- return pos + 2 <= array.length;
+ return pos + 2 <= array.size();
}
- public void expandPointerArray() {
- final long[] oldArray = array;
- array = new long[newLength()];
- System.arraycopy(oldArray, 0, array, 0, oldArray.length);
+ public void expandPointerArray(LongArray newArray) {
+ if (newArray.size() < array.size()) {
+ throw new OutOfMemoryError("Not enough memory to grow pointer array");
+ }
+ Platform.copyMemory(
+ array.getBaseObject(),
+ array.getBaseOffset(),
+ newArray.getBaseObject(),
+ newArray.getBaseOffset(),
+ array.size() * 8L);
+ consumer.freeArray(array);
+ array = newArray;
}
/**
@@ -138,11 +152,11 @@ public final class UnsafeInMemorySorter {
*/
public void insertRecord(long recordPointer, long keyPrefix) {
if (!hasSpaceForAnotherRecord()) {
- expandPointerArray();
+ expandPointerArray(consumer.allocateArray(array.size() * 2));
}
- array[pos] = recordPointer;
+ array.set(pos, recordPointer);
pos++;
- array[pos] = keyPrefix;
+ array.set(pos, keyPrefix);
pos++;
}
@@ -150,7 +164,7 @@ public final class UnsafeInMemorySorter {
private final TaskMemoryManager memoryManager;
private final int sortBufferInsertPosition;
- private final long[] sortBuffer;
+ private final LongArray sortBuffer;
private int position = 0;
private Object baseObject;
private long baseOffset;
@@ -160,7 +174,7 @@ public final class UnsafeInMemorySorter {
private SortedIterator(
TaskMemoryManager memoryManager,
int sortBufferInsertPosition,
- long[] sortBuffer) {
+ LongArray sortBuffer) {
this.memoryManager = memoryManager;
this.sortBufferInsertPosition = sortBufferInsertPosition;
this.sortBuffer = sortBuffer;
@@ -188,11 +202,11 @@ public final class UnsafeInMemorySorter {
@Override
public void loadNext() {
// This pointer points to a 4-byte record length, followed by the record's bytes
- final long recordPointer = sortBuffer[position];
+ final long recordPointer = sortBuffer.get(position);
baseObject = memoryManager.getPage(recordPointer);
baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4; // Skip over record length
recordLength = Platform.getInt(baseObject, baseOffset - 4);
- keyPrefix = sortBuffer[position + 1];
+ keyPrefix = sortBuffer.get(position + 1);
position += 2;
}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java
index d09c728a7a..d3137f5f31 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java
@@ -17,6 +17,9 @@
package org.apache.spark.util.collection.unsafe.sort;
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.LongArray;
+import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.util.collection.SortDataFormat;
/**
@@ -26,14 +29,14 @@ import org.apache.spark.util.collection.SortDataFormat;
* Within each long[] buffer, position {@code 2 * i} holds a pointer pointer to the record at
* index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix.
*/
-final class UnsafeSortDataFormat extends SortDataFormat<RecordPointerAndKeyPrefix, long[]> {
+final class UnsafeSortDataFormat extends SortDataFormat<RecordPointerAndKeyPrefix, LongArray> {
public static final UnsafeSortDataFormat INSTANCE = new UnsafeSortDataFormat();
private UnsafeSortDataFormat() { }
@Override
- public RecordPointerAndKeyPrefix getKey(long[] data, int pos) {
+ public RecordPointerAndKeyPrefix getKey(LongArray data, int pos) {
// Since we re-use keys, this method shouldn't be called.
throw new UnsupportedOperationException();
}
@@ -44,37 +47,43 @@ final class UnsafeSortDataFormat extends SortDataFormat<RecordPointerAndKeyPrefi
}
@Override
- public RecordPointerAndKeyPrefix getKey(long[] data, int pos, RecordPointerAndKeyPrefix reuse) {
- reuse.recordPointer = data[pos * 2];
- reuse.keyPrefix = data[pos * 2 + 1];
+ public RecordPointerAndKeyPrefix getKey(LongArray data, int pos, RecordPointerAndKeyPrefix reuse) {
+ reuse.recordPointer = data.get(pos * 2);
+ reuse.keyPrefix = data.get(pos * 2 + 1);
return reuse;
}
@Override
- public void swap(long[] data, int pos0, int pos1) {
- long tempPointer = data[pos0 * 2];
- long tempKeyPrefix = data[pos0 * 2 + 1];
- data[pos0 * 2] = data[pos1 * 2];
- data[pos0 * 2 + 1] = data[pos1 * 2 + 1];
- data[pos1 * 2] = tempPointer;
- data[pos1 * 2 + 1] = tempKeyPrefix;
+ public void swap(LongArray data, int pos0, int pos1) {
+ long tempPointer = data.get(pos0 * 2);
+ long tempKeyPrefix = data.get(pos0 * 2 + 1);
+ data.set(pos0 * 2, data.get(pos1 * 2));
+ data.set(pos0 * 2 + 1, data.get(pos1 * 2 + 1));
+ data.set(pos1 * 2, tempPointer);
+ data.set(pos1 * 2 + 1, tempKeyPrefix);
}
@Override
- public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) {
- dst[dstPos * 2] = src[srcPos * 2];
- dst[dstPos * 2 + 1] = src[srcPos * 2 + 1];
+ public void copyElement(LongArray src, int srcPos, LongArray dst, int dstPos) {
+ dst.set(dstPos * 2, src.get(srcPos * 2));
+ dst.set(dstPos * 2 + 1, src.get(srcPos * 2 + 1));
}
@Override
- public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) {
- System.arraycopy(src, srcPos * 2, dst, dstPos * 2, length * 2);
+ public void copyRange(LongArray src, int srcPos, LongArray dst, int dstPos, int length) {
+ Platform.copyMemory(
+ src.getBaseObject(),
+ src.getBaseOffset() + srcPos * 16,
+ dst.getBaseObject(),
+ dst.getBaseOffset() + dstPos * 16,
+ length * 16);
}
@Override
- public long[] allocate(int length) {
+ public LongArray allocate(int length) {
assert (length < Integer.MAX_VALUE / 2) : "Length " + length + " is too large";
- return new long[length * 2];
+ // This is used as temporary buffer, it's fine to allocate from JVM heap.
+ return new LongArray(MemoryBlock.fromLongArray(new long[length * 2]));
}
}
diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
index dab7b0592c..c731317395 100644
--- a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
+++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
@@ -17,8 +17,6 @@
package org.apache.spark.memory;
-import java.io.IOException;
-
import org.junit.Assert;
import org.junit.Test;
@@ -27,27 +25,6 @@ import org.apache.spark.unsafe.memory.MemoryBlock;
public class TaskMemoryManagerSuite {
- class TestMemoryConsumer extends MemoryConsumer {
- TestMemoryConsumer(TaskMemoryManager memoryManager) {
- super(memoryManager);
- }
-
- @Override
- public long spill(long size, MemoryConsumer trigger) throws IOException {
- long used = getUsed();
- releaseMemory(used);
- return used;
- }
-
- void use(long size) {
- acquireMemory(size);
- }
-
- void free(long size) {
- releaseMemory(size);
- }
- }
-
@Test
public void leakedPageMemoryIsDetected() {
final TaskMemoryManager manager = new TaskMemoryManager(
diff --git a/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java
new file mode 100644
index 0000000000..8ae3642738
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.memory;
+
+import java.io.IOException;
+
+public class TestMemoryConsumer extends MemoryConsumer {
+ public TestMemoryConsumer(TaskMemoryManager memoryManager) {
+ super(memoryManager);
+ }
+
+ @Override
+ public long spill(long size, MemoryConsumer trigger) throws IOException {
+ long used = getUsed();
+ free(used);
+ return used;
+ }
+
+ void use(long size) {
+ long got = taskMemoryManager.acquireExecutionMemory(size, this);
+ used += got;
+ }
+
+ void free(long size) {
+ used -= size;
+ taskMemoryManager.releaseExecutionMemory(size, this);
+ }
+}
+
+
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
index 2293b1bbc1..faa5a863ee 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
@@ -25,13 +25,19 @@ import org.junit.Test;
import org.apache.spark.HashPartitioner;
import org.apache.spark.SparkConf;
-import org.apache.spark.unsafe.Platform;
+import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.memory.TestMemoryConsumer;
import org.apache.spark.memory.TestMemoryManager;
+import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.memory.MemoryBlock;
-import org.apache.spark.memory.TaskMemoryManager;
public class ShuffleInMemorySorterSuite {
+ final TestMemoryManager memoryManager =
+ new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false"));
+ final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0);
+ final TestMemoryConsumer consumer = new TestMemoryConsumer(taskMemoryManager);
+
private static String getStringFromDataPage(Object baseObject, long baseOffset, int strLength) {
final byte[] strBytes = new byte[strLength];
Platform.copyMemory(baseObject, baseOffset, strBytes, Platform.BYTE_ARRAY_OFFSET, strLength);
@@ -40,7 +46,7 @@ public class ShuffleInMemorySorterSuite {
@Test
public void testSortingEmptyInput() {
- final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(100);
+ final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(consumer, 100);
final ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator();
assert(!iter.hasNext());
}
@@ -63,7 +69,7 @@ public class ShuffleInMemorySorterSuite {
new TaskMemoryManager(new TestMemoryManager(conf), 0);
final MemoryBlock dataPage = memoryManager.allocatePage(2048, null);
final Object baseObject = dataPage.getBaseObject();
- final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4);
+ final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(consumer, 4);
final HashPartitioner hashPartitioner = new HashPartitioner(4);
// Write the records into the data page and store pointers into the sorter
@@ -104,7 +110,7 @@ public class ShuffleInMemorySorterSuite {
@Test
public void testSortingManyNumbers() throws Exception {
- ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4);
+ ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(consumer, 4);
int[] numbersToSort = new int[128000];
Random random = new Random(16);
for (int i = 0; i < numbersToSort.length; i++) {
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
index cfead0e592..11c3a7be38 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -390,7 +390,6 @@ public class UnsafeExternalSorterSuite {
for (int i = 0; i < numRecordsPerPage * 10; i++) {
insertNumber(sorter, i);
newPeakMemory = sorter.getPeakMemoryUsedBytes();
- // The first page is pre-allocated on instantiation
if (i % numRecordsPerPage == 0) {
// We allocated a new page for this record, so peak memory should change
assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
index 642f6585f8..a203a09648 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
@@ -23,6 +23,7 @@ import org.junit.Test;
import org.apache.spark.HashPartitioner;
import org.apache.spark.SparkConf;
+import org.apache.spark.memory.TestMemoryConsumer;
import org.apache.spark.memory.TestMemoryManager;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.unsafe.Platform;
@@ -44,9 +45,11 @@ public class UnsafeInMemorySorterSuite {
@Test
public void testSortingEmptyInput() {
- final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(
- new TaskMemoryManager(
- new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0),
+ final TaskMemoryManager memoryManager = new TaskMemoryManager(
+ new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0);
+ final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager);
+ final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(consumer,
+ memoryManager,
mock(RecordComparator.class),
mock(PrefixComparator.class),
100);
@@ -69,6 +72,7 @@ public class UnsafeInMemorySorterSuite {
};
final TaskMemoryManager memoryManager = new TaskMemoryManager(
new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0);
+ final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager);
final MemoryBlock dataPage = memoryManager.allocatePage(2048, null);
final Object baseObject = dataPage.getBaseObject();
// Write the records into the data page:
@@ -102,7 +106,7 @@ public class UnsafeInMemorySorterSuite {
return (int) prefix1 - (int) prefix2;
}
};
- UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(memoryManager, recordComparator,
+ UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(consumer, memoryManager, recordComparator,
prefixComparator, dataToSort.length);
// Given a page of records, insert those records into the sorter one-by-one:
position = dataPage.getBaseOffset();