diff options
24 files changed, 876 insertions, 119 deletions
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java b/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java index f7a6c68be9..b36da80dbc 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java @@ -42,6 +42,16 @@ final class PackedRecordPointer { */ static final int MAXIMUM_PARTITION_ID = (1 << 24) - 1; // 16777215 + /** + * The index of the first byte of the partition id, counting from the least significant byte. + */ + static final int PARTITION_ID_START_BYTE_INDEX = 5; + + /** + * The index of the last byte of the partition id, counting from the least significant byte. + */ + static final int PARTITION_ID_END_BYTE_INDEX = 7; + /** Bit mask for the lower 40 bits of a long. */ private static final long MASK_LONG_LOWER_40_BITS = (1L << 40) - 1; diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index 3c2980e442..c4041a97e8 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -115,7 +115,8 @@ final class ShuffleExternalSorter extends MemoryConsumer { this.numElementsForSpillThreshold = conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MAX_VALUE); this.writeMetrics = writeMetrics; - this.inMemSorter = new ShuffleInMemorySorter(this, initialSize); + this.inMemSorter = new ShuffleInMemorySorter( + this, initialSize, conf.getBoolean("spark.shuffle.sort.useRadixSort", true)); this.peakMemoryUsedBytes = getMemoryUsage(); } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java index 76b0e6a304..68630946ac 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java @@ -17,12 +17,14 @@ package org.apache.spark.shuffle.sort; +import java.lang.Long; import java.util.Comparator; import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.util.collection.Sorter; +import org.apache.spark.util.collection.unsafe.sort.RadixSort; final class ShuffleInMemorySorter { @@ -47,16 +49,29 @@ final class ShuffleInMemorySorter { private LongArray array; /** + * Whether to use radix sort for sorting in-memory partition ids. Radix sort is much faster + * but requires additional memory to be reserved memory as pointers are added. + */ + private final boolean useRadixSort; + + /** + * Set to 2x for radix sort to reserve extra memory for sorting, otherwise 1x. + */ + private final int memoryAllocationFactor; + + /** * The position in the pointer array where new records can be inserted. */ private int pos = 0; private int initialSize; - ShuffleInMemorySorter(MemoryConsumer consumer, int initialSize) { + ShuffleInMemorySorter(MemoryConsumer consumer, int initialSize, boolean useRadixSort) { this.consumer = consumer; assert (initialSize > 0); this.initialSize = initialSize; + this.useRadixSort = useRadixSort; + this.memoryAllocationFactor = useRadixSort ? 2 : 1; this.array = consumer.allocateArray(initialSize); this.sorter = new Sorter<>(ShuffleSortDataFormat.INSTANCE); } @@ -87,18 +102,18 @@ final class ShuffleInMemorySorter { array.getBaseOffset(), newArray.getBaseObject(), newArray.getBaseOffset(), - array.size() * 8L + array.size() * (Long.BYTES / memoryAllocationFactor) ); consumer.freeArray(array); array = newArray; } public boolean hasSpaceForAnotherRecord() { - return pos < array.size(); + return pos < array.size() / memoryAllocationFactor; } public long getMemoryUsage() { - return array.size() * 8L; + return array.size() * Long.BYTES; } /** @@ -125,17 +140,18 @@ final class ShuffleInMemorySorter { public static final class ShuffleSorterIterator { private final LongArray pointerArray; - private final int numRecords; + private final int limit; final PackedRecordPointer packedRecordPointer = new PackedRecordPointer(); private int position = 0; - ShuffleSorterIterator(int numRecords, LongArray pointerArray) { - this.numRecords = numRecords; + ShuffleSorterIterator(int numRecords, LongArray pointerArray, int startingPosition) { + this.limit = numRecords + startingPosition; this.pointerArray = pointerArray; + this.position = startingPosition; } public boolean hasNext() { - return position < numRecords; + return position < limit; } public void loadNext() { @@ -148,7 +164,15 @@ final class ShuffleInMemorySorter { * Return an iterator over record pointers in sorted order. */ public ShuffleSorterIterator getSortedIterator() { - sorter.sort(array, 0, pos, SORT_COMPARATOR); - return new ShuffleSorterIterator(pos, array); + int offset = 0; + if (useRadixSort) { + offset = RadixSort.sort( + array, pos, + PackedRecordPointer.PARTITION_ID_START_BYTE_INDEX, + PackedRecordPointer.PARTITION_ID_END_BYTE_INDEX, false, false); + } else { + sorter.sort(array, 0, pos, SORT_COMPARATOR); + } + return new ShuffleSorterIterator(pos, array, offset); } } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java index c2a8f429be..21f2fde79d 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java @@ -28,88 +28,92 @@ import org.apache.spark.util.Utils; public class PrefixComparators { private PrefixComparators() {} - public static final StringPrefixComparator STRING = new StringPrefixComparator(); - public static final StringPrefixComparatorDesc STRING_DESC = new StringPrefixComparatorDesc(); - public static final BinaryPrefixComparator BINARY = new BinaryPrefixComparator(); - public static final BinaryPrefixComparatorDesc BINARY_DESC = new BinaryPrefixComparatorDesc(); - public static final LongPrefixComparator LONG = new LongPrefixComparator(); - public static final LongPrefixComparatorDesc LONG_DESC = new LongPrefixComparatorDesc(); - public static final DoublePrefixComparator DOUBLE = new DoublePrefixComparator(); - public static final DoublePrefixComparatorDesc DOUBLE_DESC = new DoublePrefixComparatorDesc(); - - public static final class StringPrefixComparator extends PrefixComparator { - @Override - public int compare(long aPrefix, long bPrefix) { - return UnsignedLongs.compare(aPrefix, bPrefix); - } - + public static final PrefixComparator STRING = new UnsignedPrefixComparator(); + public static final PrefixComparator STRING_DESC = new UnsignedPrefixComparatorDesc(); + public static final PrefixComparator BINARY = new UnsignedPrefixComparator(); + public static final PrefixComparator BINARY_DESC = new UnsignedPrefixComparatorDesc(); + public static final PrefixComparator LONG = new SignedPrefixComparator(); + public static final PrefixComparator LONG_DESC = new SignedPrefixComparatorDesc(); + public static final PrefixComparator DOUBLE = new UnsignedPrefixComparator(); + public static final PrefixComparator DOUBLE_DESC = new UnsignedPrefixComparatorDesc(); + + public static final class StringPrefixComparator { public static long computePrefix(UTF8String value) { return value == null ? 0L : value.getPrefix(); } } - public static final class StringPrefixComparatorDesc extends PrefixComparator { - @Override - public int compare(long bPrefix, long aPrefix) { - return UnsignedLongs.compare(aPrefix, bPrefix); + public static final class BinaryPrefixComparator { + public static long computePrefix(byte[] bytes) { + return ByteArray.getPrefix(bytes); } } - public static final class BinaryPrefixComparator extends PrefixComparator { - @Override - public int compare(long aPrefix, long bPrefix) { - return UnsignedLongs.compare(aPrefix, bPrefix); + public static final class DoublePrefixComparator { + /** + * Converts the double into a value that compares correctly as an unsigned long. For more + * details see http://stereopsis.com/radix.html. + */ + public static long computePrefix(double value) { + // Java's doubleToLongBits already canonicalizes all NaN values to the smallest possible + // positive NaN, so there's nothing special we need to do for NaNs. + long bits = Double.doubleToLongBits(value); + // Negative floats compare backwards due to their sign-magnitude representation, so flip + // all the bits in this case. + long mask = -(bits >>> 63) | 0x8000000000000000L; + return bits ^ mask; } + } - public static long computePrefix(byte[] bytes) { - return ByteArray.getPrefix(bytes); - } + /** + * Provides radix sort parameters. Comparators implementing this also are indicating that the + * ordering they define is compatible with radix sort. + */ + public static abstract class RadixSortSupport extends PrefixComparator { + /** @return Whether the sort should be descending in binary sort order. */ + public abstract boolean sortDescending(); + + /** @return Whether the sort should take into account the sign bit. */ + public abstract boolean sortSigned(); } - public static final class BinaryPrefixComparatorDesc extends PrefixComparator { + // + // Standard prefix comparator implementations + // + + public static final class UnsignedPrefixComparator extends RadixSortSupport { + @Override public final boolean sortDescending() { return false; } + @Override public final boolean sortSigned() { return false; } @Override - public int compare(long bPrefix, long aPrefix) { + public final int compare(long aPrefix, long bPrefix) { return UnsignedLongs.compare(aPrefix, bPrefix); } } - public static final class LongPrefixComparator extends PrefixComparator { + public static final class UnsignedPrefixComparatorDesc extends RadixSortSupport { + @Override public final boolean sortDescending() { return true; } + @Override public final boolean sortSigned() { return false; } @Override - public int compare(long a, long b) { - return (a < b) ? -1 : (a > b) ? 1 : 0; + public final int compare(long bPrefix, long aPrefix) { + return UnsignedLongs.compare(aPrefix, bPrefix); } } - public static final class LongPrefixComparatorDesc extends PrefixComparator { + public static final class SignedPrefixComparator extends RadixSortSupport { + @Override public final boolean sortDescending() { return false; } + @Override public final boolean sortSigned() { return true; } @Override - public int compare(long b, long a) { + public final int compare(long a, long b) { return (a < b) ? -1 : (a > b) ? 1 : 0; } } - public static final class DoublePrefixComparator extends PrefixComparator { + public static final class SignedPrefixComparatorDesc extends RadixSortSupport { + @Override public final boolean sortDescending() { return true; } + @Override public final boolean sortSigned() { return true; } @Override - public int compare(long aPrefix, long bPrefix) { - double a = Double.longBitsToDouble(aPrefix); - double b = Double.longBitsToDouble(bPrefix); - return Utils.nanSafeCompareDoubles(a, b); - } - - public static long computePrefix(double value) { - return Double.doubleToLongBits(value); - } - } - - public static final class DoublePrefixComparatorDesc extends PrefixComparator { - @Override - public int compare(long bPrefix, long aPrefix) { - double a = Double.longBitsToDouble(aPrefix); - double b = Double.longBitsToDouble(bPrefix); - return Utils.nanSafeCompareDoubles(a, b); - } - - public static long computePrefix(double value) { - return Double.doubleToLongBits(value); + public final int compare(long b, long a) { + return (a < b) ? -1 : (a > b) ? 1 : 0; } } } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java new file mode 100644 index 0000000000..3357b8e474 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java @@ -0,0 +1,253 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.LongArray; + +public class RadixSort { + + /** + * Sorts a given array of longs using least-significant-digit radix sort. This routine assumes + * you have extra space at the end of the array at least equal to the number of records. The + * sort is destructive and may relocate the data positioned within the array. + * + * @param array array of long elements followed by at least that many empty slots. + * @param numRecords number of data records in the array. + * @param startByteIndex the first byte (in range [0, 7]) to sort each long by, counting from the + * least significant byte. + * @param endByteIndex the last byte (in range [0, 7]) to sort each long by, counting from the + * least significant byte. Must be greater than startByteIndex. + * @param desc whether this is a descending (binary-order) sort. + * @param signed whether this is a signed (two's complement) sort. + * + * @return The starting index of the sorted data within the given array. We return this instead + * of always copying the data back to position zero for efficiency. + */ + public static int sort( + LongArray array, int numRecords, int startByteIndex, int endByteIndex, + boolean desc, boolean signed) { + assert startByteIndex >= 0 : "startByteIndex (" + startByteIndex + ") should >= 0"; + assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7"; + assert endByteIndex > startByteIndex; + assert numRecords * 2 <= array.size(); + int inIndex = 0; + int outIndex = numRecords; + if (numRecords > 0) { + long[][] counts = getCounts(array, numRecords, startByteIndex, endByteIndex); + for (int i = startByteIndex; i <= endByteIndex; i++) { + if (counts[i] != null) { + sortAtByte( + array, numRecords, counts[i], i, inIndex, outIndex, + desc, signed && i == endByteIndex); + int tmp = inIndex; + inIndex = outIndex; + outIndex = tmp; + } + } + } + return inIndex; + } + + /** + * Performs a partial sort by copying data into destination offsets for each byte value at the + * specified byte offset. + * + * @param array array to partially sort. + * @param numRecords number of data records in the array. + * @param counts counts for each byte value. This routine destructively modifies this array. + * @param byteIdx the byte in a long to sort at, counting from the least significant byte. + * @param inIndex the starting index in the array where input data is located. + * @param outIndex the starting index where sorted output data should be written. + * @param desc whether this is a descending (binary-order) sort. + * @param signed whether this is a signed (two's complement) sort (only applies to last byte). + */ + private static void sortAtByte( + LongArray array, int numRecords, long[] counts, int byteIdx, int inIndex, int outIndex, + boolean desc, boolean signed) { + assert counts.length == 256; + long[] offsets = transformCountsToOffsets( + counts, numRecords, array.getBaseOffset() + outIndex * 8, 8, desc, signed); + Object baseObject = array.getBaseObject(); + long baseOffset = array.getBaseOffset() + inIndex * 8; + long maxOffset = baseOffset + numRecords * 8; + for (long offset = baseOffset; offset < maxOffset; offset += 8) { + long value = Platform.getLong(baseObject, offset); + int bucket = (int)((value >>> (byteIdx * 8)) & 0xff); + Platform.putLong(baseObject, offsets[bucket], value); + offsets[bucket] += 8; + } + } + + /** + * Computes a value histogram for each byte in the given array. + * + * @param array array to count records in. + * @param numRecords number of data records in the array. + * @param startByteIndex the first byte to compute counts for (the prior are skipped). + * @param endByteIndex the last byte to compute counts for. + * + * @return an array of eight 256-byte count arrays, one for each byte starting from the least + * significant byte. If the byte does not need sorting the array will be null. + */ + private static long[][] getCounts( + LongArray array, int numRecords, int startByteIndex, int endByteIndex) { + long[][] counts = new long[8][]; + // Optimization: do a fast pre-pass to determine which byte indices we can skip for sorting. + // If all the byte values at a particular index are the same we don't need to count it. + long bitwiseMax = 0; + long bitwiseMin = -1L; + long maxOffset = array.getBaseOffset() + numRecords * 8; + Object baseObject = array.getBaseObject(); + for (long offset = array.getBaseOffset(); offset < maxOffset; offset += 8) { + long value = Platform.getLong(baseObject, offset); + bitwiseMax |= value; + bitwiseMin &= value; + } + long bitsChanged = bitwiseMin ^ bitwiseMax; + // Compute counts for each byte index. + for (int i = startByteIndex; i <= endByteIndex; i++) { + if (((bitsChanged >>> (i * 8)) & 0xff) != 0) { + counts[i] = new long[256]; + // TODO(ekl) consider computing all the counts in one pass. + for (long offset = array.getBaseOffset(); offset < maxOffset; offset += 8) { + counts[i][(int)((Platform.getLong(baseObject, offset) >>> (i * 8)) & 0xff)]++; + } + } + } + return counts; + } + + /** + * Transforms counts into the proper unsafe output offsets for the sort type. + * + * @param counts counts for each byte value. This routine destructively modifies this array. + * @param numRecords number of data records in the original data array. + * @param outputOffset output offset in bytes from the base array object. + * @param bytesPerRecord size of each record (8 for plain sort, 16 for key-prefix sort). + * @param desc whether this is a descending (binary-order) sort. + * @param signed whether this is a signed (two's complement) sort. + * + * @return the input counts array. + */ + private static long[] transformCountsToOffsets( + long[] counts, int numRecords, long outputOffset, int bytesPerRecord, + boolean desc, boolean signed) { + assert counts.length == 256; + int start = signed ? 128 : 0; // output the negative records first (values 129-255). + if (desc) { + int pos = numRecords; + for (int i = start; i < start + 256; i++) { + pos -= counts[i & 0xff]; + counts[i & 0xff] = outputOffset + pos * bytesPerRecord; + } + } else { + int pos = 0; + for (int i = start; i < start + 256; i++) { + long tmp = counts[i & 0xff]; + counts[i & 0xff] = outputOffset + pos * bytesPerRecord; + pos += tmp; + } + } + return counts; + } + + /** + * Specialization of sort() for key-prefix arrays. In this type of array, each record consists + * of two longs, only the second of which is sorted on. + */ + public static int sortKeyPrefixArray( + LongArray array, + int numRecords, + int startByteIndex, + int endByteIndex, + boolean desc, + boolean signed) { + assert startByteIndex >= 0 : "startByteIndex (" + startByteIndex + ") should >= 0"; + assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7"; + assert endByteIndex > startByteIndex; + assert numRecords * 4 <= array.size(); + int inIndex = 0; + int outIndex = numRecords * 2; + if (numRecords > 0) { + long[][] counts = getKeyPrefixArrayCounts(array, numRecords, startByteIndex, endByteIndex); + for (int i = startByteIndex; i <= endByteIndex; i++) { + if (counts[i] != null) { + sortKeyPrefixArrayAtByte( + array, numRecords, counts[i], i, inIndex, outIndex, + desc, signed && i == endByteIndex); + int tmp = inIndex; + inIndex = outIndex; + outIndex = tmp; + } + } + } + return inIndex; + } + + /** + * Specialization of getCounts() for key-prefix arrays. We could probably combine this with + * getCounts with some added parameters but that seems to hurt in benchmarks. + */ + private static long[][] getKeyPrefixArrayCounts( + LongArray array, int numRecords, int startByteIndex, int endByteIndex) { + long[][] counts = new long[8][]; + long bitwiseMax = 0; + long bitwiseMin = -1L; + long limit = array.getBaseOffset() + numRecords * 16; + Object baseObject = array.getBaseObject(); + for (long offset = array.getBaseOffset(); offset < limit; offset += 16) { + long value = Platform.getLong(baseObject, offset + 8); + bitwiseMax |= value; + bitwiseMin &= value; + } + long bitsChanged = bitwiseMin ^ bitwiseMax; + for (int i = startByteIndex; i <= endByteIndex; i++) { + if (((bitsChanged >>> (i * 8)) & 0xff) != 0) { + counts[i] = new long[256]; + for (long offset = array.getBaseOffset(); offset < limit; offset += 16) { + counts[i][(int)((Platform.getLong(baseObject, offset + 8) >>> (i * 8)) & 0xff)]++; + } + } + } + return counts; + } + + /** + * Specialization of sortAtByte() for key-prefix arrays. + */ + private static void sortKeyPrefixArrayAtByte( + LongArray array, int numRecords, long[] counts, int byteIdx, int inIndex, int outIndex, + boolean desc, boolean signed) { + assert counts.length == 256; + long[] offsets = transformCountsToOffsets( + counts, numRecords, array.getBaseOffset() + outIndex * 8, 16, desc, signed); + Object baseObject = array.getBaseObject(); + long baseOffset = array.getBaseOffset() + inIndex * 8; + long maxOffset = baseOffset + numRecords * 16; + for (long offset = baseOffset; offset < maxOffset; offset += 16) { + long key = Platform.getLong(baseObject, offset); + long prefix = Platform.getLong(baseObject, offset + 8); + int bucket = (int)((prefix >>> (byteIdx * 8)) & 0xff); + long dest = offsets[bucket]; + Platform.putLong(baseObject, dest, key); + Platform.putLong(baseObject, dest + 8, prefix); + offsets[bucket] += 16; + } + } +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 3e32dd9d63..66a77982ad 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -89,7 +89,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer { UnsafeInMemorySorter inMemorySorter) throws IOException { UnsafeExternalSorter sorter = new UnsafeExternalSorter(taskMemoryManager, blockManager, serializerManager, taskContext, recordComparator, prefixComparator, initialSize, - pageSizeBytes, inMemorySorter); + pageSizeBytes, inMemorySorter, false /* ignored */); sorter.spill(Long.MAX_VALUE, sorter); // The external sorter will be used to insert records, in-memory sorter is not needed. sorter.inMemSorter = null; @@ -104,9 +104,11 @@ public final class UnsafeExternalSorter extends MemoryConsumer { RecordComparator recordComparator, PrefixComparator prefixComparator, int initialSize, - long pageSizeBytes) { + long pageSizeBytes, + boolean canUseRadixSort) { return new UnsafeExternalSorter(taskMemoryManager, blockManager, serializerManager, - taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, null); + taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, null, + canUseRadixSort); } private UnsafeExternalSorter( @@ -118,7 +120,8 @@ public final class UnsafeExternalSorter extends MemoryConsumer { PrefixComparator prefixComparator, int initialSize, long pageSizeBytes, - @Nullable UnsafeInMemorySorter existingInMemorySorter) { + @Nullable UnsafeInMemorySorter existingInMemorySorter, + boolean canUseRadixSort) { super(taskMemoryManager, pageSizeBytes); this.taskMemoryManager = taskMemoryManager; this.blockManager = blockManager; @@ -133,7 +136,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer { if (existingInMemorySorter == null) { this.inMemSorter = new UnsafeInMemorySorter( - this, taskMemoryManager, recordComparator, prefixComparator, initialSize); + this, taskMemoryManager, recordComparator, prefixComparator, initialSize, canUseRadixSort); } else { this.inMemSorter = existingInMemorySorter; } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 01eae0e8dc..5f46ef9a81 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -17,6 +17,7 @@ package org.apache.spark.util.collection.unsafe.sort; +import java.lang.Long; import java.util.Comparator; import org.apache.avro.reflect.Nullable; @@ -26,6 +27,7 @@ import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.util.collection.Sorter; +import org.apache.spark.util.collection.unsafe.sort.RadixSort; /** * Sorts records using an AlphaSort-style key-prefix sort. This sort stores pointers to records @@ -74,6 +76,17 @@ public final class UnsafeInMemorySorter { private final Comparator<RecordPointerAndKeyPrefix> sortComparator; /** + * If non-null, specifies the radix sort parameters and that radix sort will be used. + */ + @Nullable + private final PrefixComparators.RadixSortSupport radixSortSupport; + + /** + * Set to 2x for radix sort to reserve extra memory for sorting, otherwise 1x. + */ + private final int memoryAllocationFactor; + + /** * Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. */ @@ -91,27 +104,36 @@ public final class UnsafeInMemorySorter { final TaskMemoryManager memoryManager, final RecordComparator recordComparator, final PrefixComparator prefixComparator, - int initialSize) { + int initialSize, + boolean canUseRadixSort) { this(consumer, memoryManager, recordComparator, prefixComparator, - consumer.allocateArray(initialSize * 2)); + consumer.allocateArray(initialSize * 2), canUseRadixSort); } public UnsafeInMemorySorter( - final MemoryConsumer consumer, + final MemoryConsumer consumer, final TaskMemoryManager memoryManager, final RecordComparator recordComparator, final PrefixComparator prefixComparator, - LongArray array) { + LongArray array, + boolean canUseRadixSort) { this.consumer = consumer; this.memoryManager = memoryManager; this.initialSize = array.size(); if (recordComparator != null) { this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE); this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager); + if (canUseRadixSort && prefixComparator instanceof PrefixComparators.RadixSortSupport) { + this.radixSortSupport = (PrefixComparators.RadixSortSupport)prefixComparator; + } else { + this.radixSortSupport = null; + } } else { this.sorter = null; this.sortComparator = null; + this.radixSortSupport = null; } + this.memoryAllocationFactor = this.radixSortSupport != null ? 2 : 1; this.array = array; } @@ -141,11 +163,11 @@ public final class UnsafeInMemorySorter { } public long getMemoryUsage() { - return array.size() * 8L; + return array.size() * Long.BYTES; } public boolean hasSpaceForAnotherRecord() { - return pos + 2 <= array.size(); + return pos + 1 < (array.size() / memoryAllocationFactor); } public void expandPointerArray(LongArray newArray) { @@ -157,7 +179,7 @@ public final class UnsafeInMemorySorter { array.getBaseOffset(), newArray.getBaseObject(), newArray.getBaseOffset(), - array.size() * 8L); + array.size() * (Long.BYTES / memoryAllocationFactor)); consumer.freeArray(array); array = newArray; } @@ -183,18 +205,20 @@ public final class UnsafeInMemorySorter { private final int numRecords; private int position; + private int offset; private Object baseObject; private long baseOffset; private long keyPrefix; private int recordLength; - private SortedIterator(int numRecords) { + private SortedIterator(int numRecords, int offset) { this.numRecords = numRecords; this.position = 0; + this.offset = offset; } public SortedIterator clone() { - SortedIterator iter = new SortedIterator(numRecords); + SortedIterator iter = new SortedIterator(numRecords, offset); iter.position = position; iter.baseObject = baseObject; iter.baseOffset = baseOffset; @@ -216,11 +240,11 @@ public final class UnsafeInMemorySorter { @Override public void loadNext() { // This pointer points to a 4-byte record length, followed by the record's bytes - final long recordPointer = array.get(position); + final long recordPointer = array.get(offset + position); baseObject = memoryManager.getPage(recordPointer); baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4; // Skip over record length recordLength = Platform.getInt(baseObject, baseOffset - 4); - keyPrefix = array.get(position + 1); + keyPrefix = array.get(offset + position + 1); position += 2; } @@ -242,9 +266,17 @@ public final class UnsafeInMemorySorter { * {@code next()} will return the same mutable object. */ public SortedIterator getSortedIterator() { + int offset = 0; if (sorter != null) { - sorter.sort(array, 0, pos / 2, sortComparator); + if (this.radixSortSupport != null) { + // TODO(ekl) we should handle NULL values before radix sort for efficiency, since they + // force a full-width sort (and we cannot radix-sort nullable long fields at all). + offset = RadixSort.sortKeyPrefixArray( + array, pos / 2, 0, 7, radixSortSupport.sortDescending(), radixSortSupport.sortSigned()); + } else { + sorter.sort(array, 0, pos / 2, sortComparator); + } } - return new SortedIterator(pos / 2); + return new SortedIterator(pos / 2, offset); } } diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemoryRadixSorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemoryRadixSorterSuite.java new file mode 100644 index 0000000000..6927d0a815 --- /dev/null +++ b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemoryRadixSorterSuite.java @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort; + +public class ShuffleInMemoryRadixSorterSuite extends ShuffleInMemorySorterSuite { + @Override + protected boolean shouldUseRadixSort() { return true; } +} diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java index 4cd3600df1..43e32f073a 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.shuffle.sort; +import java.lang.Long; import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.Random; @@ -34,6 +35,8 @@ import org.apache.spark.unsafe.memory.MemoryBlock; public class ShuffleInMemorySorterSuite { + protected boolean shouldUseRadixSort() { return false; } + final TestMemoryManager memoryManager = new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")); final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0); @@ -47,7 +50,8 @@ public class ShuffleInMemorySorterSuite { @Test public void testSortingEmptyInput() { - final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(consumer, 100); + final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter( + consumer, 100, shouldUseRadixSort()); final ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator(); Assert.assertFalse(iter.hasNext()); } @@ -70,14 +74,16 @@ public class ShuffleInMemorySorterSuite { new TaskMemoryManager(new TestMemoryManager(conf), 0); final MemoryBlock dataPage = memoryManager.allocatePage(2048, null); final Object baseObject = dataPage.getBaseObject(); - final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(consumer, 4); + final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter( + consumer, 4, shouldUseRadixSort()); final HashPartitioner hashPartitioner = new HashPartitioner(4); // Write the records into the data page and store pointers into the sorter long position = dataPage.getBaseOffset(); for (String str : dataToSort) { if (!sorter.hasSpaceForAnotherRecord()) { - sorter.expandPointerArray(consumer.allocateArray(sorter.numRecords() * 2)); + sorter.expandPointerArray( + consumer.allocateArray(sorter.getMemoryUsage() / Long.BYTES * 2)); } final long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, position); final byte[] strBytes = str.getBytes(StandardCharsets.UTF_8); @@ -114,12 +120,12 @@ public class ShuffleInMemorySorterSuite { @Test public void testSortingManyNumbers() throws Exception { - ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(consumer, 4); + ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(consumer, 4, shouldUseRadixSort()); int[] numbersToSort = new int[128000]; Random random = new Random(16); for (int i = 0; i < numbersToSort.length; i++) { if (!sorter.hasSpaceForAnotherRecord()) { - sorter.expandPointerArray(consumer.allocateArray(sorter.numRecords() * 2)); + sorter.expandPointerArray(consumer.allocateArray(sorter.getMemoryUsage() / 8 * 2)); } numbersToSort[i] = random.nextInt(PackedRecordPointer.MAXIMUM_PARTITION_ID + 1); sorter.insertRecord(0, numbersToSort[i]); diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index fbaaa1cf49..f9dc20d8b7 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -392,7 +392,20 @@ public class UnsafeShuffleWriterSuite { } @Test - public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception { + public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpillRadixOff() throws Exception { + conf.set("spark.shuffle.sort.useRadixSort", "false"); + writeEnoughRecordsToTriggerSortBufferExpansionAndSpill(); + assertEquals(2, spillFilesCreated.size()); + } + + @Test + public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpillRadixOn() throws Exception { + conf.set("spark.shuffle.sort.useRadixSort", "true"); + writeEnoughRecordsToTriggerSortBufferExpansionAndSpill(); + assertEquals(3, spillFilesCreated.size()); + } + + private void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception { memoryManager.limit(UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE * 16); final UnsafeShuffleWriter<Object, Object> writer = createWriter(false); final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<>(); @@ -400,7 +413,6 @@ public class UnsafeShuffleWriterSuite { dataToWrite.add(new Tuple2<Object, Object>(i, i)); } writer.write(dataToWrite.iterator()); - assertEquals(2, spillFilesCreated.size()); writer.stop(true); readRecordsFromFile(); assertSpillFilesWereCleanedUp(); diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterRadixSortSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterRadixSortSuite.java new file mode 100644 index 0000000000..bb38305a07 --- /dev/null +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterRadixSortSuite.java @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +public class UnsafeExternalSorterRadixSortSuite extends UnsafeExternalSorterSuite { + @Override + protected boolean shouldUseRadixSort() { return true; } +} diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index a2253d8559..60a40cc172 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -64,12 +64,7 @@ public class UnsafeExternalSorterSuite { new JavaSerializer(new SparkConf()), new SparkConf().set("spark.shuffle.spill.compress", "false")); // Use integer comparison for comparing prefixes (which are partition ids, in this case) - final PrefixComparator prefixComparator = new PrefixComparator() { - @Override - public int compare(long prefix1, long prefix2) { - return (int) prefix1 - (int) prefix2; - } - }; + final PrefixComparator prefixComparator = PrefixComparators.LONG; // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so // use a dummy comparator final RecordComparator recordComparator = new RecordComparator() { @@ -88,6 +83,7 @@ public class UnsafeExternalSorterSuite { @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager; @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext; + protected boolean shouldUseRadixSort() { return false; } private final long pageSizeBytes = new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "4m"); @@ -178,7 +174,8 @@ public class UnsafeExternalSorterSuite { recordComparator, prefixComparator, /* initialSize */ 1024, - pageSizeBytes); + pageSizeBytes, + shouldUseRadixSort()); } @Test @@ -381,7 +378,8 @@ public class UnsafeExternalSorterSuite { null, null, /* initialSize */ 1024, - pageSizeBytes); + pageSizeBytes, + shouldUseRadixSort()); long[] record = new long[100]; int recordSize = record.length * 8; int n = (int) pageSizeBytes / recordSize * 3; @@ -416,7 +414,8 @@ public class UnsafeExternalSorterSuite { recordComparator, prefixComparator, 1024, - pageSizeBytes); + pageSizeBytes, + shouldUseRadixSort()); // Peak memory should be monotonically increasing. More specifically, every time // we allocate a new page it should increase by exactly the size of the page. diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterRadixSortSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterRadixSortSuite.java new file mode 100644 index 0000000000..ae69ededf7 --- /dev/null +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterRadixSortSuite.java @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort; + +public class UnsafeInMemorySorterRadixSortSuite extends UnsafeInMemorySorterSuite { + @Override + protected boolean shouldUseRadixSort() { return true; } +} diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index f90214fffd..23f4abfed2 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.util.collection.unsafe.sort; +import java.lang.Long; import java.nio.charset.StandardCharsets; import java.util.Arrays; @@ -39,6 +40,8 @@ import static org.mockito.Mockito.mock; public class UnsafeInMemorySorterSuite { + protected boolean shouldUseRadixSort() { return false; } + private static String getStringFromDataPage(Object baseObject, long baseOffset, int length) { final byte[] strBytes = new byte[length]; Platform.copyMemory(baseObject, baseOffset, strBytes, Platform.BYTE_ARRAY_OFFSET, length); @@ -54,7 +57,8 @@ public class UnsafeInMemorySorterSuite { memoryManager, mock(RecordComparator.class), mock(PrefixComparator.class), - 100); + 100, + shouldUseRadixSort()); final UnsafeSorterIterator iter = sorter.getSortedIterator(); Assert.assertFalse(iter.hasNext()); } @@ -102,19 +106,15 @@ public class UnsafeInMemorySorterSuite { // Compute key prefixes based on the records' partition ids final HashPartitioner hashPartitioner = new HashPartitioner(4); // Use integer comparison for comparing prefixes (which are partition ids, in this case) - final PrefixComparator prefixComparator = new PrefixComparator() { - @Override - public int compare(long prefix1, long prefix2) { - return (int) prefix1 - (int) prefix2; - } - }; + final PrefixComparator prefixComparator = PrefixComparators.LONG; UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(consumer, memoryManager, - recordComparator, prefixComparator, dataToSort.length); + recordComparator, prefixComparator, dataToSort.length, shouldUseRadixSort()); // Given a page of records, insert those records into the sorter one-by-one: position = dataPage.getBaseOffset(); for (int i = 0; i < dataToSort.length; i++) { if (!sorter.hasSpaceForAnotherRecord()) { - sorter.expandPointerArray(consumer.allocateArray(sorter.numRecords() * 2 * 2)); + sorter.expandPointerArray( + consumer.allocateArray(sorter.getMemoryUsage() / Long.BYTES * 2)); } // position now points to the start of a record (which holds its length). final int recordLength = Platform.getInt(baseObject, position); diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala index dda8bee222..b4083230b4 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala @@ -101,6 +101,8 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks { test("double prefix comparator handles NaNs properly") { val nan1: Double = java.lang.Double.longBitsToDouble(0x7ff0000000000001L) val nan2: Double = java.lang.Double.longBitsToDouble(0x7fffffffffffffffL) + assert( + java.lang.Double.doubleToRawLongBits(nan1) != java.lang.Double.doubleToRawLongBits(nan2)) assert(nan1.isNaN) assert(nan2.isNaN) val nan1Prefix = PrefixComparators.DoublePrefixComparator.computePrefix(nan1) @@ -110,4 +112,28 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks { assert(PrefixComparators.DOUBLE.compare(nan1Prefix, doubleMaxPrefix) === 1) } + test("double prefix comparator handles negative NaNs properly") { + val negativeNan: Double = java.lang.Double.longBitsToDouble(0xfff0000000000001L) + assert(negativeNan.isNaN) + assert(java.lang.Double.doubleToRawLongBits(negativeNan) < 0) + val prefix = PrefixComparators.DoublePrefixComparator.computePrefix(negativeNan) + val doubleMaxPrefix = PrefixComparators.DoublePrefixComparator.computePrefix(Double.MaxValue) + assert(PrefixComparators.DOUBLE.compare(prefix, doubleMaxPrefix) === 1) + } + + test("double prefix comparator handles other special values properly") { + val nullValue = 0L + val nan = PrefixComparators.DoublePrefixComparator.computePrefix(Double.NaN) + val posInf = PrefixComparators.DoublePrefixComparator.computePrefix(Double.PositiveInfinity) + val negInf = PrefixComparators.DoublePrefixComparator.computePrefix(Double.NegativeInfinity) + val minValue = PrefixComparators.DoublePrefixComparator.computePrefix(Double.MinValue) + val maxValue = PrefixComparators.DoublePrefixComparator.computePrefix(Double.MaxValue) + val zero = PrefixComparators.DoublePrefixComparator.computePrefix(0.0) + assert(PrefixComparators.DOUBLE.compare(nan, posInf) === 1) + assert(PrefixComparators.DOUBLE.compare(posInf, maxValue) === 1) + assert(PrefixComparators.DOUBLE.compare(maxValue, zero) === 1) + assert(PrefixComparators.DOUBLE.compare(zero, minValue) === 1) + assert(PrefixComparators.DOUBLE.compare(minValue, negInf) === 1) + assert(PrefixComparators.DOUBLE.compare(negInf, nullValue) === 1) + } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala new file mode 100644 index 0000000000..52428634e5 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala @@ -0,0 +1,264 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection.unsafe.sort + +import java.lang.{Long => JLong} +import java.util.{Arrays, Comparator} + +import scala.util.Random + +import org.apache.spark.SparkFunSuite +import org.apache.spark.internal.Logging +import org.apache.spark.unsafe.array.LongArray +import org.apache.spark.unsafe.memory.MemoryBlock +import org.apache.spark.util.Benchmark +import org.apache.spark.util.collection.Sorter +import org.apache.spark.util.random.XORShiftRandom + +class RadixSortSuite extends SparkFunSuite with Logging { + private val N = 10000 // scale this down for more readable results + + /** + * Describes a type of sort to test, e.g. two's complement descending. Each sort type has + * a defined reference ordering as well as radix sort parameters that can be used to + * reproduce the given ordering. + */ + case class RadixSortType( + name: String, + referenceComparator: PrefixComparator, + startByteIdx: Int, endByteIdx: Int, descending: Boolean, signed: Boolean) + + val SORT_TYPES_TO_TEST = Seq( + RadixSortType("unsigned binary data asc", PrefixComparators.BINARY, 0, 7, false, false), + RadixSortType("unsigned binary data desc", PrefixComparators.BINARY_DESC, 0, 7, true, false), + RadixSortType("twos complement asc", PrefixComparators.LONG, 0, 7, false, true), + RadixSortType("twos complement desc", PrefixComparators.LONG_DESC, 0, 7, true, true), + RadixSortType( + "binary data partial", + new PrefixComparators.RadixSortSupport { + override def sortDescending = false + override def sortSigned = false + override def compare(a: Long, b: Long): Int = { + return PrefixComparators.BINARY.compare(a & 0xffffff0000L, b & 0xffffff0000L) + } + }, + 2, 4, false, false)) + + private def generateTestData(size: Int, rand: => Long): (Array[JLong], LongArray) = { + val ref = Array.tabulate[Long](size) { i => rand } + val extended = ref ++ Array.fill[Long](size)(0) + (ref.map(i => new JLong(i)), new LongArray(MemoryBlock.fromLongArray(extended))) + } + + private def generateKeyPrefixTestData(size: Int, rand: => Long): (LongArray, LongArray) = { + val ref = Array.tabulate[Long](size * 2) { i => rand } + val extended = ref ++ Array.fill[Long](size * 2)(0) + (new LongArray(MemoryBlock.fromLongArray(ref)), + new LongArray(MemoryBlock.fromLongArray(extended))) + } + + private def collectToArray(array: LongArray, offset: Int, length: Int): Array[Long] = { + var i = 0 + val out = new Array[Long](length) + while (i < length) { + out(i) = array.get(offset + i) + i += 1 + } + out + } + + private def toJavaComparator(p: PrefixComparator): Comparator[JLong] = { + new Comparator[JLong] { + override def compare(a: JLong, b: JLong): Int = { + p.compare(a, b) + } + override def equals(other: Any): Boolean = { + other == this + } + } + } + + private def referenceKeyPrefixSort(buf: LongArray, lo: Int, hi: Int, refCmp: PrefixComparator) { + new Sorter(UnsafeSortDataFormat.INSTANCE).sort( + buf, lo, hi, new Comparator[RecordPointerAndKeyPrefix] { + override def compare( + r1: RecordPointerAndKeyPrefix, + r2: RecordPointerAndKeyPrefix): Int = { + refCmp.compare(r1.keyPrefix, r2.keyPrefix) + } + }) + } + + private def fuzzTest(name: String)(testFn: Long => Unit): Unit = { + test(name) { + var seed = 0L + try { + for (i <- 0 to 10) { + seed = System.nanoTime + testFn(seed) + } + } catch { + case t: Throwable => + throw new Exception("Failed with seed: " + seed, t) + } + } + } + + // Radix sort is sensitive to the value distribution at different bit indices (e.g., we may + // omit a sort on a byte if all values are equal). This generates random good test masks. + def randomBitMask(rand: Random): Long = { + var tmp = ~0L + for (i <- 0 to rand.nextInt(5)) { + tmp &= rand.nextLong + } + tmp + } + + for (sortType <- SORT_TYPES_TO_TEST) { + test("radix support for " + sortType.name) { + val s = sortType.referenceComparator.asInstanceOf[PrefixComparators.RadixSortSupport] + assert(s.sortDescending() == sortType.descending) + assert(s.sortSigned() == sortType.signed) + } + + test("sort " + sortType.name) { + val rand = new XORShiftRandom(123) + val (ref, buffer) = generateTestData(N, rand.nextLong) + Arrays.sort(ref, toJavaComparator(sortType.referenceComparator)) + val outOffset = RadixSort.sort( + buffer, N, sortType.startByteIdx, sortType.endByteIdx, + sortType.descending, sortType.signed) + val result = collectToArray(buffer, outOffset, N) + assert(ref.view == result.view) + } + + test("sort key prefix " + sortType.name) { + val rand = new XORShiftRandom(123) + val (buf1, buf2) = generateKeyPrefixTestData(N, rand.nextLong & 0xff) + referenceKeyPrefixSort(buf1, 0, N, sortType.referenceComparator) + val outOffset = RadixSort.sortKeyPrefixArray( + buf2, N, sortType.startByteIdx, sortType.endByteIdx, + sortType.descending, sortType.signed) + val res1 = collectToArray(buf1, 0, N * 2) + val res2 = collectToArray(buf2, outOffset, N * 2) + assert(res1.view == res2.view) + } + + fuzzTest(s"fuzz test ${sortType.name} with random bitmasks") { seed => + val rand = new XORShiftRandom(seed) + val mask = randomBitMask(rand) + val (ref, buffer) = generateTestData(N, rand.nextLong & mask) + Arrays.sort(ref, toJavaComparator(sortType.referenceComparator)) + val outOffset = RadixSort.sort( + buffer, N, sortType.startByteIdx, sortType.endByteIdx, + sortType.descending, sortType.signed) + val result = collectToArray(buffer, outOffset, N) + assert(ref.view == result.view) + } + + fuzzTest(s"fuzz test key prefix ${sortType.name} with random bitmasks") { seed => + val rand = new XORShiftRandom(seed) + val mask = randomBitMask(rand) + val (buf1, buf2) = generateKeyPrefixTestData(N, rand.nextLong & mask) + referenceKeyPrefixSort(buf1, 0, N, sortType.referenceComparator) + val outOffset = RadixSort.sortKeyPrefixArray( + buf2, N, sortType.startByteIdx, sortType.endByteIdx, + sortType.descending, sortType.signed) + val res1 = collectToArray(buf1, 0, N * 2) + val res2 = collectToArray(buf2, outOffset, N * 2) + assert(res1.view == res2.view) + } + } + + ignore("microbenchmarks") { + val size = 25000000 + val rand = new XORShiftRandom(123) + val benchmark = new Benchmark("radix sort " + size, size) + benchmark.addTimerCase("reference TimSort key prefix array") { timer => + val array = Array.tabulate[Long](size * 2) { i => rand.nextLong } + val buf = new LongArray(MemoryBlock.fromLongArray(array)) + timer.startTiming() + referenceKeyPrefixSort(buf, 0, size, PrefixComparators.BINARY) + timer.stopTiming() + } + benchmark.addTimerCase("reference Arrays.sort") { timer => + val ref = Array.tabulate[Long](size) { i => rand.nextLong } + timer.startTiming() + Arrays.sort(ref) + timer.stopTiming() + } + benchmark.addTimerCase("radix sort one byte") { timer => + val array = new Array[Long](size * 2) + var i = 0 + while (i < size) { + array(i) = rand.nextLong & 0xff + i += 1 + } + val buf = new LongArray(MemoryBlock.fromLongArray(array)) + timer.startTiming() + RadixSort.sort(buf, size, 0, 7, false, false) + timer.stopTiming() + } + benchmark.addTimerCase("radix sort two bytes") { timer => + val array = new Array[Long](size * 2) + var i = 0 + while (i < size) { + array(i) = rand.nextLong & 0xffff + i += 1 + } + val buf = new LongArray(MemoryBlock.fromLongArray(array)) + timer.startTiming() + RadixSort.sort(buf, size, 0, 7, false, false) + timer.stopTiming() + } + benchmark.addTimerCase("radix sort eight bytes") { timer => + val array = new Array[Long](size * 2) + var i = 0 + while (i < size) { + array(i) = rand.nextLong + i += 1 + } + val buf = new LongArray(MemoryBlock.fromLongArray(array)) + timer.startTiming() + RadixSort.sort(buf, size, 0, 7, false, false) + timer.stopTiming() + } + benchmark.addTimerCase("radix sort key prefix array") { timer => + val (_, buf2) = generateKeyPrefixTestData(size, rand.nextLong) + timer.startTiming() + RadixSort.sortKeyPrefixArray(buf2, size, 0, 7, false, false) + timer.stopTiming() + } + benchmark.run + + /** + Running benchmark: radix sort 25000000 + Java HotSpot(TM) 64-Bit Server VM 1.8.0_66-b17 on Linux 3.13.0-44-generic + Intel(R) Core(TM) i7-4600U CPU @ 2.10GHz + + radix sort 25000000: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + reference TimSort key prefix array 15546 / 15859 1.6 621.9 1.0X + reference Arrays.sort 2416 / 2446 10.3 96.6 6.4X + radix sort one byte 133 / 137 188.4 5.3 117.2X + radix sort two bytes 255 / 258 98.2 10.2 61.1X + radix sort eight bytes 991 / 997 25.2 39.6 15.7X + radix sort key prefix array 1540 / 1563 16.2 61.6 10.1X + */ + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 7784345a7a..8d9906da7e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -59,7 +59,8 @@ public final class UnsafeExternalRowSorter { Ordering<InternalRow> ordering, PrefixComparator prefixComparator, PrefixComputer prefixComputer, - long pageSizeBytes) throws IOException { + long pageSizeBytes, + boolean canUseRadixSort) throws IOException { this.schema = schema; this.prefixComputer = prefixComputer; final SparkEnv sparkEnv = SparkEnv.get(); @@ -72,7 +73,8 @@ public final class UnsafeExternalRowSorter { new RowComparator(ordering, schema.length()), prefixComparator, /* initialSize */ 4096, - pageSizeBytes + pageSizeBytes, + canUseRadixSort ); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index e0c3b22a3c..42a8be6b1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -84,8 +84,7 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { case DateType | TimestampType => (Long.MinValue, s"(long) $input") case FloatType | DoubleType => - (DoublePrefixComparator.computePrefix(Double.NegativeInfinity), - s"$DoublePrefixCmp.computePrefix((double)$input)") + (0L, s"$DoublePrefixCmp.computePrefix((double)$input)") case StringType => (0L, s"$input.getPrefix()") case BinaryType => (0L, s"$BinaryPrefixCmp.computePrefix($input)") case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS => diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index 8132bba04c..b6499e35b5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -85,13 +85,15 @@ public final class UnsafeKVExternalSorter { recordComparator, prefixComparator, /* initialSize */ 4096, - pageSizeBytes); + pageSizeBytes, + keySchema.length() == 1 && SortPrefixUtils.canSortFullyWithPrefix(keySchema.apply(0))); } else { // During spilling, the array in map will not be used, so we can borrow that and use it // as the underline array for in-memory sorter (it's always large enough). // Since we will not grow the array, it's fine to pass `null` as consumer. final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter( - null, taskMemoryManager, recordComparator, prefixComparator, map.getArray()); + null, taskMemoryManager, recordComparator, prefixComparator, map.getArray(), + false /* TODO(ekl) we can only radix sort if the BytesToBytes load factor is <= 0.5 */); // We cannot use the destructive iterator here because we are reusing the existing memory // pages in BytesToBytesMap to hold records during sorting. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala index 80255fafbe..6f9baa2d33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala @@ -25,6 +25,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.plans.physical.{Distribution, OrderedDistribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.types._ +import org.apache.spark.util.collection.unsafe.sort.RadixSort; /** * Performs (external) sorting. @@ -48,7 +50,10 @@ case class Sort( override def requiredChildDistribution: Seq[Distribution] = if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil + private val enableRadixSort = sqlContext.conf.enableRadixSort + override private[sql] lazy val metrics = Map( + "sortTime" -> SQLMetrics.createLongMetric(sparkContext, "sort time"), "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size")) @@ -59,6 +64,9 @@ case class Sort( val boundSortExpression = BindReferences.bindReference(sortOrder.head, output) val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression) + val canUseRadixSort = enableRadixSort && sortOrder.length == 1 && + SortPrefixUtils.canSortFullyWithPrefix(boundSortExpression) + // The generator for prefix val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression))) val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { @@ -69,7 +77,8 @@ case class Sort( val pageSize = SparkEnv.get.memoryManager.pageSizeBytes val sorter = new UnsafeExternalRowSorter( - schema, ordering, prefixComparator, prefixComputer, pageSize) + schema, ordering, prefixComparator, prefixComputer, pageSize, canUseRadixSort) + if (testSpillFrequency > 0) { sorter.setTestSpillFrequency(testSpillFrequency) } @@ -139,11 +148,15 @@ case class Sort( val dataSize = metricTerm(ctx, "dataSize") val spillSize = metricTerm(ctx, "spillSize") val spillSizeBefore = ctx.freshName("spillSizeBefore") + val startTime = ctx.freshName("startTime") + val sortTime = metricTerm(ctx, "sortTime") s""" | if ($needToSort) { | $addToSorter(); - | Long $spillSizeBefore = $metrics.memoryBytesSpilled(); + | long $spillSizeBefore = $metrics.memoryBytesSpilled(); + | long $startTime = System.nanoTime(); | $sortedIterator = $sorterVariable.sort(); + | $sortTime.add(System.nanoTime() - $startTime); | $dataSize.add($sorterVariable.getPeakMemoryUsage()); | $spillSize.add($metrics.memoryBytesSpilled() - $spillSizeBefore); | $metrics.incPeakExecutionMemory($sorterVariable.getPeakMemoryUsage()); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala index 909f124d2c..1a5ff5fcec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala @@ -66,6 +66,32 @@ object SortPrefixUtils { } /** + * Returns whether the specified SortOrder can be satisfied with a radix sort on the prefix. + */ + def canSortFullyWithPrefix(sortOrder: SortOrder): Boolean = { + sortOrder.dataType match { + // TODO(ekl) long-type is problematic because it's null prefix representation collides with + // the lowest possible long value. Handle this special case outside radix sort. + case LongType if sortOrder.nullable => + false + case BooleanType | ByteType | ShortType | IntegerType | LongType | DateType | + TimestampType | FloatType | DoubleType => + true + case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => + true + case _ => + false + } + } + + /** + * Returns whether the fully sorting on the specified key field is possible with radix sort. + */ + def canSortFullyWithPrefix(field: StructField): Boolean = { + canSortFullyWithPrefix(SortOrder(BoundReference(0, field.dataType, field.nullable), Ascending)) + } + + /** * Creates the prefix computer for the first field in the given schema, in ascending order. */ def createPrefixGenerator(schema: StructType): UnsafeExternalRowSorter.PrefixComputer = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 85ce388de0..a46d0e0ba7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -344,7 +344,8 @@ case class Window( null, null, 1024, - SparkEnv.get.memoryManager.pageSizeBytes) + SparkEnv.get.memoryManager.pageSizeBytes, + false) rows.foreach { r => sorter.insertRecord(r.getBaseObject, r.getBaseOffset, r.getSizeInBytes, 0) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala index edb4c5a16f..b1de52b5f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -46,7 +46,8 @@ class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numField null, null, 1024, - SparkEnv.get.memoryManager.pageSizeBytes) + SparkEnv.get.memoryManager.pageSizeBytes, + false) val partition = split.asInstanceOf[CartesianPartition] for (y <- rdd2.iterator(partition.s2, context)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 6e7c1bc133..a4e82d80f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -95,6 +95,14 @@ object SQLConf { .booleanConf .createWithDefault(true) + val RADIX_SORT_ENABLED = SQLConfigBuilder("spark.sql.sort.enableRadixSort") + .internal() + .doc("When true, enable use of radix sort when possible. Radix sort is much faster but " + + "requires additional memory to be reserved up-front. The memory overhead may be " + + "significant when sorting very small rows (up to 50% more in this case).") + .booleanConf + .createWithDefault(true) + val AUTO_BROADCASTJOIN_THRESHOLD = SQLConfigBuilder("spark.sql.autoBroadcastJoinThreshold") .doc("Configures the maximum size in bytes for a table that will be broadcast to all worker " + "nodes when performing a join. By setting this value to -1 broadcasting can be disabled. " + @@ -584,6 +592,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def preferSortMergeJoin: Boolean = getConf(PREFER_SORTMERGEJOIN) + def enableRadixSort: Boolean = getConf(RADIX_SORT_ENABLED) + def defaultSizeInBytes: Long = getConf(DEFAULT_SIZE_IN_BYTES, autoBroadcastJoinThreshold + 1L) |