aboutsummaryrefslogtreecommitdiff
path: root/core/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main')
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java10
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java3
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java44
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java114
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java253
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java13
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java58
7 files changed, 411 insertions, 84 deletions
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java b/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java
index f7a6c68be9..b36da80dbc 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java
@@ -42,6 +42,16 @@ final class PackedRecordPointer {
*/
static final int MAXIMUM_PARTITION_ID = (1 << 24) - 1; // 16777215
+ /**
+ * The index of the first byte of the partition id, counting from the least significant byte.
+ */
+ static final int PARTITION_ID_START_BYTE_INDEX = 5;
+
+ /**
+ * The index of the last byte of the partition id, counting from the least significant byte.
+ */
+ static final int PARTITION_ID_END_BYTE_INDEX = 7;
+
/** Bit mask for the lower 40 bits of a long. */
private static final long MASK_LONG_LOWER_40_BITS = (1L << 40) - 1;
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
index 3c2980e442..c4041a97e8 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
@@ -115,7 +115,8 @@ final class ShuffleExternalSorter extends MemoryConsumer {
this.numElementsForSpillThreshold =
conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MAX_VALUE);
this.writeMetrics = writeMetrics;
- this.inMemSorter = new ShuffleInMemorySorter(this, initialSize);
+ this.inMemSorter = new ShuffleInMemorySorter(
+ this, initialSize, conf.getBoolean("spark.shuffle.sort.useRadixSort", true));
this.peakMemoryUsedBytes = getMemoryUsage();
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
index 76b0e6a304..68630946ac 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
@@ -17,12 +17,14 @@
package org.apache.spark.shuffle.sort;
+import java.lang.Long;
import java.util.Comparator;
import org.apache.spark.memory.MemoryConsumer;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.util.collection.Sorter;
+import org.apache.spark.util.collection.unsafe.sort.RadixSort;
final class ShuffleInMemorySorter {
@@ -47,16 +49,29 @@ final class ShuffleInMemorySorter {
private LongArray array;
/**
+ * Whether to use radix sort for sorting in-memory partition ids. Radix sort is much faster
+ * but requires additional memory to be reserved memory as pointers are added.
+ */
+ private final boolean useRadixSort;
+
+ /**
+ * Set to 2x for radix sort to reserve extra memory for sorting, otherwise 1x.
+ */
+ private final int memoryAllocationFactor;
+
+ /**
* The position in the pointer array where new records can be inserted.
*/
private int pos = 0;
private int initialSize;
- ShuffleInMemorySorter(MemoryConsumer consumer, int initialSize) {
+ ShuffleInMemorySorter(MemoryConsumer consumer, int initialSize, boolean useRadixSort) {
this.consumer = consumer;
assert (initialSize > 0);
this.initialSize = initialSize;
+ this.useRadixSort = useRadixSort;
+ this.memoryAllocationFactor = useRadixSort ? 2 : 1;
this.array = consumer.allocateArray(initialSize);
this.sorter = new Sorter<>(ShuffleSortDataFormat.INSTANCE);
}
@@ -87,18 +102,18 @@ final class ShuffleInMemorySorter {
array.getBaseOffset(),
newArray.getBaseObject(),
newArray.getBaseOffset(),
- array.size() * 8L
+ array.size() * (Long.BYTES / memoryAllocationFactor)
);
consumer.freeArray(array);
array = newArray;
}
public boolean hasSpaceForAnotherRecord() {
- return pos < array.size();
+ return pos < array.size() / memoryAllocationFactor;
}
public long getMemoryUsage() {
- return array.size() * 8L;
+ return array.size() * Long.BYTES;
}
/**
@@ -125,17 +140,18 @@ final class ShuffleInMemorySorter {
public static final class ShuffleSorterIterator {
private final LongArray pointerArray;
- private final int numRecords;
+ private final int limit;
final PackedRecordPointer packedRecordPointer = new PackedRecordPointer();
private int position = 0;
- ShuffleSorterIterator(int numRecords, LongArray pointerArray) {
- this.numRecords = numRecords;
+ ShuffleSorterIterator(int numRecords, LongArray pointerArray, int startingPosition) {
+ this.limit = numRecords + startingPosition;
this.pointerArray = pointerArray;
+ this.position = startingPosition;
}
public boolean hasNext() {
- return position < numRecords;
+ return position < limit;
}
public void loadNext() {
@@ -148,7 +164,15 @@ final class ShuffleInMemorySorter {
* Return an iterator over record pointers in sorted order.
*/
public ShuffleSorterIterator getSortedIterator() {
- sorter.sort(array, 0, pos, SORT_COMPARATOR);
- return new ShuffleSorterIterator(pos, array);
+ int offset = 0;
+ if (useRadixSort) {
+ offset = RadixSort.sort(
+ array, pos,
+ PackedRecordPointer.PARTITION_ID_START_BYTE_INDEX,
+ PackedRecordPointer.PARTITION_ID_END_BYTE_INDEX, false, false);
+ } else {
+ sorter.sort(array, 0, pos, SORT_COMPARATOR);
+ }
+ return new ShuffleSorterIterator(pos, array, offset);
}
}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
index c2a8f429be..21f2fde79d 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
@@ -28,88 +28,92 @@ import org.apache.spark.util.Utils;
public class PrefixComparators {
private PrefixComparators() {}
- public static final StringPrefixComparator STRING = new StringPrefixComparator();
- public static final StringPrefixComparatorDesc STRING_DESC = new StringPrefixComparatorDesc();
- public static final BinaryPrefixComparator BINARY = new BinaryPrefixComparator();
- public static final BinaryPrefixComparatorDesc BINARY_DESC = new BinaryPrefixComparatorDesc();
- public static final LongPrefixComparator LONG = new LongPrefixComparator();
- public static final LongPrefixComparatorDesc LONG_DESC = new LongPrefixComparatorDesc();
- public static final DoublePrefixComparator DOUBLE = new DoublePrefixComparator();
- public static final DoublePrefixComparatorDesc DOUBLE_DESC = new DoublePrefixComparatorDesc();
-
- public static final class StringPrefixComparator extends PrefixComparator {
- @Override
- public int compare(long aPrefix, long bPrefix) {
- return UnsignedLongs.compare(aPrefix, bPrefix);
- }
-
+ public static final PrefixComparator STRING = new UnsignedPrefixComparator();
+ public static final PrefixComparator STRING_DESC = new UnsignedPrefixComparatorDesc();
+ public static final PrefixComparator BINARY = new UnsignedPrefixComparator();
+ public static final PrefixComparator BINARY_DESC = new UnsignedPrefixComparatorDesc();
+ public static final PrefixComparator LONG = new SignedPrefixComparator();
+ public static final PrefixComparator LONG_DESC = new SignedPrefixComparatorDesc();
+ public static final PrefixComparator DOUBLE = new UnsignedPrefixComparator();
+ public static final PrefixComparator DOUBLE_DESC = new UnsignedPrefixComparatorDesc();
+
+ public static final class StringPrefixComparator {
public static long computePrefix(UTF8String value) {
return value == null ? 0L : value.getPrefix();
}
}
- public static final class StringPrefixComparatorDesc extends PrefixComparator {
- @Override
- public int compare(long bPrefix, long aPrefix) {
- return UnsignedLongs.compare(aPrefix, bPrefix);
+ public static final class BinaryPrefixComparator {
+ public static long computePrefix(byte[] bytes) {
+ return ByteArray.getPrefix(bytes);
}
}
- public static final class BinaryPrefixComparator extends PrefixComparator {
- @Override
- public int compare(long aPrefix, long bPrefix) {
- return UnsignedLongs.compare(aPrefix, bPrefix);
+ public static final class DoublePrefixComparator {
+ /**
+ * Converts the double into a value that compares correctly as an unsigned long. For more
+ * details see http://stereopsis.com/radix.html.
+ */
+ public static long computePrefix(double value) {
+ // Java's doubleToLongBits already canonicalizes all NaN values to the smallest possible
+ // positive NaN, so there's nothing special we need to do for NaNs.
+ long bits = Double.doubleToLongBits(value);
+ // Negative floats compare backwards due to their sign-magnitude representation, so flip
+ // all the bits in this case.
+ long mask = -(bits >>> 63) | 0x8000000000000000L;
+ return bits ^ mask;
}
+ }
- public static long computePrefix(byte[] bytes) {
- return ByteArray.getPrefix(bytes);
- }
+ /**
+ * Provides radix sort parameters. Comparators implementing this also are indicating that the
+ * ordering they define is compatible with radix sort.
+ */
+ public static abstract class RadixSortSupport extends PrefixComparator {
+ /** @return Whether the sort should be descending in binary sort order. */
+ public abstract boolean sortDescending();
+
+ /** @return Whether the sort should take into account the sign bit. */
+ public abstract boolean sortSigned();
}
- public static final class BinaryPrefixComparatorDesc extends PrefixComparator {
+ //
+ // Standard prefix comparator implementations
+ //
+
+ public static final class UnsignedPrefixComparator extends RadixSortSupport {
+ @Override public final boolean sortDescending() { return false; }
+ @Override public final boolean sortSigned() { return false; }
@Override
- public int compare(long bPrefix, long aPrefix) {
+ public final int compare(long aPrefix, long bPrefix) {
return UnsignedLongs.compare(aPrefix, bPrefix);
}
}
- public static final class LongPrefixComparator extends PrefixComparator {
+ public static final class UnsignedPrefixComparatorDesc extends RadixSortSupport {
+ @Override public final boolean sortDescending() { return true; }
+ @Override public final boolean sortSigned() { return false; }
@Override
- public int compare(long a, long b) {
- return (a < b) ? -1 : (a > b) ? 1 : 0;
+ public final int compare(long bPrefix, long aPrefix) {
+ return UnsignedLongs.compare(aPrefix, bPrefix);
}
}
- public static final class LongPrefixComparatorDesc extends PrefixComparator {
+ public static final class SignedPrefixComparator extends RadixSortSupport {
+ @Override public final boolean sortDescending() { return false; }
+ @Override public final boolean sortSigned() { return true; }
@Override
- public int compare(long b, long a) {
+ public final int compare(long a, long b) {
return (a < b) ? -1 : (a > b) ? 1 : 0;
}
}
- public static final class DoublePrefixComparator extends PrefixComparator {
+ public static final class SignedPrefixComparatorDesc extends RadixSortSupport {
+ @Override public final boolean sortDescending() { return true; }
+ @Override public final boolean sortSigned() { return true; }
@Override
- public int compare(long aPrefix, long bPrefix) {
- double a = Double.longBitsToDouble(aPrefix);
- double b = Double.longBitsToDouble(bPrefix);
- return Utils.nanSafeCompareDoubles(a, b);
- }
-
- public static long computePrefix(double value) {
- return Double.doubleToLongBits(value);
- }
- }
-
- public static final class DoublePrefixComparatorDesc extends PrefixComparator {
- @Override
- public int compare(long bPrefix, long aPrefix) {
- double a = Double.longBitsToDouble(aPrefix);
- double b = Double.longBitsToDouble(bPrefix);
- return Utils.nanSafeCompareDoubles(a, b);
- }
-
- public static long computePrefix(double value) {
- return Double.doubleToLongBits(value);
+ public final int compare(long b, long a) {
+ return (a < b) ? -1 : (a > b) ? 1 : 0;
}
}
}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java
new file mode 100644
index 0000000000..3357b8e474
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java
@@ -0,0 +1,253 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.LongArray;
+
+public class RadixSort {
+
+ /**
+ * Sorts a given array of longs using least-significant-digit radix sort. This routine assumes
+ * you have extra space at the end of the array at least equal to the number of records. The
+ * sort is destructive and may relocate the data positioned within the array.
+ *
+ * @param array array of long elements followed by at least that many empty slots.
+ * @param numRecords number of data records in the array.
+ * @param startByteIndex the first byte (in range [0, 7]) to sort each long by, counting from the
+ * least significant byte.
+ * @param endByteIndex the last byte (in range [0, 7]) to sort each long by, counting from the
+ * least significant byte. Must be greater than startByteIndex.
+ * @param desc whether this is a descending (binary-order) sort.
+ * @param signed whether this is a signed (two's complement) sort.
+ *
+ * @return The starting index of the sorted data within the given array. We return this instead
+ * of always copying the data back to position zero for efficiency.
+ */
+ public static int sort(
+ LongArray array, int numRecords, int startByteIndex, int endByteIndex,
+ boolean desc, boolean signed) {
+ assert startByteIndex >= 0 : "startByteIndex (" + startByteIndex + ") should >= 0";
+ assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7";
+ assert endByteIndex > startByteIndex;
+ assert numRecords * 2 <= array.size();
+ int inIndex = 0;
+ int outIndex = numRecords;
+ if (numRecords > 0) {
+ long[][] counts = getCounts(array, numRecords, startByteIndex, endByteIndex);
+ for (int i = startByteIndex; i <= endByteIndex; i++) {
+ if (counts[i] != null) {
+ sortAtByte(
+ array, numRecords, counts[i], i, inIndex, outIndex,
+ desc, signed && i == endByteIndex);
+ int tmp = inIndex;
+ inIndex = outIndex;
+ outIndex = tmp;
+ }
+ }
+ }
+ return inIndex;
+ }
+
+ /**
+ * Performs a partial sort by copying data into destination offsets for each byte value at the
+ * specified byte offset.
+ *
+ * @param array array to partially sort.
+ * @param numRecords number of data records in the array.
+ * @param counts counts for each byte value. This routine destructively modifies this array.
+ * @param byteIdx the byte in a long to sort at, counting from the least significant byte.
+ * @param inIndex the starting index in the array where input data is located.
+ * @param outIndex the starting index where sorted output data should be written.
+ * @param desc whether this is a descending (binary-order) sort.
+ * @param signed whether this is a signed (two's complement) sort (only applies to last byte).
+ */
+ private static void sortAtByte(
+ LongArray array, int numRecords, long[] counts, int byteIdx, int inIndex, int outIndex,
+ boolean desc, boolean signed) {
+ assert counts.length == 256;
+ long[] offsets = transformCountsToOffsets(
+ counts, numRecords, array.getBaseOffset() + outIndex * 8, 8, desc, signed);
+ Object baseObject = array.getBaseObject();
+ long baseOffset = array.getBaseOffset() + inIndex * 8;
+ long maxOffset = baseOffset + numRecords * 8;
+ for (long offset = baseOffset; offset < maxOffset; offset += 8) {
+ long value = Platform.getLong(baseObject, offset);
+ int bucket = (int)((value >>> (byteIdx * 8)) & 0xff);
+ Platform.putLong(baseObject, offsets[bucket], value);
+ offsets[bucket] += 8;
+ }
+ }
+
+ /**
+ * Computes a value histogram for each byte in the given array.
+ *
+ * @param array array to count records in.
+ * @param numRecords number of data records in the array.
+ * @param startByteIndex the first byte to compute counts for (the prior are skipped).
+ * @param endByteIndex the last byte to compute counts for.
+ *
+ * @return an array of eight 256-byte count arrays, one for each byte starting from the least
+ * significant byte. If the byte does not need sorting the array will be null.
+ */
+ private static long[][] getCounts(
+ LongArray array, int numRecords, int startByteIndex, int endByteIndex) {
+ long[][] counts = new long[8][];
+ // Optimization: do a fast pre-pass to determine which byte indices we can skip for sorting.
+ // If all the byte values at a particular index are the same we don't need to count it.
+ long bitwiseMax = 0;
+ long bitwiseMin = -1L;
+ long maxOffset = array.getBaseOffset() + numRecords * 8;
+ Object baseObject = array.getBaseObject();
+ for (long offset = array.getBaseOffset(); offset < maxOffset; offset += 8) {
+ long value = Platform.getLong(baseObject, offset);
+ bitwiseMax |= value;
+ bitwiseMin &= value;
+ }
+ long bitsChanged = bitwiseMin ^ bitwiseMax;
+ // Compute counts for each byte index.
+ for (int i = startByteIndex; i <= endByteIndex; i++) {
+ if (((bitsChanged >>> (i * 8)) & 0xff) != 0) {
+ counts[i] = new long[256];
+ // TODO(ekl) consider computing all the counts in one pass.
+ for (long offset = array.getBaseOffset(); offset < maxOffset; offset += 8) {
+ counts[i][(int)((Platform.getLong(baseObject, offset) >>> (i * 8)) & 0xff)]++;
+ }
+ }
+ }
+ return counts;
+ }
+
+ /**
+ * Transforms counts into the proper unsafe output offsets for the sort type.
+ *
+ * @param counts counts for each byte value. This routine destructively modifies this array.
+ * @param numRecords number of data records in the original data array.
+ * @param outputOffset output offset in bytes from the base array object.
+ * @param bytesPerRecord size of each record (8 for plain sort, 16 for key-prefix sort).
+ * @param desc whether this is a descending (binary-order) sort.
+ * @param signed whether this is a signed (two's complement) sort.
+ *
+ * @return the input counts array.
+ */
+ private static long[] transformCountsToOffsets(
+ long[] counts, int numRecords, long outputOffset, int bytesPerRecord,
+ boolean desc, boolean signed) {
+ assert counts.length == 256;
+ int start = signed ? 128 : 0; // output the negative records first (values 129-255).
+ if (desc) {
+ int pos = numRecords;
+ for (int i = start; i < start + 256; i++) {
+ pos -= counts[i & 0xff];
+ counts[i & 0xff] = outputOffset + pos * bytesPerRecord;
+ }
+ } else {
+ int pos = 0;
+ for (int i = start; i < start + 256; i++) {
+ long tmp = counts[i & 0xff];
+ counts[i & 0xff] = outputOffset + pos * bytesPerRecord;
+ pos += tmp;
+ }
+ }
+ return counts;
+ }
+
+ /**
+ * Specialization of sort() for key-prefix arrays. In this type of array, each record consists
+ * of two longs, only the second of which is sorted on.
+ */
+ public static int sortKeyPrefixArray(
+ LongArray array,
+ int numRecords,
+ int startByteIndex,
+ int endByteIndex,
+ boolean desc,
+ boolean signed) {
+ assert startByteIndex >= 0 : "startByteIndex (" + startByteIndex + ") should >= 0";
+ assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7";
+ assert endByteIndex > startByteIndex;
+ assert numRecords * 4 <= array.size();
+ int inIndex = 0;
+ int outIndex = numRecords * 2;
+ if (numRecords > 0) {
+ long[][] counts = getKeyPrefixArrayCounts(array, numRecords, startByteIndex, endByteIndex);
+ for (int i = startByteIndex; i <= endByteIndex; i++) {
+ if (counts[i] != null) {
+ sortKeyPrefixArrayAtByte(
+ array, numRecords, counts[i], i, inIndex, outIndex,
+ desc, signed && i == endByteIndex);
+ int tmp = inIndex;
+ inIndex = outIndex;
+ outIndex = tmp;
+ }
+ }
+ }
+ return inIndex;
+ }
+
+ /**
+ * Specialization of getCounts() for key-prefix arrays. We could probably combine this with
+ * getCounts with some added parameters but that seems to hurt in benchmarks.
+ */
+ private static long[][] getKeyPrefixArrayCounts(
+ LongArray array, int numRecords, int startByteIndex, int endByteIndex) {
+ long[][] counts = new long[8][];
+ long bitwiseMax = 0;
+ long bitwiseMin = -1L;
+ long limit = array.getBaseOffset() + numRecords * 16;
+ Object baseObject = array.getBaseObject();
+ for (long offset = array.getBaseOffset(); offset < limit; offset += 16) {
+ long value = Platform.getLong(baseObject, offset + 8);
+ bitwiseMax |= value;
+ bitwiseMin &= value;
+ }
+ long bitsChanged = bitwiseMin ^ bitwiseMax;
+ for (int i = startByteIndex; i <= endByteIndex; i++) {
+ if (((bitsChanged >>> (i * 8)) & 0xff) != 0) {
+ counts[i] = new long[256];
+ for (long offset = array.getBaseOffset(); offset < limit; offset += 16) {
+ counts[i][(int)((Platform.getLong(baseObject, offset + 8) >>> (i * 8)) & 0xff)]++;
+ }
+ }
+ }
+ return counts;
+ }
+
+ /**
+ * Specialization of sortAtByte() for key-prefix arrays.
+ */
+ private static void sortKeyPrefixArrayAtByte(
+ LongArray array, int numRecords, long[] counts, int byteIdx, int inIndex, int outIndex,
+ boolean desc, boolean signed) {
+ assert counts.length == 256;
+ long[] offsets = transformCountsToOffsets(
+ counts, numRecords, array.getBaseOffset() + outIndex * 8, 16, desc, signed);
+ Object baseObject = array.getBaseObject();
+ long baseOffset = array.getBaseOffset() + inIndex * 8;
+ long maxOffset = baseOffset + numRecords * 16;
+ for (long offset = baseOffset; offset < maxOffset; offset += 16) {
+ long key = Platform.getLong(baseObject, offset);
+ long prefix = Platform.getLong(baseObject, offset + 8);
+ int bucket = (int)((prefix >>> (byteIdx * 8)) & 0xff);
+ long dest = offsets[bucket];
+ Platform.putLong(baseObject, dest, key);
+ Platform.putLong(baseObject, dest + 8, prefix);
+ offsets[bucket] += 16;
+ }
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index 3e32dd9d63..66a77982ad 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -89,7 +89,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
UnsafeInMemorySorter inMemorySorter) throws IOException {
UnsafeExternalSorter sorter = new UnsafeExternalSorter(taskMemoryManager, blockManager,
serializerManager, taskContext, recordComparator, prefixComparator, initialSize,
- pageSizeBytes, inMemorySorter);
+ pageSizeBytes, inMemorySorter, false /* ignored */);
sorter.spill(Long.MAX_VALUE, sorter);
// The external sorter will be used to insert records, in-memory sorter is not needed.
sorter.inMemSorter = null;
@@ -104,9 +104,11 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
RecordComparator recordComparator,
PrefixComparator prefixComparator,
int initialSize,
- long pageSizeBytes) {
+ long pageSizeBytes,
+ boolean canUseRadixSort) {
return new UnsafeExternalSorter(taskMemoryManager, blockManager, serializerManager,
- taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, null);
+ taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, null,
+ canUseRadixSort);
}
private UnsafeExternalSorter(
@@ -118,7 +120,8 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
PrefixComparator prefixComparator,
int initialSize,
long pageSizeBytes,
- @Nullable UnsafeInMemorySorter existingInMemorySorter) {
+ @Nullable UnsafeInMemorySorter existingInMemorySorter,
+ boolean canUseRadixSort) {
super(taskMemoryManager, pageSizeBytes);
this.taskMemoryManager = taskMemoryManager;
this.blockManager = blockManager;
@@ -133,7 +136,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
if (existingInMemorySorter == null) {
this.inMemSorter = new UnsafeInMemorySorter(
- this, taskMemoryManager, recordComparator, prefixComparator, initialSize);
+ this, taskMemoryManager, recordComparator, prefixComparator, initialSize, canUseRadixSort);
} else {
this.inMemSorter = existingInMemorySorter;
}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index 01eae0e8dc..5f46ef9a81 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -17,6 +17,7 @@
package org.apache.spark.util.collection.unsafe.sort;
+import java.lang.Long;
import java.util.Comparator;
import org.apache.avro.reflect.Nullable;
@@ -26,6 +27,7 @@ import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.util.collection.Sorter;
+import org.apache.spark.util.collection.unsafe.sort.RadixSort;
/**
* Sorts records using an AlphaSort-style key-prefix sort. This sort stores pointers to records
@@ -74,6 +76,17 @@ public final class UnsafeInMemorySorter {
private final Comparator<RecordPointerAndKeyPrefix> sortComparator;
/**
+ * If non-null, specifies the radix sort parameters and that radix sort will be used.
+ */
+ @Nullable
+ private final PrefixComparators.RadixSortSupport radixSortSupport;
+
+ /**
+ * Set to 2x for radix sort to reserve extra memory for sorting, otherwise 1x.
+ */
+ private final int memoryAllocationFactor;
+
+ /**
* Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at
* index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix.
*/
@@ -91,27 +104,36 @@ public final class UnsafeInMemorySorter {
final TaskMemoryManager memoryManager,
final RecordComparator recordComparator,
final PrefixComparator prefixComparator,
- int initialSize) {
+ int initialSize,
+ boolean canUseRadixSort) {
this(consumer, memoryManager, recordComparator, prefixComparator,
- consumer.allocateArray(initialSize * 2));
+ consumer.allocateArray(initialSize * 2), canUseRadixSort);
}
public UnsafeInMemorySorter(
- final MemoryConsumer consumer,
+ final MemoryConsumer consumer,
final TaskMemoryManager memoryManager,
final RecordComparator recordComparator,
final PrefixComparator prefixComparator,
- LongArray array) {
+ LongArray array,
+ boolean canUseRadixSort) {
this.consumer = consumer;
this.memoryManager = memoryManager;
this.initialSize = array.size();
if (recordComparator != null) {
this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE);
this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager);
+ if (canUseRadixSort && prefixComparator instanceof PrefixComparators.RadixSortSupport) {
+ this.radixSortSupport = (PrefixComparators.RadixSortSupport)prefixComparator;
+ } else {
+ this.radixSortSupport = null;
+ }
} else {
this.sorter = null;
this.sortComparator = null;
+ this.radixSortSupport = null;
}
+ this.memoryAllocationFactor = this.radixSortSupport != null ? 2 : 1;
this.array = array;
}
@@ -141,11 +163,11 @@ public final class UnsafeInMemorySorter {
}
public long getMemoryUsage() {
- return array.size() * 8L;
+ return array.size() * Long.BYTES;
}
public boolean hasSpaceForAnotherRecord() {
- return pos + 2 <= array.size();
+ return pos + 1 < (array.size() / memoryAllocationFactor);
}
public void expandPointerArray(LongArray newArray) {
@@ -157,7 +179,7 @@ public final class UnsafeInMemorySorter {
array.getBaseOffset(),
newArray.getBaseObject(),
newArray.getBaseOffset(),
- array.size() * 8L);
+ array.size() * (Long.BYTES / memoryAllocationFactor));
consumer.freeArray(array);
array = newArray;
}
@@ -183,18 +205,20 @@ public final class UnsafeInMemorySorter {
private final int numRecords;
private int position;
+ private int offset;
private Object baseObject;
private long baseOffset;
private long keyPrefix;
private int recordLength;
- private SortedIterator(int numRecords) {
+ private SortedIterator(int numRecords, int offset) {
this.numRecords = numRecords;
this.position = 0;
+ this.offset = offset;
}
public SortedIterator clone() {
- SortedIterator iter = new SortedIterator(numRecords);
+ SortedIterator iter = new SortedIterator(numRecords, offset);
iter.position = position;
iter.baseObject = baseObject;
iter.baseOffset = baseOffset;
@@ -216,11 +240,11 @@ public final class UnsafeInMemorySorter {
@Override
public void loadNext() {
// This pointer points to a 4-byte record length, followed by the record's bytes
- final long recordPointer = array.get(position);
+ final long recordPointer = array.get(offset + position);
baseObject = memoryManager.getPage(recordPointer);
baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4; // Skip over record length
recordLength = Platform.getInt(baseObject, baseOffset - 4);
- keyPrefix = array.get(position + 1);
+ keyPrefix = array.get(offset + position + 1);
position += 2;
}
@@ -242,9 +266,17 @@ public final class UnsafeInMemorySorter {
* {@code next()} will return the same mutable object.
*/
public SortedIterator getSortedIterator() {
+ int offset = 0;
if (sorter != null) {
- sorter.sort(array, 0, pos / 2, sortComparator);
+ if (this.radixSortSupport != null) {
+ // TODO(ekl) we should handle NULL values before radix sort for efficiency, since they
+ // force a full-width sort (and we cannot radix-sort nullable long fields at all).
+ offset = RadixSort.sortKeyPrefixArray(
+ array, pos / 2, 0, 7, radixSortSupport.sortDescending(), radixSortSupport.sortSigned());
+ } else {
+ sorter.sort(array, 0, pos / 2, sortComparator);
+ }
}
- return new SortedIterator(pos / 2);
+ return new SortedIterator(pos / 2, offset);
}
}