aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java10
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java3
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java44
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java114
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java253
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java13
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java58
-rw-r--r--core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemoryRadixSorterSuite.java23
-rw-r--r--core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java16
-rw-r--r--core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java16
-rw-r--r--core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterRadixSortSuite.java23
-rw-r--r--core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java17
-rw-r--r--core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterRadixSortSuite.java23
-rw-r--r--core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java18
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala26
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala264
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala3
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala17
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala26
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala10
24 files changed, 876 insertions, 119 deletions
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java b/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java
index f7a6c68be9..b36da80dbc 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java
@@ -42,6 +42,16 @@ final class PackedRecordPointer {
*/
static final int MAXIMUM_PARTITION_ID = (1 << 24) - 1; // 16777215
+ /**
+ * The index of the first byte of the partition id, counting from the least significant byte.
+ */
+ static final int PARTITION_ID_START_BYTE_INDEX = 5;
+
+ /**
+ * The index of the last byte of the partition id, counting from the least significant byte.
+ */
+ static final int PARTITION_ID_END_BYTE_INDEX = 7;
+
/** Bit mask for the lower 40 bits of a long. */
private static final long MASK_LONG_LOWER_40_BITS = (1L << 40) - 1;
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
index 3c2980e442..c4041a97e8 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
@@ -115,7 +115,8 @@ final class ShuffleExternalSorter extends MemoryConsumer {
this.numElementsForSpillThreshold =
conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MAX_VALUE);
this.writeMetrics = writeMetrics;
- this.inMemSorter = new ShuffleInMemorySorter(this, initialSize);
+ this.inMemSorter = new ShuffleInMemorySorter(
+ this, initialSize, conf.getBoolean("spark.shuffle.sort.useRadixSort", true));
this.peakMemoryUsedBytes = getMemoryUsage();
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
index 76b0e6a304..68630946ac 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
@@ -17,12 +17,14 @@
package org.apache.spark.shuffle.sort;
+import java.lang.Long;
import java.util.Comparator;
import org.apache.spark.memory.MemoryConsumer;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.util.collection.Sorter;
+import org.apache.spark.util.collection.unsafe.sort.RadixSort;
final class ShuffleInMemorySorter {
@@ -47,16 +49,29 @@ final class ShuffleInMemorySorter {
private LongArray array;
/**
+ * Whether to use radix sort for sorting in-memory partition ids. Radix sort is much faster
+ * but requires additional memory to be reserved memory as pointers are added.
+ */
+ private final boolean useRadixSort;
+
+ /**
+ * Set to 2x for radix sort to reserve extra memory for sorting, otherwise 1x.
+ */
+ private final int memoryAllocationFactor;
+
+ /**
* The position in the pointer array where new records can be inserted.
*/
private int pos = 0;
private int initialSize;
- ShuffleInMemorySorter(MemoryConsumer consumer, int initialSize) {
+ ShuffleInMemorySorter(MemoryConsumer consumer, int initialSize, boolean useRadixSort) {
this.consumer = consumer;
assert (initialSize > 0);
this.initialSize = initialSize;
+ this.useRadixSort = useRadixSort;
+ this.memoryAllocationFactor = useRadixSort ? 2 : 1;
this.array = consumer.allocateArray(initialSize);
this.sorter = new Sorter<>(ShuffleSortDataFormat.INSTANCE);
}
@@ -87,18 +102,18 @@ final class ShuffleInMemorySorter {
array.getBaseOffset(),
newArray.getBaseObject(),
newArray.getBaseOffset(),
- array.size() * 8L
+ array.size() * (Long.BYTES / memoryAllocationFactor)
);
consumer.freeArray(array);
array = newArray;
}
public boolean hasSpaceForAnotherRecord() {
- return pos < array.size();
+ return pos < array.size() / memoryAllocationFactor;
}
public long getMemoryUsage() {
- return array.size() * 8L;
+ return array.size() * Long.BYTES;
}
/**
@@ -125,17 +140,18 @@ final class ShuffleInMemorySorter {
public static final class ShuffleSorterIterator {
private final LongArray pointerArray;
- private final int numRecords;
+ private final int limit;
final PackedRecordPointer packedRecordPointer = new PackedRecordPointer();
private int position = 0;
- ShuffleSorterIterator(int numRecords, LongArray pointerArray) {
- this.numRecords = numRecords;
+ ShuffleSorterIterator(int numRecords, LongArray pointerArray, int startingPosition) {
+ this.limit = numRecords + startingPosition;
this.pointerArray = pointerArray;
+ this.position = startingPosition;
}
public boolean hasNext() {
- return position < numRecords;
+ return position < limit;
}
public void loadNext() {
@@ -148,7 +164,15 @@ final class ShuffleInMemorySorter {
* Return an iterator over record pointers in sorted order.
*/
public ShuffleSorterIterator getSortedIterator() {
- sorter.sort(array, 0, pos, SORT_COMPARATOR);
- return new ShuffleSorterIterator(pos, array);
+ int offset = 0;
+ if (useRadixSort) {
+ offset = RadixSort.sort(
+ array, pos,
+ PackedRecordPointer.PARTITION_ID_START_BYTE_INDEX,
+ PackedRecordPointer.PARTITION_ID_END_BYTE_INDEX, false, false);
+ } else {
+ sorter.sort(array, 0, pos, SORT_COMPARATOR);
+ }
+ return new ShuffleSorterIterator(pos, array, offset);
}
}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
index c2a8f429be..21f2fde79d 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
@@ -28,88 +28,92 @@ import org.apache.spark.util.Utils;
public class PrefixComparators {
private PrefixComparators() {}
- public static final StringPrefixComparator STRING = new StringPrefixComparator();
- public static final StringPrefixComparatorDesc STRING_DESC = new StringPrefixComparatorDesc();
- public static final BinaryPrefixComparator BINARY = new BinaryPrefixComparator();
- public static final BinaryPrefixComparatorDesc BINARY_DESC = new BinaryPrefixComparatorDesc();
- public static final LongPrefixComparator LONG = new LongPrefixComparator();
- public static final LongPrefixComparatorDesc LONG_DESC = new LongPrefixComparatorDesc();
- public static final DoublePrefixComparator DOUBLE = new DoublePrefixComparator();
- public static final DoublePrefixComparatorDesc DOUBLE_DESC = new DoublePrefixComparatorDesc();
-
- public static final class StringPrefixComparator extends PrefixComparator {
- @Override
- public int compare(long aPrefix, long bPrefix) {
- return UnsignedLongs.compare(aPrefix, bPrefix);
- }
-
+ public static final PrefixComparator STRING = new UnsignedPrefixComparator();
+ public static final PrefixComparator STRING_DESC = new UnsignedPrefixComparatorDesc();
+ public static final PrefixComparator BINARY = new UnsignedPrefixComparator();
+ public static final PrefixComparator BINARY_DESC = new UnsignedPrefixComparatorDesc();
+ public static final PrefixComparator LONG = new SignedPrefixComparator();
+ public static final PrefixComparator LONG_DESC = new SignedPrefixComparatorDesc();
+ public static final PrefixComparator DOUBLE = new UnsignedPrefixComparator();
+ public static final PrefixComparator DOUBLE_DESC = new UnsignedPrefixComparatorDesc();
+
+ public static final class StringPrefixComparator {
public static long computePrefix(UTF8String value) {
return value == null ? 0L : value.getPrefix();
}
}
- public static final class StringPrefixComparatorDesc extends PrefixComparator {
- @Override
- public int compare(long bPrefix, long aPrefix) {
- return UnsignedLongs.compare(aPrefix, bPrefix);
+ public static final class BinaryPrefixComparator {
+ public static long computePrefix(byte[] bytes) {
+ return ByteArray.getPrefix(bytes);
}
}
- public static final class BinaryPrefixComparator extends PrefixComparator {
- @Override
- public int compare(long aPrefix, long bPrefix) {
- return UnsignedLongs.compare(aPrefix, bPrefix);
+ public static final class DoublePrefixComparator {
+ /**
+ * Converts the double into a value that compares correctly as an unsigned long. For more
+ * details see http://stereopsis.com/radix.html.
+ */
+ public static long computePrefix(double value) {
+ // Java's doubleToLongBits already canonicalizes all NaN values to the smallest possible
+ // positive NaN, so there's nothing special we need to do for NaNs.
+ long bits = Double.doubleToLongBits(value);
+ // Negative floats compare backwards due to their sign-magnitude representation, so flip
+ // all the bits in this case.
+ long mask = -(bits >>> 63) | 0x8000000000000000L;
+ return bits ^ mask;
}
+ }
- public static long computePrefix(byte[] bytes) {
- return ByteArray.getPrefix(bytes);
- }
+ /**
+ * Provides radix sort parameters. Comparators implementing this also are indicating that the
+ * ordering they define is compatible with radix sort.
+ */
+ public static abstract class RadixSortSupport extends PrefixComparator {
+ /** @return Whether the sort should be descending in binary sort order. */
+ public abstract boolean sortDescending();
+
+ /** @return Whether the sort should take into account the sign bit. */
+ public abstract boolean sortSigned();
}
- public static final class BinaryPrefixComparatorDesc extends PrefixComparator {
+ //
+ // Standard prefix comparator implementations
+ //
+
+ public static final class UnsignedPrefixComparator extends RadixSortSupport {
+ @Override public final boolean sortDescending() { return false; }
+ @Override public final boolean sortSigned() { return false; }
@Override
- public int compare(long bPrefix, long aPrefix) {
+ public final int compare(long aPrefix, long bPrefix) {
return UnsignedLongs.compare(aPrefix, bPrefix);
}
}
- public static final class LongPrefixComparator extends PrefixComparator {
+ public static final class UnsignedPrefixComparatorDesc extends RadixSortSupport {
+ @Override public final boolean sortDescending() { return true; }
+ @Override public final boolean sortSigned() { return false; }
@Override
- public int compare(long a, long b) {
- return (a < b) ? -1 : (a > b) ? 1 : 0;
+ public final int compare(long bPrefix, long aPrefix) {
+ return UnsignedLongs.compare(aPrefix, bPrefix);
}
}
- public static final class LongPrefixComparatorDesc extends PrefixComparator {
+ public static final class SignedPrefixComparator extends RadixSortSupport {
+ @Override public final boolean sortDescending() { return false; }
+ @Override public final boolean sortSigned() { return true; }
@Override
- public int compare(long b, long a) {
+ public final int compare(long a, long b) {
return (a < b) ? -1 : (a > b) ? 1 : 0;
}
}
- public static final class DoublePrefixComparator extends PrefixComparator {
+ public static final class SignedPrefixComparatorDesc extends RadixSortSupport {
+ @Override public final boolean sortDescending() { return true; }
+ @Override public final boolean sortSigned() { return true; }
@Override
- public int compare(long aPrefix, long bPrefix) {
- double a = Double.longBitsToDouble(aPrefix);
- double b = Double.longBitsToDouble(bPrefix);
- return Utils.nanSafeCompareDoubles(a, b);
- }
-
- public static long computePrefix(double value) {
- return Double.doubleToLongBits(value);
- }
- }
-
- public static final class DoublePrefixComparatorDesc extends PrefixComparator {
- @Override
- public int compare(long bPrefix, long aPrefix) {
- double a = Double.longBitsToDouble(aPrefix);
- double b = Double.longBitsToDouble(bPrefix);
- return Utils.nanSafeCompareDoubles(a, b);
- }
-
- public static long computePrefix(double value) {
- return Double.doubleToLongBits(value);
+ public final int compare(long b, long a) {
+ return (a < b) ? -1 : (a > b) ? 1 : 0;
}
}
}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java
new file mode 100644
index 0000000000..3357b8e474
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java
@@ -0,0 +1,253 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.LongArray;
+
+public class RadixSort {
+
+ /**
+ * Sorts a given array of longs using least-significant-digit radix sort. This routine assumes
+ * you have extra space at the end of the array at least equal to the number of records. The
+ * sort is destructive and may relocate the data positioned within the array.
+ *
+ * @param array array of long elements followed by at least that many empty slots.
+ * @param numRecords number of data records in the array.
+ * @param startByteIndex the first byte (in range [0, 7]) to sort each long by, counting from the
+ * least significant byte.
+ * @param endByteIndex the last byte (in range [0, 7]) to sort each long by, counting from the
+ * least significant byte. Must be greater than startByteIndex.
+ * @param desc whether this is a descending (binary-order) sort.
+ * @param signed whether this is a signed (two's complement) sort.
+ *
+ * @return The starting index of the sorted data within the given array. We return this instead
+ * of always copying the data back to position zero for efficiency.
+ */
+ public static int sort(
+ LongArray array, int numRecords, int startByteIndex, int endByteIndex,
+ boolean desc, boolean signed) {
+ assert startByteIndex >= 0 : "startByteIndex (" + startByteIndex + ") should >= 0";
+ assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7";
+ assert endByteIndex > startByteIndex;
+ assert numRecords * 2 <= array.size();
+ int inIndex = 0;
+ int outIndex = numRecords;
+ if (numRecords > 0) {
+ long[][] counts = getCounts(array, numRecords, startByteIndex, endByteIndex);
+ for (int i = startByteIndex; i <= endByteIndex; i++) {
+ if (counts[i] != null) {
+ sortAtByte(
+ array, numRecords, counts[i], i, inIndex, outIndex,
+ desc, signed && i == endByteIndex);
+ int tmp = inIndex;
+ inIndex = outIndex;
+ outIndex = tmp;
+ }
+ }
+ }
+ return inIndex;
+ }
+
+ /**
+ * Performs a partial sort by copying data into destination offsets for each byte value at the
+ * specified byte offset.
+ *
+ * @param array array to partially sort.
+ * @param numRecords number of data records in the array.
+ * @param counts counts for each byte value. This routine destructively modifies this array.
+ * @param byteIdx the byte in a long to sort at, counting from the least significant byte.
+ * @param inIndex the starting index in the array where input data is located.
+ * @param outIndex the starting index where sorted output data should be written.
+ * @param desc whether this is a descending (binary-order) sort.
+ * @param signed whether this is a signed (two's complement) sort (only applies to last byte).
+ */
+ private static void sortAtByte(
+ LongArray array, int numRecords, long[] counts, int byteIdx, int inIndex, int outIndex,
+ boolean desc, boolean signed) {
+ assert counts.length == 256;
+ long[] offsets = transformCountsToOffsets(
+ counts, numRecords, array.getBaseOffset() + outIndex * 8, 8, desc, signed);
+ Object baseObject = array.getBaseObject();
+ long baseOffset = array.getBaseOffset() + inIndex * 8;
+ long maxOffset = baseOffset + numRecords * 8;
+ for (long offset = baseOffset; offset < maxOffset; offset += 8) {
+ long value = Platform.getLong(baseObject, offset);
+ int bucket = (int)((value >>> (byteIdx * 8)) & 0xff);
+ Platform.putLong(baseObject, offsets[bucket], value);
+ offsets[bucket] += 8;
+ }
+ }
+
+ /**
+ * Computes a value histogram for each byte in the given array.
+ *
+ * @param array array to count records in.
+ * @param numRecords number of data records in the array.
+ * @param startByteIndex the first byte to compute counts for (the prior are skipped).
+ * @param endByteIndex the last byte to compute counts for.
+ *
+ * @return an array of eight 256-byte count arrays, one for each byte starting from the least
+ * significant byte. If the byte does not need sorting the array will be null.
+ */
+ private static long[][] getCounts(
+ LongArray array, int numRecords, int startByteIndex, int endByteIndex) {
+ long[][] counts = new long[8][];
+ // Optimization: do a fast pre-pass to determine which byte indices we can skip for sorting.
+ // If all the byte values at a particular index are the same we don't need to count it.
+ long bitwiseMax = 0;
+ long bitwiseMin = -1L;
+ long maxOffset = array.getBaseOffset() + numRecords * 8;
+ Object baseObject = array.getBaseObject();
+ for (long offset = array.getBaseOffset(); offset < maxOffset; offset += 8) {
+ long value = Platform.getLong(baseObject, offset);
+ bitwiseMax |= value;
+ bitwiseMin &= value;
+ }
+ long bitsChanged = bitwiseMin ^ bitwiseMax;
+ // Compute counts for each byte index.
+ for (int i = startByteIndex; i <= endByteIndex; i++) {
+ if (((bitsChanged >>> (i * 8)) & 0xff) != 0) {
+ counts[i] = new long[256];
+ // TODO(ekl) consider computing all the counts in one pass.
+ for (long offset = array.getBaseOffset(); offset < maxOffset; offset += 8) {
+ counts[i][(int)((Platform.getLong(baseObject, offset) >>> (i * 8)) & 0xff)]++;
+ }
+ }
+ }
+ return counts;
+ }
+
+ /**
+ * Transforms counts into the proper unsafe output offsets for the sort type.
+ *
+ * @param counts counts for each byte value. This routine destructively modifies this array.
+ * @param numRecords number of data records in the original data array.
+ * @param outputOffset output offset in bytes from the base array object.
+ * @param bytesPerRecord size of each record (8 for plain sort, 16 for key-prefix sort).
+ * @param desc whether this is a descending (binary-order) sort.
+ * @param signed whether this is a signed (two's complement) sort.
+ *
+ * @return the input counts array.
+ */
+ private static long[] transformCountsToOffsets(
+ long[] counts, int numRecords, long outputOffset, int bytesPerRecord,
+ boolean desc, boolean signed) {
+ assert counts.length == 256;
+ int start = signed ? 128 : 0; // output the negative records first (values 129-255).
+ if (desc) {
+ int pos = numRecords;
+ for (int i = start; i < start + 256; i++) {
+ pos -= counts[i & 0xff];
+ counts[i & 0xff] = outputOffset + pos * bytesPerRecord;
+ }
+ } else {
+ int pos = 0;
+ for (int i = start; i < start + 256; i++) {
+ long tmp = counts[i & 0xff];
+ counts[i & 0xff] = outputOffset + pos * bytesPerRecord;
+ pos += tmp;
+ }
+ }
+ return counts;
+ }
+
+ /**
+ * Specialization of sort() for key-prefix arrays. In this type of array, each record consists
+ * of two longs, only the second of which is sorted on.
+ */
+ public static int sortKeyPrefixArray(
+ LongArray array,
+ int numRecords,
+ int startByteIndex,
+ int endByteIndex,
+ boolean desc,
+ boolean signed) {
+ assert startByteIndex >= 0 : "startByteIndex (" + startByteIndex + ") should >= 0";
+ assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7";
+ assert endByteIndex > startByteIndex;
+ assert numRecords * 4 <= array.size();
+ int inIndex = 0;
+ int outIndex = numRecords * 2;
+ if (numRecords > 0) {
+ long[][] counts = getKeyPrefixArrayCounts(array, numRecords, startByteIndex, endByteIndex);
+ for (int i = startByteIndex; i <= endByteIndex; i++) {
+ if (counts[i] != null) {
+ sortKeyPrefixArrayAtByte(
+ array, numRecords, counts[i], i, inIndex, outIndex,
+ desc, signed && i == endByteIndex);
+ int tmp = inIndex;
+ inIndex = outIndex;
+ outIndex = tmp;
+ }
+ }
+ }
+ return inIndex;
+ }
+
+ /**
+ * Specialization of getCounts() for key-prefix arrays. We could probably combine this with
+ * getCounts with some added parameters but that seems to hurt in benchmarks.
+ */
+ private static long[][] getKeyPrefixArrayCounts(
+ LongArray array, int numRecords, int startByteIndex, int endByteIndex) {
+ long[][] counts = new long[8][];
+ long bitwiseMax = 0;
+ long bitwiseMin = -1L;
+ long limit = array.getBaseOffset() + numRecords * 16;
+ Object baseObject = array.getBaseObject();
+ for (long offset = array.getBaseOffset(); offset < limit; offset += 16) {
+ long value = Platform.getLong(baseObject, offset + 8);
+ bitwiseMax |= value;
+ bitwiseMin &= value;
+ }
+ long bitsChanged = bitwiseMin ^ bitwiseMax;
+ for (int i = startByteIndex; i <= endByteIndex; i++) {
+ if (((bitsChanged >>> (i * 8)) & 0xff) != 0) {
+ counts[i] = new long[256];
+ for (long offset = array.getBaseOffset(); offset < limit; offset += 16) {
+ counts[i][(int)((Platform.getLong(baseObject, offset + 8) >>> (i * 8)) & 0xff)]++;
+ }
+ }
+ }
+ return counts;
+ }
+
+ /**
+ * Specialization of sortAtByte() for key-prefix arrays.
+ */
+ private static void sortKeyPrefixArrayAtByte(
+ LongArray array, int numRecords, long[] counts, int byteIdx, int inIndex, int outIndex,
+ boolean desc, boolean signed) {
+ assert counts.length == 256;
+ long[] offsets = transformCountsToOffsets(
+ counts, numRecords, array.getBaseOffset() + outIndex * 8, 16, desc, signed);
+ Object baseObject = array.getBaseObject();
+ long baseOffset = array.getBaseOffset() + inIndex * 8;
+ long maxOffset = baseOffset + numRecords * 16;
+ for (long offset = baseOffset; offset < maxOffset; offset += 16) {
+ long key = Platform.getLong(baseObject, offset);
+ long prefix = Platform.getLong(baseObject, offset + 8);
+ int bucket = (int)((prefix >>> (byteIdx * 8)) & 0xff);
+ long dest = offsets[bucket];
+ Platform.putLong(baseObject, dest, key);
+ Platform.putLong(baseObject, dest + 8, prefix);
+ offsets[bucket] += 16;
+ }
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index 3e32dd9d63..66a77982ad 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -89,7 +89,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
UnsafeInMemorySorter inMemorySorter) throws IOException {
UnsafeExternalSorter sorter = new UnsafeExternalSorter(taskMemoryManager, blockManager,
serializerManager, taskContext, recordComparator, prefixComparator, initialSize,
- pageSizeBytes, inMemorySorter);
+ pageSizeBytes, inMemorySorter, false /* ignored */);
sorter.spill(Long.MAX_VALUE, sorter);
// The external sorter will be used to insert records, in-memory sorter is not needed.
sorter.inMemSorter = null;
@@ -104,9 +104,11 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
RecordComparator recordComparator,
PrefixComparator prefixComparator,
int initialSize,
- long pageSizeBytes) {
+ long pageSizeBytes,
+ boolean canUseRadixSort) {
return new UnsafeExternalSorter(taskMemoryManager, blockManager, serializerManager,
- taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, null);
+ taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, null,
+ canUseRadixSort);
}
private UnsafeExternalSorter(
@@ -118,7 +120,8 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
PrefixComparator prefixComparator,
int initialSize,
long pageSizeBytes,
- @Nullable UnsafeInMemorySorter existingInMemorySorter) {
+ @Nullable UnsafeInMemorySorter existingInMemorySorter,
+ boolean canUseRadixSort) {
super(taskMemoryManager, pageSizeBytes);
this.taskMemoryManager = taskMemoryManager;
this.blockManager = blockManager;
@@ -133,7 +136,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
if (existingInMemorySorter == null) {
this.inMemSorter = new UnsafeInMemorySorter(
- this, taskMemoryManager, recordComparator, prefixComparator, initialSize);
+ this, taskMemoryManager, recordComparator, prefixComparator, initialSize, canUseRadixSort);
} else {
this.inMemSorter = existingInMemorySorter;
}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index 01eae0e8dc..5f46ef9a81 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -17,6 +17,7 @@
package org.apache.spark.util.collection.unsafe.sort;
+import java.lang.Long;
import java.util.Comparator;
import org.apache.avro.reflect.Nullable;
@@ -26,6 +27,7 @@ import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.util.collection.Sorter;
+import org.apache.spark.util.collection.unsafe.sort.RadixSort;
/**
* Sorts records using an AlphaSort-style key-prefix sort. This sort stores pointers to records
@@ -74,6 +76,17 @@ public final class UnsafeInMemorySorter {
private final Comparator<RecordPointerAndKeyPrefix> sortComparator;
/**
+ * If non-null, specifies the radix sort parameters and that radix sort will be used.
+ */
+ @Nullable
+ private final PrefixComparators.RadixSortSupport radixSortSupport;
+
+ /**
+ * Set to 2x for radix sort to reserve extra memory for sorting, otherwise 1x.
+ */
+ private final int memoryAllocationFactor;
+
+ /**
* Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at
* index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix.
*/
@@ -91,27 +104,36 @@ public final class UnsafeInMemorySorter {
final TaskMemoryManager memoryManager,
final RecordComparator recordComparator,
final PrefixComparator prefixComparator,
- int initialSize) {
+ int initialSize,
+ boolean canUseRadixSort) {
this(consumer, memoryManager, recordComparator, prefixComparator,
- consumer.allocateArray(initialSize * 2));
+ consumer.allocateArray(initialSize * 2), canUseRadixSort);
}
public UnsafeInMemorySorter(
- final MemoryConsumer consumer,
+ final MemoryConsumer consumer,
final TaskMemoryManager memoryManager,
final RecordComparator recordComparator,
final PrefixComparator prefixComparator,
- LongArray array) {
+ LongArray array,
+ boolean canUseRadixSort) {
this.consumer = consumer;
this.memoryManager = memoryManager;
this.initialSize = array.size();
if (recordComparator != null) {
this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE);
this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager);
+ if (canUseRadixSort && prefixComparator instanceof PrefixComparators.RadixSortSupport) {
+ this.radixSortSupport = (PrefixComparators.RadixSortSupport)prefixComparator;
+ } else {
+ this.radixSortSupport = null;
+ }
} else {
this.sorter = null;
this.sortComparator = null;
+ this.radixSortSupport = null;
}
+ this.memoryAllocationFactor = this.radixSortSupport != null ? 2 : 1;
this.array = array;
}
@@ -141,11 +163,11 @@ public final class UnsafeInMemorySorter {
}
public long getMemoryUsage() {
- return array.size() * 8L;
+ return array.size() * Long.BYTES;
}
public boolean hasSpaceForAnotherRecord() {
- return pos + 2 <= array.size();
+ return pos + 1 < (array.size() / memoryAllocationFactor);
}
public void expandPointerArray(LongArray newArray) {
@@ -157,7 +179,7 @@ public final class UnsafeInMemorySorter {
array.getBaseOffset(),
newArray.getBaseObject(),
newArray.getBaseOffset(),
- array.size() * 8L);
+ array.size() * (Long.BYTES / memoryAllocationFactor));
consumer.freeArray(array);
array = newArray;
}
@@ -183,18 +205,20 @@ public final class UnsafeInMemorySorter {
private final int numRecords;
private int position;
+ private int offset;
private Object baseObject;
private long baseOffset;
private long keyPrefix;
private int recordLength;
- private SortedIterator(int numRecords) {
+ private SortedIterator(int numRecords, int offset) {
this.numRecords = numRecords;
this.position = 0;
+ this.offset = offset;
}
public SortedIterator clone() {
- SortedIterator iter = new SortedIterator(numRecords);
+ SortedIterator iter = new SortedIterator(numRecords, offset);
iter.position = position;
iter.baseObject = baseObject;
iter.baseOffset = baseOffset;
@@ -216,11 +240,11 @@ public final class UnsafeInMemorySorter {
@Override
public void loadNext() {
// This pointer points to a 4-byte record length, followed by the record's bytes
- final long recordPointer = array.get(position);
+ final long recordPointer = array.get(offset + position);
baseObject = memoryManager.getPage(recordPointer);
baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4; // Skip over record length
recordLength = Platform.getInt(baseObject, baseOffset - 4);
- keyPrefix = array.get(position + 1);
+ keyPrefix = array.get(offset + position + 1);
position += 2;
}
@@ -242,9 +266,17 @@ public final class UnsafeInMemorySorter {
* {@code next()} will return the same mutable object.
*/
public SortedIterator getSortedIterator() {
+ int offset = 0;
if (sorter != null) {
- sorter.sort(array, 0, pos / 2, sortComparator);
+ if (this.radixSortSupport != null) {
+ // TODO(ekl) we should handle NULL values before radix sort for efficiency, since they
+ // force a full-width sort (and we cannot radix-sort nullable long fields at all).
+ offset = RadixSort.sortKeyPrefixArray(
+ array, pos / 2, 0, 7, radixSortSupport.sortDescending(), radixSortSupport.sortSigned());
+ } else {
+ sorter.sort(array, 0, pos / 2, sortComparator);
+ }
}
- return new SortedIterator(pos / 2);
+ return new SortedIterator(pos / 2, offset);
}
}
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemoryRadixSorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemoryRadixSorterSuite.java
new file mode 100644
index 0000000000..6927d0a815
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemoryRadixSorterSuite.java
@@ -0,0 +1,23 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+public class ShuffleInMemoryRadixSorterSuite extends ShuffleInMemorySorterSuite {
+ @Override
+ protected boolean shouldUseRadixSort() { return true; }
+}
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
index 4cd3600df1..43e32f073a 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
@@ -17,6 +17,7 @@
package org.apache.spark.shuffle.sort;
+import java.lang.Long;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Random;
@@ -34,6 +35,8 @@ import org.apache.spark.unsafe.memory.MemoryBlock;
public class ShuffleInMemorySorterSuite {
+ protected boolean shouldUseRadixSort() { return false; }
+
final TestMemoryManager memoryManager =
new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false"));
final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0);
@@ -47,7 +50,8 @@ public class ShuffleInMemorySorterSuite {
@Test
public void testSortingEmptyInput() {
- final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(consumer, 100);
+ final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(
+ consumer, 100, shouldUseRadixSort());
final ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator();
Assert.assertFalse(iter.hasNext());
}
@@ -70,14 +74,16 @@ public class ShuffleInMemorySorterSuite {
new TaskMemoryManager(new TestMemoryManager(conf), 0);
final MemoryBlock dataPage = memoryManager.allocatePage(2048, null);
final Object baseObject = dataPage.getBaseObject();
- final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(consumer, 4);
+ final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(
+ consumer, 4, shouldUseRadixSort());
final HashPartitioner hashPartitioner = new HashPartitioner(4);
// Write the records into the data page and store pointers into the sorter
long position = dataPage.getBaseOffset();
for (String str : dataToSort) {
if (!sorter.hasSpaceForAnotherRecord()) {
- sorter.expandPointerArray(consumer.allocateArray(sorter.numRecords() * 2));
+ sorter.expandPointerArray(
+ consumer.allocateArray(sorter.getMemoryUsage() / Long.BYTES * 2));
}
final long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, position);
final byte[] strBytes = str.getBytes(StandardCharsets.UTF_8);
@@ -114,12 +120,12 @@ public class ShuffleInMemorySorterSuite {
@Test
public void testSortingManyNumbers() throws Exception {
- ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(consumer, 4);
+ ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(consumer, 4, shouldUseRadixSort());
int[] numbersToSort = new int[128000];
Random random = new Random(16);
for (int i = 0; i < numbersToSort.length; i++) {
if (!sorter.hasSpaceForAnotherRecord()) {
- sorter.expandPointerArray(consumer.allocateArray(sorter.numRecords() * 2));
+ sorter.expandPointerArray(consumer.allocateArray(sorter.getMemoryUsage() / 8 * 2));
}
numbersToSort[i] = random.nextInt(PackedRecordPointer.MAXIMUM_PARTITION_ID + 1);
sorter.insertRecord(0, numbersToSort[i]);
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index fbaaa1cf49..f9dc20d8b7 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -392,7 +392,20 @@ public class UnsafeShuffleWriterSuite {
}
@Test
- public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception {
+ public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpillRadixOff() throws Exception {
+ conf.set("spark.shuffle.sort.useRadixSort", "false");
+ writeEnoughRecordsToTriggerSortBufferExpansionAndSpill();
+ assertEquals(2, spillFilesCreated.size());
+ }
+
+ @Test
+ public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpillRadixOn() throws Exception {
+ conf.set("spark.shuffle.sort.useRadixSort", "true");
+ writeEnoughRecordsToTriggerSortBufferExpansionAndSpill();
+ assertEquals(3, spillFilesCreated.size());
+ }
+
+ private void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception {
memoryManager.limit(UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE * 16);
final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<>();
@@ -400,7 +413,6 @@ public class UnsafeShuffleWriterSuite {
dataToWrite.add(new Tuple2<Object, Object>(i, i));
}
writer.write(dataToWrite.iterator());
- assertEquals(2, spillFilesCreated.size());
writer.stop(true);
readRecordsFromFile();
assertSpillFilesWereCleanedUp();
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterRadixSortSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterRadixSortSuite.java
new file mode 100644
index 0000000000..bb38305a07
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterRadixSortSuite.java
@@ -0,0 +1,23 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+public class UnsafeExternalSorterRadixSortSuite extends UnsafeExternalSorterSuite {
+ @Override
+ protected boolean shouldUseRadixSort() { return true; }
+}
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
index a2253d8559..60a40cc172 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -64,12 +64,7 @@ public class UnsafeExternalSorterSuite {
new JavaSerializer(new SparkConf()),
new SparkConf().set("spark.shuffle.spill.compress", "false"));
// Use integer comparison for comparing prefixes (which are partition ids, in this case)
- final PrefixComparator prefixComparator = new PrefixComparator() {
- @Override
- public int compare(long prefix1, long prefix2) {
- return (int) prefix1 - (int) prefix2;
- }
- };
+ final PrefixComparator prefixComparator = PrefixComparators.LONG;
// Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so
// use a dummy comparator
final RecordComparator recordComparator = new RecordComparator() {
@@ -88,6 +83,7 @@ public class UnsafeExternalSorterSuite {
@Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
@Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext;
+ protected boolean shouldUseRadixSort() { return false; }
private final long pageSizeBytes = new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "4m");
@@ -178,7 +174,8 @@ public class UnsafeExternalSorterSuite {
recordComparator,
prefixComparator,
/* initialSize */ 1024,
- pageSizeBytes);
+ pageSizeBytes,
+ shouldUseRadixSort());
}
@Test
@@ -381,7 +378,8 @@ public class UnsafeExternalSorterSuite {
null,
null,
/* initialSize */ 1024,
- pageSizeBytes);
+ pageSizeBytes,
+ shouldUseRadixSort());
long[] record = new long[100];
int recordSize = record.length * 8;
int n = (int) pageSizeBytes / recordSize * 3;
@@ -416,7 +414,8 @@ public class UnsafeExternalSorterSuite {
recordComparator,
prefixComparator,
1024,
- pageSizeBytes);
+ pageSizeBytes,
+ shouldUseRadixSort());
// Peak memory should be monotonically increasing. More specifically, every time
// we allocate a new page it should increase by exactly the size of the page.
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterRadixSortSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterRadixSortSuite.java
new file mode 100644
index 0000000000..ae69ededf7
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterRadixSortSuite.java
@@ -0,0 +1,23 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+public class UnsafeInMemorySorterRadixSortSuite extends UnsafeInMemorySorterSuite {
+ @Override
+ protected boolean shouldUseRadixSort() { return true; }
+}
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
index f90214fffd..23f4abfed2 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
@@ -17,6 +17,7 @@
package org.apache.spark.util.collection.unsafe.sort;
+import java.lang.Long;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
@@ -39,6 +40,8 @@ import static org.mockito.Mockito.mock;
public class UnsafeInMemorySorterSuite {
+ protected boolean shouldUseRadixSort() { return false; }
+
private static String getStringFromDataPage(Object baseObject, long baseOffset, int length) {
final byte[] strBytes = new byte[length];
Platform.copyMemory(baseObject, baseOffset, strBytes, Platform.BYTE_ARRAY_OFFSET, length);
@@ -54,7 +57,8 @@ public class UnsafeInMemorySorterSuite {
memoryManager,
mock(RecordComparator.class),
mock(PrefixComparator.class),
- 100);
+ 100,
+ shouldUseRadixSort());
final UnsafeSorterIterator iter = sorter.getSortedIterator();
Assert.assertFalse(iter.hasNext());
}
@@ -102,19 +106,15 @@ public class UnsafeInMemorySorterSuite {
// Compute key prefixes based on the records' partition ids
final HashPartitioner hashPartitioner = new HashPartitioner(4);
// Use integer comparison for comparing prefixes (which are partition ids, in this case)
- final PrefixComparator prefixComparator = new PrefixComparator() {
- @Override
- public int compare(long prefix1, long prefix2) {
- return (int) prefix1 - (int) prefix2;
- }
- };
+ final PrefixComparator prefixComparator = PrefixComparators.LONG;
UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(consumer, memoryManager,
- recordComparator, prefixComparator, dataToSort.length);
+ recordComparator, prefixComparator, dataToSort.length, shouldUseRadixSort());
// Given a page of records, insert those records into the sorter one-by-one:
position = dataPage.getBaseOffset();
for (int i = 0; i < dataToSort.length; i++) {
if (!sorter.hasSpaceForAnotherRecord()) {
- sorter.expandPointerArray(consumer.allocateArray(sorter.numRecords() * 2 * 2));
+ sorter.expandPointerArray(
+ consumer.allocateArray(sorter.getMemoryUsage() / Long.BYTES * 2));
}
// position now points to the start of a record (which holds its length).
final int recordLength = Platform.getInt(baseObject, position);
diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
index dda8bee222..b4083230b4 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
@@ -101,6 +101,8 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks {
test("double prefix comparator handles NaNs properly") {
val nan1: Double = java.lang.Double.longBitsToDouble(0x7ff0000000000001L)
val nan2: Double = java.lang.Double.longBitsToDouble(0x7fffffffffffffffL)
+ assert(
+ java.lang.Double.doubleToRawLongBits(nan1) != java.lang.Double.doubleToRawLongBits(nan2))
assert(nan1.isNaN)
assert(nan2.isNaN)
val nan1Prefix = PrefixComparators.DoublePrefixComparator.computePrefix(nan1)
@@ -110,4 +112,28 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks {
assert(PrefixComparators.DOUBLE.compare(nan1Prefix, doubleMaxPrefix) === 1)
}
+ test("double prefix comparator handles negative NaNs properly") {
+ val negativeNan: Double = java.lang.Double.longBitsToDouble(0xfff0000000000001L)
+ assert(negativeNan.isNaN)
+ assert(java.lang.Double.doubleToRawLongBits(negativeNan) < 0)
+ val prefix = PrefixComparators.DoublePrefixComparator.computePrefix(negativeNan)
+ val doubleMaxPrefix = PrefixComparators.DoublePrefixComparator.computePrefix(Double.MaxValue)
+ assert(PrefixComparators.DOUBLE.compare(prefix, doubleMaxPrefix) === 1)
+ }
+
+ test("double prefix comparator handles other special values properly") {
+ val nullValue = 0L
+ val nan = PrefixComparators.DoublePrefixComparator.computePrefix(Double.NaN)
+ val posInf = PrefixComparators.DoublePrefixComparator.computePrefix(Double.PositiveInfinity)
+ val negInf = PrefixComparators.DoublePrefixComparator.computePrefix(Double.NegativeInfinity)
+ val minValue = PrefixComparators.DoublePrefixComparator.computePrefix(Double.MinValue)
+ val maxValue = PrefixComparators.DoublePrefixComparator.computePrefix(Double.MaxValue)
+ val zero = PrefixComparators.DoublePrefixComparator.computePrefix(0.0)
+ assert(PrefixComparators.DOUBLE.compare(nan, posInf) === 1)
+ assert(PrefixComparators.DOUBLE.compare(posInf, maxValue) === 1)
+ assert(PrefixComparators.DOUBLE.compare(maxValue, zero) === 1)
+ assert(PrefixComparators.DOUBLE.compare(zero, minValue) === 1)
+ assert(PrefixComparators.DOUBLE.compare(minValue, negInf) === 1)
+ assert(PrefixComparators.DOUBLE.compare(negInf, nullValue) === 1)
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala
new file mode 100644
index 0000000000..52428634e5
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala
@@ -0,0 +1,264 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort
+
+import java.lang.{Long => JLong}
+import java.util.{Arrays, Comparator}
+
+import scala.util.Random
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.internal.Logging
+import org.apache.spark.unsafe.array.LongArray
+import org.apache.spark.unsafe.memory.MemoryBlock
+import org.apache.spark.util.Benchmark
+import org.apache.spark.util.collection.Sorter
+import org.apache.spark.util.random.XORShiftRandom
+
+class RadixSortSuite extends SparkFunSuite with Logging {
+ private val N = 10000 // scale this down for more readable results
+
+ /**
+ * Describes a type of sort to test, e.g. two's complement descending. Each sort type has
+ * a defined reference ordering as well as radix sort parameters that can be used to
+ * reproduce the given ordering.
+ */
+ case class RadixSortType(
+ name: String,
+ referenceComparator: PrefixComparator,
+ startByteIdx: Int, endByteIdx: Int, descending: Boolean, signed: Boolean)
+
+ val SORT_TYPES_TO_TEST = Seq(
+ RadixSortType("unsigned binary data asc", PrefixComparators.BINARY, 0, 7, false, false),
+ RadixSortType("unsigned binary data desc", PrefixComparators.BINARY_DESC, 0, 7, true, false),
+ RadixSortType("twos complement asc", PrefixComparators.LONG, 0, 7, false, true),
+ RadixSortType("twos complement desc", PrefixComparators.LONG_DESC, 0, 7, true, true),
+ RadixSortType(
+ "binary data partial",
+ new PrefixComparators.RadixSortSupport {
+ override def sortDescending = false
+ override def sortSigned = false
+ override def compare(a: Long, b: Long): Int = {
+ return PrefixComparators.BINARY.compare(a & 0xffffff0000L, b & 0xffffff0000L)
+ }
+ },
+ 2, 4, false, false))
+
+ private def generateTestData(size: Int, rand: => Long): (Array[JLong], LongArray) = {
+ val ref = Array.tabulate[Long](size) { i => rand }
+ val extended = ref ++ Array.fill[Long](size)(0)
+ (ref.map(i => new JLong(i)), new LongArray(MemoryBlock.fromLongArray(extended)))
+ }
+
+ private def generateKeyPrefixTestData(size: Int, rand: => Long): (LongArray, LongArray) = {
+ val ref = Array.tabulate[Long](size * 2) { i => rand }
+ val extended = ref ++ Array.fill[Long](size * 2)(0)
+ (new LongArray(MemoryBlock.fromLongArray(ref)),
+ new LongArray(MemoryBlock.fromLongArray(extended)))
+ }
+
+ private def collectToArray(array: LongArray, offset: Int, length: Int): Array[Long] = {
+ var i = 0
+ val out = new Array[Long](length)
+ while (i < length) {
+ out(i) = array.get(offset + i)
+ i += 1
+ }
+ out
+ }
+
+ private def toJavaComparator(p: PrefixComparator): Comparator[JLong] = {
+ new Comparator[JLong] {
+ override def compare(a: JLong, b: JLong): Int = {
+ p.compare(a, b)
+ }
+ override def equals(other: Any): Boolean = {
+ other == this
+ }
+ }
+ }
+
+ private def referenceKeyPrefixSort(buf: LongArray, lo: Int, hi: Int, refCmp: PrefixComparator) {
+ new Sorter(UnsafeSortDataFormat.INSTANCE).sort(
+ buf, lo, hi, new Comparator[RecordPointerAndKeyPrefix] {
+ override def compare(
+ r1: RecordPointerAndKeyPrefix,
+ r2: RecordPointerAndKeyPrefix): Int = {
+ refCmp.compare(r1.keyPrefix, r2.keyPrefix)
+ }
+ })
+ }
+
+ private def fuzzTest(name: String)(testFn: Long => Unit): Unit = {
+ test(name) {
+ var seed = 0L
+ try {
+ for (i <- 0 to 10) {
+ seed = System.nanoTime
+ testFn(seed)
+ }
+ } catch {
+ case t: Throwable =>
+ throw new Exception("Failed with seed: " + seed, t)
+ }
+ }
+ }
+
+ // Radix sort is sensitive to the value distribution at different bit indices (e.g., we may
+ // omit a sort on a byte if all values are equal). This generates random good test masks.
+ def randomBitMask(rand: Random): Long = {
+ var tmp = ~0L
+ for (i <- 0 to rand.nextInt(5)) {
+ tmp &= rand.nextLong
+ }
+ tmp
+ }
+
+ for (sortType <- SORT_TYPES_TO_TEST) {
+ test("radix support for " + sortType.name) {
+ val s = sortType.referenceComparator.asInstanceOf[PrefixComparators.RadixSortSupport]
+ assert(s.sortDescending() == sortType.descending)
+ assert(s.sortSigned() == sortType.signed)
+ }
+
+ test("sort " + sortType.name) {
+ val rand = new XORShiftRandom(123)
+ val (ref, buffer) = generateTestData(N, rand.nextLong)
+ Arrays.sort(ref, toJavaComparator(sortType.referenceComparator))
+ val outOffset = RadixSort.sort(
+ buffer, N, sortType.startByteIdx, sortType.endByteIdx,
+ sortType.descending, sortType.signed)
+ val result = collectToArray(buffer, outOffset, N)
+ assert(ref.view == result.view)
+ }
+
+ test("sort key prefix " + sortType.name) {
+ val rand = new XORShiftRandom(123)
+ val (buf1, buf2) = generateKeyPrefixTestData(N, rand.nextLong & 0xff)
+ referenceKeyPrefixSort(buf1, 0, N, sortType.referenceComparator)
+ val outOffset = RadixSort.sortKeyPrefixArray(
+ buf2, N, sortType.startByteIdx, sortType.endByteIdx,
+ sortType.descending, sortType.signed)
+ val res1 = collectToArray(buf1, 0, N * 2)
+ val res2 = collectToArray(buf2, outOffset, N * 2)
+ assert(res1.view == res2.view)
+ }
+
+ fuzzTest(s"fuzz test ${sortType.name} with random bitmasks") { seed =>
+ val rand = new XORShiftRandom(seed)
+ val mask = randomBitMask(rand)
+ val (ref, buffer) = generateTestData(N, rand.nextLong & mask)
+ Arrays.sort(ref, toJavaComparator(sortType.referenceComparator))
+ val outOffset = RadixSort.sort(
+ buffer, N, sortType.startByteIdx, sortType.endByteIdx,
+ sortType.descending, sortType.signed)
+ val result = collectToArray(buffer, outOffset, N)
+ assert(ref.view == result.view)
+ }
+
+ fuzzTest(s"fuzz test key prefix ${sortType.name} with random bitmasks") { seed =>
+ val rand = new XORShiftRandom(seed)
+ val mask = randomBitMask(rand)
+ val (buf1, buf2) = generateKeyPrefixTestData(N, rand.nextLong & mask)
+ referenceKeyPrefixSort(buf1, 0, N, sortType.referenceComparator)
+ val outOffset = RadixSort.sortKeyPrefixArray(
+ buf2, N, sortType.startByteIdx, sortType.endByteIdx,
+ sortType.descending, sortType.signed)
+ val res1 = collectToArray(buf1, 0, N * 2)
+ val res2 = collectToArray(buf2, outOffset, N * 2)
+ assert(res1.view == res2.view)
+ }
+ }
+
+ ignore("microbenchmarks") {
+ val size = 25000000
+ val rand = new XORShiftRandom(123)
+ val benchmark = new Benchmark("radix sort " + size, size)
+ benchmark.addTimerCase("reference TimSort key prefix array") { timer =>
+ val array = Array.tabulate[Long](size * 2) { i => rand.nextLong }
+ val buf = new LongArray(MemoryBlock.fromLongArray(array))
+ timer.startTiming()
+ referenceKeyPrefixSort(buf, 0, size, PrefixComparators.BINARY)
+ timer.stopTiming()
+ }
+ benchmark.addTimerCase("reference Arrays.sort") { timer =>
+ val ref = Array.tabulate[Long](size) { i => rand.nextLong }
+ timer.startTiming()
+ Arrays.sort(ref)
+ timer.stopTiming()
+ }
+ benchmark.addTimerCase("radix sort one byte") { timer =>
+ val array = new Array[Long](size * 2)
+ var i = 0
+ while (i < size) {
+ array(i) = rand.nextLong & 0xff
+ i += 1
+ }
+ val buf = new LongArray(MemoryBlock.fromLongArray(array))
+ timer.startTiming()
+ RadixSort.sort(buf, size, 0, 7, false, false)
+ timer.stopTiming()
+ }
+ benchmark.addTimerCase("radix sort two bytes") { timer =>
+ val array = new Array[Long](size * 2)
+ var i = 0
+ while (i < size) {
+ array(i) = rand.nextLong & 0xffff
+ i += 1
+ }
+ val buf = new LongArray(MemoryBlock.fromLongArray(array))
+ timer.startTiming()
+ RadixSort.sort(buf, size, 0, 7, false, false)
+ timer.stopTiming()
+ }
+ benchmark.addTimerCase("radix sort eight bytes") { timer =>
+ val array = new Array[Long](size * 2)
+ var i = 0
+ while (i < size) {
+ array(i) = rand.nextLong
+ i += 1
+ }
+ val buf = new LongArray(MemoryBlock.fromLongArray(array))
+ timer.startTiming()
+ RadixSort.sort(buf, size, 0, 7, false, false)
+ timer.stopTiming()
+ }
+ benchmark.addTimerCase("radix sort key prefix array") { timer =>
+ val (_, buf2) = generateKeyPrefixTestData(size, rand.nextLong)
+ timer.startTiming()
+ RadixSort.sortKeyPrefixArray(buf2, size, 0, 7, false, false)
+ timer.stopTiming()
+ }
+ benchmark.run
+
+ /**
+ Running benchmark: radix sort 25000000
+ Java HotSpot(TM) 64-Bit Server VM 1.8.0_66-b17 on Linux 3.13.0-44-generic
+ Intel(R) Core(TM) i7-4600U CPU @ 2.10GHz
+
+ radix sort 25000000: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ -------------------------------------------------------------------------------------------
+ reference TimSort key prefix array 15546 / 15859 1.6 621.9 1.0X
+ reference Arrays.sort 2416 / 2446 10.3 96.6 6.4X
+ radix sort one byte 133 / 137 188.4 5.3 117.2X
+ radix sort two bytes 255 / 258 98.2 10.2 61.1X
+ radix sort eight bytes 991 / 997 25.2 39.6 15.7X
+ radix sort key prefix array 1540 / 1563 16.2 61.6 10.1X
+ */
+ }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
index 7784345a7a..8d9906da7e 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -59,7 +59,8 @@ public final class UnsafeExternalRowSorter {
Ordering<InternalRow> ordering,
PrefixComparator prefixComparator,
PrefixComputer prefixComputer,
- long pageSizeBytes) throws IOException {
+ long pageSizeBytes,
+ boolean canUseRadixSort) throws IOException {
this.schema = schema;
this.prefixComputer = prefixComputer;
final SparkEnv sparkEnv = SparkEnv.get();
@@ -72,7 +73,8 @@ public final class UnsafeExternalRowSorter {
new RowComparator(ordering, schema.length()),
prefixComparator,
/* initialSize */ 4096,
- pageSizeBytes
+ pageSizeBytes,
+ canUseRadixSort
);
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
index e0c3b22a3c..42a8be6b1b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
@@ -84,8 +84,7 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression {
case DateType | TimestampType =>
(Long.MinValue, s"(long) $input")
case FloatType | DoubleType =>
- (DoublePrefixComparator.computePrefix(Double.NegativeInfinity),
- s"$DoublePrefixCmp.computePrefix((double)$input)")
+ (0L, s"$DoublePrefixCmp.computePrefix((double)$input)")
case StringType => (0L, s"$input.getPrefix()")
case BinaryType => (0L, s"$BinaryPrefixCmp.computePrefix($input)")
case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS =>
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
index 8132bba04c..b6499e35b5 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
@@ -85,13 +85,15 @@ public final class UnsafeKVExternalSorter {
recordComparator,
prefixComparator,
/* initialSize */ 4096,
- pageSizeBytes);
+ pageSizeBytes,
+ keySchema.length() == 1 && SortPrefixUtils.canSortFullyWithPrefix(keySchema.apply(0)));
} else {
// During spilling, the array in map will not be used, so we can borrow that and use it
// as the underline array for in-memory sorter (it's always large enough).
// Since we will not grow the array, it's fine to pass `null` as consumer.
final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter(
- null, taskMemoryManager, recordComparator, prefixComparator, map.getArray());
+ null, taskMemoryManager, recordComparator, prefixComparator, map.getArray(),
+ false /* TODO(ekl) we can only radix sort if the BytesToBytes load factor is <= 0.5 */);
// We cannot use the destructive iterator here because we are reusing the existing memory
// pages in BytesToBytesMap to hold records during sorting.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
index 80255fafbe..6f9baa2d33 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
@@ -25,6 +25,8 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, OrderedDistribution, UnspecifiedDistribution}
import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.types._
+import org.apache.spark.util.collection.unsafe.sort.RadixSort;
/**
* Performs (external) sorting.
@@ -48,7 +50,10 @@ case class Sort(
override def requiredChildDistribution: Seq[Distribution] =
if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil
+ private val enableRadixSort = sqlContext.conf.enableRadixSort
+
override private[sql] lazy val metrics = Map(
+ "sortTime" -> SQLMetrics.createLongMetric(sparkContext, "sort time"),
"dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
"spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"))
@@ -59,6 +64,9 @@ case class Sort(
val boundSortExpression = BindReferences.bindReference(sortOrder.head, output)
val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression)
+ val canUseRadixSort = enableRadixSort && sortOrder.length == 1 &&
+ SortPrefixUtils.canSortFullyWithPrefix(boundSortExpression)
+
// The generator for prefix
val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression)))
val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer {
@@ -69,7 +77,8 @@ case class Sort(
val pageSize = SparkEnv.get.memoryManager.pageSizeBytes
val sorter = new UnsafeExternalRowSorter(
- schema, ordering, prefixComparator, prefixComputer, pageSize)
+ schema, ordering, prefixComparator, prefixComputer, pageSize, canUseRadixSort)
+
if (testSpillFrequency > 0) {
sorter.setTestSpillFrequency(testSpillFrequency)
}
@@ -139,11 +148,15 @@ case class Sort(
val dataSize = metricTerm(ctx, "dataSize")
val spillSize = metricTerm(ctx, "spillSize")
val spillSizeBefore = ctx.freshName("spillSizeBefore")
+ val startTime = ctx.freshName("startTime")
+ val sortTime = metricTerm(ctx, "sortTime")
s"""
| if ($needToSort) {
| $addToSorter();
- | Long $spillSizeBefore = $metrics.memoryBytesSpilled();
+ | long $spillSizeBefore = $metrics.memoryBytesSpilled();
+ | long $startTime = System.nanoTime();
| $sortedIterator = $sorterVariable.sort();
+ | $sortTime.add(System.nanoTime() - $startTime);
| $dataSize.add($sorterVariable.getPeakMemoryUsage());
| $spillSize.add($metrics.memoryBytesSpilled() - $spillSizeBefore);
| $metrics.incPeakExecutionMemory($sorterVariable.getPeakMemoryUsage());
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
index 909f124d2c..1a5ff5fcec 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
@@ -66,6 +66,32 @@ object SortPrefixUtils {
}
/**
+ * Returns whether the specified SortOrder can be satisfied with a radix sort on the prefix.
+ */
+ def canSortFullyWithPrefix(sortOrder: SortOrder): Boolean = {
+ sortOrder.dataType match {
+ // TODO(ekl) long-type is problematic because it's null prefix representation collides with
+ // the lowest possible long value. Handle this special case outside radix sort.
+ case LongType if sortOrder.nullable =>
+ false
+ case BooleanType | ByteType | ShortType | IntegerType | LongType | DateType |
+ TimestampType | FloatType | DoubleType =>
+ true
+ case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
+ true
+ case _ =>
+ false
+ }
+ }
+
+ /**
+ * Returns whether the fully sorting on the specified key field is possible with radix sort.
+ */
+ def canSortFullyWithPrefix(field: StructField): Boolean = {
+ canSortFullyWithPrefix(SortOrder(BoundReference(0, field.dataType, field.nullable), Ascending))
+ }
+
+ /**
* Creates the prefix computer for the first field in the given schema, in ascending order.
*/
def createPrefixGenerator(schema: StructType): UnsafeExternalRowSorter.PrefixComputer = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
index 85ce388de0..a46d0e0ba7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
@@ -344,7 +344,8 @@ case class Window(
null,
null,
1024,
- SparkEnv.get.memoryManager.pageSizeBytes)
+ SparkEnv.get.memoryManager.pageSizeBytes,
+ false)
rows.foreach { r =>
sorter.insertRecord(r.getBaseObject, r.getBaseOffset, r.getSizeInBytes, 0)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala
index edb4c5a16f..b1de52b5f4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala
@@ -46,7 +46,8 @@ class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numField
null,
null,
1024,
- SparkEnv.get.memoryManager.pageSizeBytes)
+ SparkEnv.get.memoryManager.pageSizeBytes,
+ false)
val partition = split.asInstanceOf[CartesianPartition]
for (y <- rdd2.iterator(partition.s2, context)) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 6e7c1bc133..a4e82d80f5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -95,6 +95,14 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val RADIX_SORT_ENABLED = SQLConfigBuilder("spark.sql.sort.enableRadixSort")
+ .internal()
+ .doc("When true, enable use of radix sort when possible. Radix sort is much faster but " +
+ "requires additional memory to be reserved up-front. The memory overhead may be " +
+ "significant when sorting very small rows (up to 50% more in this case).")
+ .booleanConf
+ .createWithDefault(true)
+
val AUTO_BROADCASTJOIN_THRESHOLD = SQLConfigBuilder("spark.sql.autoBroadcastJoinThreshold")
.doc("Configures the maximum size in bytes for a table that will be broadcast to all worker " +
"nodes when performing a join. By setting this value to -1 broadcasting can be disabled. " +
@@ -584,6 +592,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {
def preferSortMergeJoin: Boolean = getConf(PREFER_SORTMERGEJOIN)
+ def enableRadixSort: Boolean = getConf(RADIX_SORT_ENABLED)
+
def defaultSizeInBytes: Long =
getConf(DEFAULT_SIZE_IN_BYTES, autoBroadcastJoinThreshold + 1L)