diff options
Diffstat (limited to 'core/src/main')
7 files changed, 411 insertions, 84 deletions
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java b/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java index f7a6c68be9..b36da80dbc 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java @@ -42,6 +42,16 @@ final class PackedRecordPointer { */ static final int MAXIMUM_PARTITION_ID = (1 << 24) - 1; // 16777215 + /** + * The index of the first byte of the partition id, counting from the least significant byte. + */ + static final int PARTITION_ID_START_BYTE_INDEX = 5; + + /** + * The index of the last byte of the partition id, counting from the least significant byte. + */ + static final int PARTITION_ID_END_BYTE_INDEX = 7; + /** Bit mask for the lower 40 bits of a long. */ private static final long MASK_LONG_LOWER_40_BITS = (1L << 40) - 1; diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index 3c2980e442..c4041a97e8 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -115,7 +115,8 @@ final class ShuffleExternalSorter extends MemoryConsumer { this.numElementsForSpillThreshold = conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MAX_VALUE); this.writeMetrics = writeMetrics; - this.inMemSorter = new ShuffleInMemorySorter(this, initialSize); + this.inMemSorter = new ShuffleInMemorySorter( + this, initialSize, conf.getBoolean("spark.shuffle.sort.useRadixSort", true)); this.peakMemoryUsedBytes = getMemoryUsage(); } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java index 76b0e6a304..68630946ac 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java @@ -17,12 +17,14 @@ package org.apache.spark.shuffle.sort; +import java.lang.Long; import java.util.Comparator; import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.util.collection.Sorter; +import org.apache.spark.util.collection.unsafe.sort.RadixSort; final class ShuffleInMemorySorter { @@ -47,16 +49,29 @@ final class ShuffleInMemorySorter { private LongArray array; /** + * Whether to use radix sort for sorting in-memory partition ids. Radix sort is much faster + * but requires additional memory to be reserved memory as pointers are added. + */ + private final boolean useRadixSort; + + /** + * Set to 2x for radix sort to reserve extra memory for sorting, otherwise 1x. + */ + private final int memoryAllocationFactor; + + /** * The position in the pointer array where new records can be inserted. */ private int pos = 0; private int initialSize; - ShuffleInMemorySorter(MemoryConsumer consumer, int initialSize) { + ShuffleInMemorySorter(MemoryConsumer consumer, int initialSize, boolean useRadixSort) { this.consumer = consumer; assert (initialSize > 0); this.initialSize = initialSize; + this.useRadixSort = useRadixSort; + this.memoryAllocationFactor = useRadixSort ? 2 : 1; this.array = consumer.allocateArray(initialSize); this.sorter = new Sorter<>(ShuffleSortDataFormat.INSTANCE); } @@ -87,18 +102,18 @@ final class ShuffleInMemorySorter { array.getBaseOffset(), newArray.getBaseObject(), newArray.getBaseOffset(), - array.size() * 8L + array.size() * (Long.BYTES / memoryAllocationFactor) ); consumer.freeArray(array); array = newArray; } public boolean hasSpaceForAnotherRecord() { - return pos < array.size(); + return pos < array.size() / memoryAllocationFactor; } public long getMemoryUsage() { - return array.size() * 8L; + return array.size() * Long.BYTES; } /** @@ -125,17 +140,18 @@ final class ShuffleInMemorySorter { public static final class ShuffleSorterIterator { private final LongArray pointerArray; - private final int numRecords; + private final int limit; final PackedRecordPointer packedRecordPointer = new PackedRecordPointer(); private int position = 0; - ShuffleSorterIterator(int numRecords, LongArray pointerArray) { - this.numRecords = numRecords; + ShuffleSorterIterator(int numRecords, LongArray pointerArray, int startingPosition) { + this.limit = numRecords + startingPosition; this.pointerArray = pointerArray; + this.position = startingPosition; } public boolean hasNext() { - return position < numRecords; + return position < limit; } public void loadNext() { @@ -148,7 +164,15 @@ final class ShuffleInMemorySorter { * Return an iterator over record pointers in sorted order. */ public ShuffleSorterIterator getSortedIterator() { - sorter.sort(array, 0, pos, SORT_COMPARATOR); - return new ShuffleSorterIterator(pos, array); + int offset = 0; + if (useRadixSort) { + offset = RadixSort.sort( + array, pos, + PackedRecordPointer.PARTITION_ID_START_BYTE_INDEX, + PackedRecordPointer.PARTITION_ID_END_BYTE_INDEX, false, false); + } else { + sorter.sort(array, 0, pos, SORT_COMPARATOR); + } + return new ShuffleSorterIterator(pos, array, offset); } } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java index c2a8f429be..21f2fde79d 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java @@ -28,88 +28,92 @@ import org.apache.spark.util.Utils; public class PrefixComparators { private PrefixComparators() {} - public static final StringPrefixComparator STRING = new StringPrefixComparator(); - public static final StringPrefixComparatorDesc STRING_DESC = new StringPrefixComparatorDesc(); - public static final BinaryPrefixComparator BINARY = new BinaryPrefixComparator(); - public static final BinaryPrefixComparatorDesc BINARY_DESC = new BinaryPrefixComparatorDesc(); - public static final LongPrefixComparator LONG = new LongPrefixComparator(); - public static final LongPrefixComparatorDesc LONG_DESC = new LongPrefixComparatorDesc(); - public static final DoublePrefixComparator DOUBLE = new DoublePrefixComparator(); - public static final DoublePrefixComparatorDesc DOUBLE_DESC = new DoublePrefixComparatorDesc(); - - public static final class StringPrefixComparator extends PrefixComparator { - @Override - public int compare(long aPrefix, long bPrefix) { - return UnsignedLongs.compare(aPrefix, bPrefix); - } - + public static final PrefixComparator STRING = new UnsignedPrefixComparator(); + public static final PrefixComparator STRING_DESC = new UnsignedPrefixComparatorDesc(); + public static final PrefixComparator BINARY = new UnsignedPrefixComparator(); + public static final PrefixComparator BINARY_DESC = new UnsignedPrefixComparatorDesc(); + public static final PrefixComparator LONG = new SignedPrefixComparator(); + public static final PrefixComparator LONG_DESC = new SignedPrefixComparatorDesc(); + public static final PrefixComparator DOUBLE = new UnsignedPrefixComparator(); + public static final PrefixComparator DOUBLE_DESC = new UnsignedPrefixComparatorDesc(); + + public static final class StringPrefixComparator { public static long computePrefix(UTF8String value) { return value == null ? 0L : value.getPrefix(); } } - public static final class StringPrefixComparatorDesc extends PrefixComparator { - @Override - public int compare(long bPrefix, long aPrefix) { - return UnsignedLongs.compare(aPrefix, bPrefix); + public static final class BinaryPrefixComparator { + public static long computePrefix(byte[] bytes) { + return ByteArray.getPrefix(bytes); } } - public static final class BinaryPrefixComparator extends PrefixComparator { - @Override - public int compare(long aPrefix, long bPrefix) { - return UnsignedLongs.compare(aPrefix, bPrefix); + public static final class DoublePrefixComparator { + /** + * Converts the double into a value that compares correctly as an unsigned long. For more + * details see http://stereopsis.com/radix.html. + */ + public static long computePrefix(double value) { + // Java's doubleToLongBits already canonicalizes all NaN values to the smallest possible + // positive NaN, so there's nothing special we need to do for NaNs. + long bits = Double.doubleToLongBits(value); + // Negative floats compare backwards due to their sign-magnitude representation, so flip + // all the bits in this case. + long mask = -(bits >>> 63) | 0x8000000000000000L; + return bits ^ mask; } + } - public static long computePrefix(byte[] bytes) { - return ByteArray.getPrefix(bytes); - } + /** + * Provides radix sort parameters. Comparators implementing this also are indicating that the + * ordering they define is compatible with radix sort. + */ + public static abstract class RadixSortSupport extends PrefixComparator { + /** @return Whether the sort should be descending in binary sort order. */ + public abstract boolean sortDescending(); + + /** @return Whether the sort should take into account the sign bit. */ + public abstract boolean sortSigned(); } - public static final class BinaryPrefixComparatorDesc extends PrefixComparator { + // + // Standard prefix comparator implementations + // + + public static final class UnsignedPrefixComparator extends RadixSortSupport { + @Override public final boolean sortDescending() { return false; } + @Override public final boolean sortSigned() { return false; } @Override - public int compare(long bPrefix, long aPrefix) { + public final int compare(long aPrefix, long bPrefix) { return UnsignedLongs.compare(aPrefix, bPrefix); } } - public static final class LongPrefixComparator extends PrefixComparator { + public static final class UnsignedPrefixComparatorDesc extends RadixSortSupport { + @Override public final boolean sortDescending() { return true; } + @Override public final boolean sortSigned() { return false; } @Override - public int compare(long a, long b) { - return (a < b) ? -1 : (a > b) ? 1 : 0; + public final int compare(long bPrefix, long aPrefix) { + return UnsignedLongs.compare(aPrefix, bPrefix); } } - public static final class LongPrefixComparatorDesc extends PrefixComparator { + public static final class SignedPrefixComparator extends RadixSortSupport { + @Override public final boolean sortDescending() { return false; } + @Override public final boolean sortSigned() { return true; } @Override - public int compare(long b, long a) { + public final int compare(long a, long b) { return (a < b) ? -1 : (a > b) ? 1 : 0; } } - public static final class DoublePrefixComparator extends PrefixComparator { + public static final class SignedPrefixComparatorDesc extends RadixSortSupport { + @Override public final boolean sortDescending() { return true; } + @Override public final boolean sortSigned() { return true; } @Override - public int compare(long aPrefix, long bPrefix) { - double a = Double.longBitsToDouble(aPrefix); - double b = Double.longBitsToDouble(bPrefix); - return Utils.nanSafeCompareDoubles(a, b); - } - - public static long computePrefix(double value) { - return Double.doubleToLongBits(value); - } - } - - public static final class DoublePrefixComparatorDesc extends PrefixComparator { - @Override - public int compare(long bPrefix, long aPrefix) { - double a = Double.longBitsToDouble(aPrefix); - double b = Double.longBitsToDouble(bPrefix); - return Utils.nanSafeCompareDoubles(a, b); - } - - public static long computePrefix(double value) { - return Double.doubleToLongBits(value); + public final int compare(long b, long a) { + return (a < b) ? -1 : (a > b) ? 1 : 0; } } } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java new file mode 100644 index 0000000000..3357b8e474 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java @@ -0,0 +1,253 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.LongArray; + +public class RadixSort { + + /** + * Sorts a given array of longs using least-significant-digit radix sort. This routine assumes + * you have extra space at the end of the array at least equal to the number of records. The + * sort is destructive and may relocate the data positioned within the array. + * + * @param array array of long elements followed by at least that many empty slots. + * @param numRecords number of data records in the array. + * @param startByteIndex the first byte (in range [0, 7]) to sort each long by, counting from the + * least significant byte. + * @param endByteIndex the last byte (in range [0, 7]) to sort each long by, counting from the + * least significant byte. Must be greater than startByteIndex. + * @param desc whether this is a descending (binary-order) sort. + * @param signed whether this is a signed (two's complement) sort. + * + * @return The starting index of the sorted data within the given array. We return this instead + * of always copying the data back to position zero for efficiency. + */ + public static int sort( + LongArray array, int numRecords, int startByteIndex, int endByteIndex, + boolean desc, boolean signed) { + assert startByteIndex >= 0 : "startByteIndex (" + startByteIndex + ") should >= 0"; + assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7"; + assert endByteIndex > startByteIndex; + assert numRecords * 2 <= array.size(); + int inIndex = 0; + int outIndex = numRecords; + if (numRecords > 0) { + long[][] counts = getCounts(array, numRecords, startByteIndex, endByteIndex); + for (int i = startByteIndex; i <= endByteIndex; i++) { + if (counts[i] != null) { + sortAtByte( + array, numRecords, counts[i], i, inIndex, outIndex, + desc, signed && i == endByteIndex); + int tmp = inIndex; + inIndex = outIndex; + outIndex = tmp; + } + } + } + return inIndex; + } + + /** + * Performs a partial sort by copying data into destination offsets for each byte value at the + * specified byte offset. + * + * @param array array to partially sort. + * @param numRecords number of data records in the array. + * @param counts counts for each byte value. This routine destructively modifies this array. + * @param byteIdx the byte in a long to sort at, counting from the least significant byte. + * @param inIndex the starting index in the array where input data is located. + * @param outIndex the starting index where sorted output data should be written. + * @param desc whether this is a descending (binary-order) sort. + * @param signed whether this is a signed (two's complement) sort (only applies to last byte). + */ + private static void sortAtByte( + LongArray array, int numRecords, long[] counts, int byteIdx, int inIndex, int outIndex, + boolean desc, boolean signed) { + assert counts.length == 256; + long[] offsets = transformCountsToOffsets( + counts, numRecords, array.getBaseOffset() + outIndex * 8, 8, desc, signed); + Object baseObject = array.getBaseObject(); + long baseOffset = array.getBaseOffset() + inIndex * 8; + long maxOffset = baseOffset + numRecords * 8; + for (long offset = baseOffset; offset < maxOffset; offset += 8) { + long value = Platform.getLong(baseObject, offset); + int bucket = (int)((value >>> (byteIdx * 8)) & 0xff); + Platform.putLong(baseObject, offsets[bucket], value); + offsets[bucket] += 8; + } + } + + /** + * Computes a value histogram for each byte in the given array. + * + * @param array array to count records in. + * @param numRecords number of data records in the array. + * @param startByteIndex the first byte to compute counts for (the prior are skipped). + * @param endByteIndex the last byte to compute counts for. + * + * @return an array of eight 256-byte count arrays, one for each byte starting from the least + * significant byte. If the byte does not need sorting the array will be null. + */ + private static long[][] getCounts( + LongArray array, int numRecords, int startByteIndex, int endByteIndex) { + long[][] counts = new long[8][]; + // Optimization: do a fast pre-pass to determine which byte indices we can skip for sorting. + // If all the byte values at a particular index are the same we don't need to count it. + long bitwiseMax = 0; + long bitwiseMin = -1L; + long maxOffset = array.getBaseOffset() + numRecords * 8; + Object baseObject = array.getBaseObject(); + for (long offset = array.getBaseOffset(); offset < maxOffset; offset += 8) { + long value = Platform.getLong(baseObject, offset); + bitwiseMax |= value; + bitwiseMin &= value; + } + long bitsChanged = bitwiseMin ^ bitwiseMax; + // Compute counts for each byte index. + for (int i = startByteIndex; i <= endByteIndex; i++) { + if (((bitsChanged >>> (i * 8)) & 0xff) != 0) { + counts[i] = new long[256]; + // TODO(ekl) consider computing all the counts in one pass. + for (long offset = array.getBaseOffset(); offset < maxOffset; offset += 8) { + counts[i][(int)((Platform.getLong(baseObject, offset) >>> (i * 8)) & 0xff)]++; + } + } + } + return counts; + } + + /** + * Transforms counts into the proper unsafe output offsets for the sort type. + * + * @param counts counts for each byte value. This routine destructively modifies this array. + * @param numRecords number of data records in the original data array. + * @param outputOffset output offset in bytes from the base array object. + * @param bytesPerRecord size of each record (8 for plain sort, 16 for key-prefix sort). + * @param desc whether this is a descending (binary-order) sort. + * @param signed whether this is a signed (two's complement) sort. + * + * @return the input counts array. + */ + private static long[] transformCountsToOffsets( + long[] counts, int numRecords, long outputOffset, int bytesPerRecord, + boolean desc, boolean signed) { + assert counts.length == 256; + int start = signed ? 128 : 0; // output the negative records first (values 129-255). + if (desc) { + int pos = numRecords; + for (int i = start; i < start + 256; i++) { + pos -= counts[i & 0xff]; + counts[i & 0xff] = outputOffset + pos * bytesPerRecord; + } + } else { + int pos = 0; + for (int i = start; i < start + 256; i++) { + long tmp = counts[i & 0xff]; + counts[i & 0xff] = outputOffset + pos * bytesPerRecord; + pos += tmp; + } + } + return counts; + } + + /** + * Specialization of sort() for key-prefix arrays. In this type of array, each record consists + * of two longs, only the second of which is sorted on. + */ + public static int sortKeyPrefixArray( + LongArray array, + int numRecords, + int startByteIndex, + int endByteIndex, + boolean desc, + boolean signed) { + assert startByteIndex >= 0 : "startByteIndex (" + startByteIndex + ") should >= 0"; + assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7"; + assert endByteIndex > startByteIndex; + assert numRecords * 4 <= array.size(); + int inIndex = 0; + int outIndex = numRecords * 2; + if (numRecords > 0) { + long[][] counts = getKeyPrefixArrayCounts(array, numRecords, startByteIndex, endByteIndex); + for (int i = startByteIndex; i <= endByteIndex; i++) { + if (counts[i] != null) { + sortKeyPrefixArrayAtByte( + array, numRecords, counts[i], i, inIndex, outIndex, + desc, signed && i == endByteIndex); + int tmp = inIndex; + inIndex = outIndex; + outIndex = tmp; + } + } + } + return inIndex; + } + + /** + * Specialization of getCounts() for key-prefix arrays. We could probably combine this with + * getCounts with some added parameters but that seems to hurt in benchmarks. + */ + private static long[][] getKeyPrefixArrayCounts( + LongArray array, int numRecords, int startByteIndex, int endByteIndex) { + long[][] counts = new long[8][]; + long bitwiseMax = 0; + long bitwiseMin = -1L; + long limit = array.getBaseOffset() + numRecords * 16; + Object baseObject = array.getBaseObject(); + for (long offset = array.getBaseOffset(); offset < limit; offset += 16) { + long value = Platform.getLong(baseObject, offset + 8); + bitwiseMax |= value; + bitwiseMin &= value; + } + long bitsChanged = bitwiseMin ^ bitwiseMax; + for (int i = startByteIndex; i <= endByteIndex; i++) { + if (((bitsChanged >>> (i * 8)) & 0xff) != 0) { + counts[i] = new long[256]; + for (long offset = array.getBaseOffset(); offset < limit; offset += 16) { + counts[i][(int)((Platform.getLong(baseObject, offset + 8) >>> (i * 8)) & 0xff)]++; + } + } + } + return counts; + } + + /** + * Specialization of sortAtByte() for key-prefix arrays. + */ + private static void sortKeyPrefixArrayAtByte( + LongArray array, int numRecords, long[] counts, int byteIdx, int inIndex, int outIndex, + boolean desc, boolean signed) { + assert counts.length == 256; + long[] offsets = transformCountsToOffsets( + counts, numRecords, array.getBaseOffset() + outIndex * 8, 16, desc, signed); + Object baseObject = array.getBaseObject(); + long baseOffset = array.getBaseOffset() + inIndex * 8; + long maxOffset = baseOffset + numRecords * 16; + for (long offset = baseOffset; offset < maxOffset; offset += 16) { + long key = Platform.getLong(baseObject, offset); + long prefix = Platform.getLong(baseObject, offset + 8); + int bucket = (int)((prefix >>> (byteIdx * 8)) & 0xff); + long dest = offsets[bucket]; + Platform.putLong(baseObject, dest, key); + Platform.putLong(baseObject, dest + 8, prefix); + offsets[bucket] += 16; + } + } +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 3e32dd9d63..66a77982ad 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -89,7 +89,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer { UnsafeInMemorySorter inMemorySorter) throws IOException { UnsafeExternalSorter sorter = new UnsafeExternalSorter(taskMemoryManager, blockManager, serializerManager, taskContext, recordComparator, prefixComparator, initialSize, - pageSizeBytes, inMemorySorter); + pageSizeBytes, inMemorySorter, false /* ignored */); sorter.spill(Long.MAX_VALUE, sorter); // The external sorter will be used to insert records, in-memory sorter is not needed. sorter.inMemSorter = null; @@ -104,9 +104,11 @@ public final class UnsafeExternalSorter extends MemoryConsumer { RecordComparator recordComparator, PrefixComparator prefixComparator, int initialSize, - long pageSizeBytes) { + long pageSizeBytes, + boolean canUseRadixSort) { return new UnsafeExternalSorter(taskMemoryManager, blockManager, serializerManager, - taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, null); + taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, null, + canUseRadixSort); } private UnsafeExternalSorter( @@ -118,7 +120,8 @@ public final class UnsafeExternalSorter extends MemoryConsumer { PrefixComparator prefixComparator, int initialSize, long pageSizeBytes, - @Nullable UnsafeInMemorySorter existingInMemorySorter) { + @Nullable UnsafeInMemorySorter existingInMemorySorter, + boolean canUseRadixSort) { super(taskMemoryManager, pageSizeBytes); this.taskMemoryManager = taskMemoryManager; this.blockManager = blockManager; @@ -133,7 +136,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer { if (existingInMemorySorter == null) { this.inMemSorter = new UnsafeInMemorySorter( - this, taskMemoryManager, recordComparator, prefixComparator, initialSize); + this, taskMemoryManager, recordComparator, prefixComparator, initialSize, canUseRadixSort); } else { this.inMemSorter = existingInMemorySorter; } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 01eae0e8dc..5f46ef9a81 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -17,6 +17,7 @@ package org.apache.spark.util.collection.unsafe.sort; +import java.lang.Long; import java.util.Comparator; import org.apache.avro.reflect.Nullable; @@ -26,6 +27,7 @@ import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.util.collection.Sorter; +import org.apache.spark.util.collection.unsafe.sort.RadixSort; /** * Sorts records using an AlphaSort-style key-prefix sort. This sort stores pointers to records @@ -74,6 +76,17 @@ public final class UnsafeInMemorySorter { private final Comparator<RecordPointerAndKeyPrefix> sortComparator; /** + * If non-null, specifies the radix sort parameters and that radix sort will be used. + */ + @Nullable + private final PrefixComparators.RadixSortSupport radixSortSupport; + + /** + * Set to 2x for radix sort to reserve extra memory for sorting, otherwise 1x. + */ + private final int memoryAllocationFactor; + + /** * Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. */ @@ -91,27 +104,36 @@ public final class UnsafeInMemorySorter { final TaskMemoryManager memoryManager, final RecordComparator recordComparator, final PrefixComparator prefixComparator, - int initialSize) { + int initialSize, + boolean canUseRadixSort) { this(consumer, memoryManager, recordComparator, prefixComparator, - consumer.allocateArray(initialSize * 2)); + consumer.allocateArray(initialSize * 2), canUseRadixSort); } public UnsafeInMemorySorter( - final MemoryConsumer consumer, + final MemoryConsumer consumer, final TaskMemoryManager memoryManager, final RecordComparator recordComparator, final PrefixComparator prefixComparator, - LongArray array) { + LongArray array, + boolean canUseRadixSort) { this.consumer = consumer; this.memoryManager = memoryManager; this.initialSize = array.size(); if (recordComparator != null) { this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE); this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager); + if (canUseRadixSort && prefixComparator instanceof PrefixComparators.RadixSortSupport) { + this.radixSortSupport = (PrefixComparators.RadixSortSupport)prefixComparator; + } else { + this.radixSortSupport = null; + } } else { this.sorter = null; this.sortComparator = null; + this.radixSortSupport = null; } + this.memoryAllocationFactor = this.radixSortSupport != null ? 2 : 1; this.array = array; } @@ -141,11 +163,11 @@ public final class UnsafeInMemorySorter { } public long getMemoryUsage() { - return array.size() * 8L; + return array.size() * Long.BYTES; } public boolean hasSpaceForAnotherRecord() { - return pos + 2 <= array.size(); + return pos + 1 < (array.size() / memoryAllocationFactor); } public void expandPointerArray(LongArray newArray) { @@ -157,7 +179,7 @@ public final class UnsafeInMemorySorter { array.getBaseOffset(), newArray.getBaseObject(), newArray.getBaseOffset(), - array.size() * 8L); + array.size() * (Long.BYTES / memoryAllocationFactor)); consumer.freeArray(array); array = newArray; } @@ -183,18 +205,20 @@ public final class UnsafeInMemorySorter { private final int numRecords; private int position; + private int offset; private Object baseObject; private long baseOffset; private long keyPrefix; private int recordLength; - private SortedIterator(int numRecords) { + private SortedIterator(int numRecords, int offset) { this.numRecords = numRecords; this.position = 0; + this.offset = offset; } public SortedIterator clone() { - SortedIterator iter = new SortedIterator(numRecords); + SortedIterator iter = new SortedIterator(numRecords, offset); iter.position = position; iter.baseObject = baseObject; iter.baseOffset = baseOffset; @@ -216,11 +240,11 @@ public final class UnsafeInMemorySorter { @Override public void loadNext() { // This pointer points to a 4-byte record length, followed by the record's bytes - final long recordPointer = array.get(position); + final long recordPointer = array.get(offset + position); baseObject = memoryManager.getPage(recordPointer); baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4; // Skip over record length recordLength = Platform.getInt(baseObject, baseOffset - 4); - keyPrefix = array.get(position + 1); + keyPrefix = array.get(offset + position + 1); position += 2; } @@ -242,9 +266,17 @@ public final class UnsafeInMemorySorter { * {@code next()} will return the same mutable object. */ public SortedIterator getSortedIterator() { + int offset = 0; if (sorter != null) { - sorter.sort(array, 0, pos / 2, sortComparator); + if (this.radixSortSupport != null) { + // TODO(ekl) we should handle NULL values before radix sort for efficiency, since they + // force a full-width sort (and we cannot radix-sort nullable long fields at all). + offset = RadixSort.sortKeyPrefixArray( + array, pos / 2, 0, 7, radixSortSupport.sortDescending(), radixSortSupport.sortSigned()); + } else { + sorter.sort(array, 0, pos / 2, sortComparator); + } } - return new SortedIterator(pos / 2); + return new SortedIterator(pos / 2, offset); } } |