aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
Diffstat (limited to 'core')
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java10
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java31
2 files changed, 19 insertions, 22 deletions
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index 9a7b2ad06c..2e40312674 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -468,6 +468,12 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
}
allocatedPages.clear();
}
+
+ // in-memory sorter will not be used after spilling
+ assert(inMemSorter != null);
+ released += inMemSorter.getMemoryUsage();
+ inMemSorter.free();
+ inMemSorter = null;
return released;
}
}
@@ -489,10 +495,6 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
}
upstream = nextUpstream;
nextUpstream = null;
-
- assert(inMemSorter != null);
- inMemSorter.free();
- inMemSorter = null;
}
numRecords--;
upstream.loadNext();
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index a218ad4623..dce1f15a29 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -108,6 +108,7 @@ public final class UnsafeInMemorySorter {
*/
public void free() {
consumer.freeArray(array);
+ array = null;
}
public void reset() {
@@ -160,28 +161,22 @@ public final class UnsafeInMemorySorter {
pos++;
}
- public static final class SortedIterator extends UnsafeSorterIterator {
+ public final class SortedIterator extends UnsafeSorterIterator {
- private final TaskMemoryManager memoryManager;
- private final int sortBufferInsertPosition;
- private final LongArray sortBuffer;
- private int position = 0;
+ private final int numRecords;
+ private int position;
private Object baseObject;
private long baseOffset;
private long keyPrefix;
private int recordLength;
- private SortedIterator(
- TaskMemoryManager memoryManager,
- int sortBufferInsertPosition,
- LongArray sortBuffer) {
- this.memoryManager = memoryManager;
- this.sortBufferInsertPosition = sortBufferInsertPosition;
- this.sortBuffer = sortBuffer;
+ private SortedIterator(int numRecords) {
+ this.numRecords = numRecords;
+ this.position = 0;
}
public SortedIterator clone () {
- SortedIterator iter = new SortedIterator(memoryManager, sortBufferInsertPosition, sortBuffer);
+ SortedIterator iter = new SortedIterator(numRecords);
iter.position = position;
iter.baseObject = baseObject;
iter.baseOffset = baseOffset;
@@ -192,21 +187,21 @@ public final class UnsafeInMemorySorter {
@Override
public boolean hasNext() {
- return position < sortBufferInsertPosition;
+ return position / 2 < numRecords;
}
public int numRecordsLeft() {
- return (sortBufferInsertPosition - position) / 2;
+ return numRecords - position / 2;
}
@Override
public void loadNext() {
// This pointer points to a 4-byte record length, followed by the record's bytes
- final long recordPointer = sortBuffer.get(position);
+ final long recordPointer = array.get(position);
baseObject = memoryManager.getPage(recordPointer);
baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4; // Skip over record length
recordLength = Platform.getInt(baseObject, baseOffset - 4);
- keyPrefix = sortBuffer.get(position + 1);
+ keyPrefix = array.get(position + 1);
position += 2;
}
@@ -229,6 +224,6 @@ public final class UnsafeInMemorySorter {
*/
public SortedIterator getSortedIterator() {
sorter.sort(array, 0, pos / 2, sortComparator);
- return new SortedIterator(memoryManager, pos, array);
+ return new SortedIterator(pos / 2);
}
}