aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEric Liang <ekl@databricks.com>2016-04-21 16:48:51 -0700
committerDavies Liu <davies.liu@gmail.com>2016-04-21 16:48:51 -0700
commite2b5647ab92eb478b3f7b36a0ce6faf83e24c0e5 (patch)
tree4ba34de912caf16f572a42c9cc033aeb2bb9bb4a
parent6d1e4c4a65541cbf78284005de1776dc49efa9f4 (diff)
downloadspark-e2b5647ab92eb478b3f7b36a0ce6faf83e24c0e5.tar.gz
spark-e2b5647ab92eb478b3f7b36a0ce6faf83e24c0e5.tar.bz2
spark-e2b5647ab92eb478b3f7b36a0ce6faf83e24c0e5.zip
[SPARK-14724] Use radix sort for shuffles and sort operator when possible
## What changes were proposed in this pull request? Spark currently uses TimSort for all in-memory sorts, including sorts done for shuffle. One low-hanging fruit is to use radix sort when possible (e.g. sorting by integer keys). This PR adds a radix sort implementation to the unsafe sort package and switches shuffles and sorts to use it when possible. The current implementation does not have special support for null values, so we cannot radix-sort `LongType`. I will address this in a follow-up PR. ## How was this patch tested? Unit tests, enabling radix sort on existing tests. Microbenchmark results: ``` Running benchmark: radix sort 25000000 Java HotSpot(TM) 64-Bit Server VM 1.8.0_66-b17 on Linux 3.13.0-44-generic Intel(R) Core(TM) i7-4600U CPU 2.10GHz radix sort 25000000: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- reference TimSort key prefix array 15546 / 15859 1.6 621.9 1.0X reference Arrays.sort 2416 / 2446 10.3 96.6 6.4X radix sort one byte 133 / 137 188.4 5.3 117.2X radix sort two bytes 255 / 258 98.2 10.2 61.1X radix sort eight bytes 991 / 997 25.2 39.6 15.7X radix sort key prefix array 1540 / 1563 16.2 61.6 10.1X ``` I also ran a mix of the supported TPCDS queries and compared TimSort vs RadixSort metrics. The overall benchmark ran ~10% faster with radix sort on. In the breakdown below, the radix-enabled sort phases averaged about 20x faster than TimSort, however sorting is only a small fraction of the overall runtime. About half of the TPCDS queries were able to take advantage of radix sort. ``` TPCDS on master: 2499s real time, 8185s executor - 1171s in TimSort, avg 267 MB/s (note the /s accounting is weird here since dataSize counts the record sizes too) TPCDS with radix enabled: 2294s real time, 7391s executor - 596s in TimSort, avg 254 MB/s - 26s in radix sort, avg 4.2 GB/s ``` cc davies rxin Author: Eric Liang <ekl@databricks.com> Closes #12490 from ericl/sort-benchmark.
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java10
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java3
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java44
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java114
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java253
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java13
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java58
-rw-r--r--core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemoryRadixSorterSuite.java23
-rw-r--r--core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java16
-rw-r--r--core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java16
-rw-r--r--core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterRadixSortSuite.java23
-rw-r--r--core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java17
-rw-r--r--core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterRadixSortSuite.java23
-rw-r--r--core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java18
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala26
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala264
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala3
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala17
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala26
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala10
24 files changed, 876 insertions, 119 deletions
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java b/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java
index f7a6c68be9..b36da80dbc 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java
@@ -42,6 +42,16 @@ final class PackedRecordPointer {
*/
static final int MAXIMUM_PARTITION_ID = (1 << 24) - 1; // 16777215
+ /**
+ * The index of the first byte of the partition id, counting from the least significant byte.
+ */
+ static final int PARTITION_ID_START_BYTE_INDEX = 5;
+
+ /**
+ * The index of the last byte of the partition id, counting from the least significant byte.
+ */
+ static final int PARTITION_ID_END_BYTE_INDEX = 7;
+
/** Bit mask for the lower 40 bits of a long. */
private static final long MASK_LONG_LOWER_40_BITS = (1L << 40) - 1;
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
index 3c2980e442..c4041a97e8 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
@@ -115,7 +115,8 @@ final class ShuffleExternalSorter extends MemoryConsumer {
this.numElementsForSpillThreshold =
conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MAX_VALUE);
this.writeMetrics = writeMetrics;
- this.inMemSorter = new ShuffleInMemorySorter(this, initialSize);
+ this.inMemSorter = new ShuffleInMemorySorter(
+ this, initialSize, conf.getBoolean("spark.shuffle.sort.useRadixSort", true));
this.peakMemoryUsedBytes = getMemoryUsage();
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
index 76b0e6a304..68630946ac 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
@@ -17,12 +17,14 @@
package org.apache.spark.shuffle.sort;
+import java.lang.Long;
import java.util.Comparator;
import org.apache.spark.memory.MemoryConsumer;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.util.collection.Sorter;
+import org.apache.spark.util.collection.unsafe.sort.RadixSort;
final class ShuffleInMemorySorter {
@@ -47,16 +49,29 @@ final class ShuffleInMemorySorter {
private LongArray array;
/**
+ * Whether to use radix sort for sorting in-memory partition ids. Radix sort is much faster
+ * but requires additional memory to be reserved memory as pointers are added.
+ */
+ private final boolean useRadixSort;
+
+ /**
+ * Set to 2x for radix sort to reserve extra memory for sorting, otherwise 1x.
+ */
+ private final int memoryAllocationFactor;
+
+ /**
* The position in the pointer array where new records can be inserted.
*/
private int pos = 0;
private int initialSize;
- ShuffleInMemorySorter(MemoryConsumer consumer, int initialSize) {
+ ShuffleInMemorySorter(MemoryConsumer consumer, int initialSize, boolean useRadixSort) {
this.consumer = consumer;
assert (initialSize > 0);
this.initialSize = initialSize;
+ this.useRadixSort = useRadixSort;
+ this.memoryAllocationFactor = useRadixSort ? 2 : 1;
this.array = consumer.allocateArray(initialSize);
this.sorter = new Sorter<>(ShuffleSortDataFormat.INSTANCE);
}
@@ -87,18 +102,18 @@ final class ShuffleInMemorySorter {
array.getBaseOffset(),
newArray.getBaseObject(),
newArray.getBaseOffset(),
- array.size() * 8L
+ array.size() * (Long.BYTES / memoryAllocationFactor)
);
consumer.freeArray(array);
array = newArray;
}
public boolean hasSpaceForAnotherRecord() {
- return pos < array.size();
+ return pos < array.size() / memoryAllocationFactor;
}
public long getMemoryUsage() {
- return array.size() * 8L;
+ return array.size() * Long.BYTES;
}
/**
@@ -125,17 +140,18 @@ final class ShuffleInMemorySorter {
public static final class ShuffleSorterIterator {
private final LongArray pointerArray;
- private final int numRecords;
+ private final int limit;
final PackedRecordPointer packedRecordPointer = new PackedRecordPointer();
private int position = 0;
- ShuffleSorterIterator(int numRecords, LongArray pointerArray) {
- this.numRecords = numRecords;
+ ShuffleSorterIterator(int numRecords, LongArray pointerArray, int startingPosition) {
+ this.limit = numRecords + startingPosition;
this.pointerArray = pointerArray;
+ this.position = startingPosition;
}
public boolean hasNext() {
- return position < numRecords;
+ return position < limit;
}
public void loadNext() {
@@ -148,7 +164,15 @@ final class ShuffleInMemorySorter {
* Return an iterator over record pointers in sorted order.
*/
public ShuffleSorterIterator getSortedIterator() {
- sorter.sort(array, 0, pos, SORT_COMPARATOR);
- return new ShuffleSorterIterator(pos, array);
+ int offset = 0;
+ if (useRadixSort) {
+ offset = RadixSort.sort(
+ array, pos,
+ PackedRecordPointer.PARTITION_ID_START_BYTE_INDEX,
+ PackedRecordPointer.PARTITION_ID_END_BYTE_INDEX, false, false);
+ } else {
+ sorter.sort(array, 0, pos, SORT_COMPARATOR);
+ }
+ return new ShuffleSorterIterator(pos, array, offset);
}
}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
index c2a8f429be..21f2fde79d 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
@@ -28,88 +28,92 @@ import org.apache.spark.util.Utils;
public class PrefixComparators {
private PrefixComparators() {}
- public static final StringPrefixComparator STRING = new StringPrefixComparator();
- public static final StringPrefixComparatorDesc STRING_DESC = new StringPrefixComparatorDesc();
- public static final BinaryPrefixComparator BINARY = new BinaryPrefixComparator();
- public static final BinaryPrefixComparatorDesc BINARY_DESC = new BinaryPrefixComparatorDesc();
- public static final LongPrefixComparator LONG = new LongPrefixComparator();
- public static final LongPrefixComparatorDesc LONG_DESC = new LongPrefixComparatorDesc();
- public static final DoublePrefixComparator DOUBLE = new DoublePrefixComparator();
- public static final DoublePrefixComparatorDesc DOUBLE_DESC = new DoublePrefixComparatorDesc();
-
- public static final class StringPrefixComparator extends PrefixComparator {
- @Override
- public int compare(long aPrefix, long bPrefix) {
- return UnsignedLongs.compare(aPrefix, bPrefix);
- }
-
+ public static final PrefixComparator STRING = new UnsignedPrefixComparator();
+ public static final PrefixComparator STRING_DESC = new UnsignedPrefixComparatorDesc();
+ public static final PrefixComparator BINARY = new UnsignedPrefixComparator();
+ public static final PrefixComparator BINARY_DESC = new UnsignedPrefixComparatorDesc();
+ public static final PrefixComparator LONG = new SignedPrefixComparator();
+ public static final PrefixComparator LONG_DESC = new SignedPrefixComparatorDesc();
+ public static final PrefixComparator DOUBLE = new UnsignedPrefixComparator();
+ public static final PrefixComparator DOUBLE_DESC = new UnsignedPrefixComparatorDesc();
+
+ public static final class StringPrefixComparator {
public static long computePrefix(UTF8String value) {
return value == null ? 0L : value.getPrefix();
}
}
- public static final class StringPrefixComparatorDesc extends PrefixComparator {
- @Override
- public int compare(long bPrefix, long aPrefix) {
- return UnsignedLongs.compare(aPrefix, bPrefix);
+ public static final class BinaryPrefixComparator {
+ public static long computePrefix(byte[] bytes) {
+ return ByteArray.getPrefix(bytes);
}
}
- public static final class BinaryPrefixComparator extends PrefixComparator {
- @Override
- public int compare(long aPrefix, long bPrefix) {
- return UnsignedLongs.compare(aPrefix, bPrefix);
+ public static final class DoublePrefixComparator {
+ /**
+ * Converts the double into a value that compares correctly as an unsigned long. For more
+ * details see http://stereopsis.com/radix.html.
+ */
+ public static long computePrefix(double value) {
+ // Java's doubleToLongBits already canonicalizes all NaN values to the smallest possible
+ // positive NaN, so there's nothing special we need to do for NaNs.
+ long bits = Double.doubleToLongBits(value);
+ // Negative floats compare backwards due to their sign-magnitude representation, so flip
+ // all the bits in this case.
+ long mask = -(bits >>> 63) | 0x8000000000000000L;
+ return bits ^ mask;
}
+ }
- public static long computePrefix(byte[] bytes) {
- return ByteArray.getPrefix(bytes);
- }
+ /**
+ * Provides radix sort parameters. Comparators implementing this also are indicating that the
+ * ordering they define is compatible with radix sort.
+ */
+ public static abstract class RadixSortSupport extends PrefixComparator {
+ /** @return Whether the sort should be descending in binary sort order. */
+ public abstract boolean sortDescending();
+
+ /** @return Whether the sort should take into account the sign bit. */
+ public abstract boolean sortSigned();
}
- public static final class BinaryPrefixComparatorDesc extends PrefixComparator {
+ //
+ // Standard prefix comparator implementations
+ //
+
+ public static final class UnsignedPrefixComparator extends RadixSortSupport {
+ @Override public final boolean sortDescending() { return false; }
+ @Override public final boolean sortSigned() { return false; }
@Override
- public int compare(long bPrefix, long aPrefix) {
+ public final int compare(long aPrefix, long bPrefix) {
return UnsignedLongs.compare(aPrefix, bPrefix);
}
}
- public static final class LongPrefixComparator extends PrefixComparator {
+ public static final class UnsignedPrefixComparatorDesc extends RadixSortSupport {
+ @Override public final boolean sortDescending() { return true; }
+ @Override public final boolean sortSigned() { return false; }
@Override
- public int compare(long a, long b) {
- return (a < b) ? -1 : (a > b) ? 1 : 0;
+ public final int compare(long bPrefix, long aPrefix) {
+ return UnsignedLongs.compare(aPrefix, bPrefix);
}
}
- public static final class LongPrefixComparatorDesc extends PrefixComparator {
+ public static final class SignedPrefixComparator extends RadixSortSupport {
+ @Override public final boolean sortDescending() { return false; }
+ @Override public final boolean sortSigned() { return true; }
@Override
- public int compare(long b, long a) {
+ public final int compare(long a, long b) {
return (a < b) ? -1 : (a > b) ? 1 : 0;
}
}
- public static final class DoublePrefixComparator extends PrefixComparator {
+ public static final class SignedPrefixComparatorDesc extends RadixSortSupport {
+ @Override public final boolean sortDescending() { return true; }
+ @Override public final boolean sortSigned() { return true; }
@Override
- public int compare(long aPrefix, long bPrefix) {
- double a = Double.longBitsToDouble(aPrefix);
- double b = Double.longBitsToDouble(bPrefix);
- return Utils.nanSafeCompareDoubles(a, b);
- }
-
- public static long computePrefix(double value) {
- return Double.doubleToLongBits(value);
- }
- }
-
- public static final class DoublePrefixComparatorDesc extends PrefixComparator {
- @Override
- public int compare(long bPrefix, long aPrefix) {
- double a = Double.longBitsToDouble(aPrefix);
- double b = Double.longBitsToDouble(bPrefix);
- return Utils.nanSafeCompareDoubles(a, b);
- }
-
- public static long computePrefix(double value) {
- return Double.doubleToLongBits(value);
+ public final int compare(long b, long a) {
+ return (a < b) ? -1 : (a > b) ? 1 : 0;
}
}
}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java
new file mode 100644
index 0000000000..3357b8e474
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java
@@ -0,0 +1,253 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.LongArray;
+
+public class RadixSort {
+
+ /**
+ * Sorts a given array of longs using least-significant-digit radix sort. This routine assumes
+ * you have extra space at the end of the array at least equal to the number of records. The
+ * sort is destructive and may relocate the data positioned within the array.
+ *
+ * @param array array of long elements followed by at least that many empty slots.
+ * @param numRecords number of data records in the array.
+ * @param startByteIndex the first byte (in range [0, 7]) to sort each long by, counting from the
+ * least significant byte.
+ * @param endByteIndex the last byte (in range [0, 7]) to sort each long by, counting from the
+ * least significant byte. Must be greater than startByteIndex.
+ * @param desc whether this is a descending (binary-order) sort.
+ * @param signed whether this is a signed (two's complement) sort.
+ *
+ * @return The starting index of the sorted data within the given array. We return this instead
+ * of always copying the data back to position zero for efficiency.
+ */
+ public static int sort(
+ LongArray array, int numRecords, int startByteIndex, int endByteIndex,
+ boolean desc, boolean signed) {
+ assert startByteIndex >= 0 : "startByteIndex (" + startByteIndex + ") should >= 0";
+ assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7";
+ assert endByteIndex > startByteIndex;
+ assert numRecords * 2 <= array.size();
+ int inIndex = 0;
+ int outIndex = numRecords;
+ if (numRecords > 0) {
+ long[][] counts = getCounts(array, numRecords, startByteIndex, endByteIndex);
+ for (int i = startByteIndex; i <= endByteIndex; i++) {
+ if (counts[i] != null) {
+ sortAtByte(
+ array, numRecords, counts[i], i, inIndex, outIndex,
+ desc, signed && i == endByteIndex);
+ int tmp = inIndex;
+ inIndex = outIndex;
+ outIndex = tmp;
+ }
+ }
+ }
+ return inIndex;
+ }
+
+ /**
+ * Performs a partial sort by copying data into destination offsets for each byte value at the
+ * specified byte offset.
+ *
+ * @param array array to partially sort.
+ * @param numRecords number of data records in the array.
+ * @param counts counts for each byte value. This routine destructively modifies this array.
+ * @param byteIdx the byte in a long to sort at, counting from the least significant byte.
+ * @param inIndex the starting index in the array where input data is located.
+ * @param outIndex the starting index where sorted output data should be written.
+ * @param desc whether this is a descending (binary-order) sort.
+ * @param signed whether this is a signed (two's complement) sort (only applies to last byte).
+ */
+ private static void sortAtByte(
+ LongArray array, int numRecords, long[] counts, int byteIdx, int inIndex, int outIndex,
+ boolean desc, boolean signed) {
+ assert counts.length == 256;
+ long[] offsets = transformCountsToOffsets(
+ counts, numRecords, array.getBaseOffset() + outIndex * 8, 8, desc, signed);
+ Object baseObject = array.getBaseObject();
+ long baseOffset = array.getBaseOffset() + inIndex * 8;
+ long maxOffset = baseOffset + numRecords * 8;
+ for (long offset = baseOffset; offset < maxOffset; offset += 8) {
+ long value = Platform.getLong(baseObject, offset);
+ int bucket = (int)((value >>> (byteIdx * 8)) & 0xff);
+ Platform.putLong(baseObject, offsets[bucket], value);
+ offsets[bucket] += 8;
+ }
+ }
+
+ /**
+ * Computes a value histogram for each byte in the given array.
+ *
+ * @param array array to count records in.
+ * @param numRecords number of data records in the array.
+ * @param startByteIndex the first byte to compute counts for (the prior are skipped).
+ * @param endByteIndex the last byte to compute counts for.
+ *
+ * @return an array of eight 256-byte count arrays, one for each byte starting from the least
+ * significant byte. If the byte does not need sorting the array will be null.
+ */
+ private static long[][] getCounts(
+ LongArray array, int numRecords, int startByteIndex, int endByteIndex) {
+ long[][] counts = new long[8][];
+ // Optimization: do a fast pre-pass to determine which byte indices we can skip for sorting.
+ // If all the byte values at a particular index are the same we don't need to count it.
+ long bitwiseMax = 0;
+ long bitwiseMin = -1L;
+ long maxOffset = array.getBaseOffset() + numRecords * 8;
+ Object baseObject = array.getBaseObject();
+ for (long offset = array.getBaseOffset(); offset < maxOffset; offset += 8) {
+ long value = Platform.getLong(baseObject, offset);
+ bitwiseMax |= value;
+ bitwiseMin &= value;
+ }
+ long bitsChanged = bitwiseMin ^ bitwiseMax;
+ // Compute counts for each byte index.
+ for (int i = startByteIndex; i <= endByteIndex; i++) {
+ if (((bitsChanged >>> (i * 8)) & 0xff) != 0) {
+ counts[i] = new long[256];
+ // TODO(ekl) consider computing all the counts in one pass.
+ for (long offset = array.getBaseOffset(); offset < maxOffset; offset += 8) {
+ counts[i][(int)((Platform.getLong(baseObject, offset) >>> (i * 8)) & 0xff)]++;
+ }
+ }
+ }
+ return counts;
+ }
+
+ /**
+ * Transforms counts into the proper unsafe output offsets for the sort type.
+ *
+ * @param counts counts for each byte value. This routine destructively modifies this array.
+ * @param numRecords number of data records in the original data array.
+ * @param outputOffset output offset in bytes from the base array object.
+ * @param bytesPerRecord size of each record (8 for plain sort, 16 for key-prefix sort).
+ * @param desc whether this is a descending (binary-order) sort.
+ * @param signed whether this is a signed (two's complement) sort.
+ *
+ * @return the input counts array.
+ */
+ private static long[] transformCountsToOffsets(
+ long[] counts, int numRecords, long outputOffset, int bytesPerRecord,
+ boolean desc, boolean signed) {
+ assert counts.length == 256;
+ int start = signed ? 128 : 0; // output the negative records first (values 129-255).
+ if (desc) {
+ int pos = numRecords;
+ for (int i = start; i < start + 256; i++) {
+ pos -= counts[i & 0xff];
+ counts[i & 0xff] = outputOffset + pos * bytesPerRecord;
+ }
+ } else {
+ int pos = 0;
+ for (int i = start; i < start + 256; i++) {
+ long tmp = counts[i & 0xff];
+ counts[i & 0xff] = outputOffset + pos * bytesPerRecord;
+ pos += tmp;
+ }
+ }
+ return counts;
+ }
+
+ /**
+ * Specialization of sort() for key-prefix arrays. In this type of array, each record consists
+ * of two longs, only the second of which is sorted on.
+ */
+ public static int sortKeyPrefixArray(
+ LongArray array,
+ int numRecords,
+ int startByteIndex,
+ int endByteIndex,
+ boolean desc,
+ boolean signed) {
+ assert startByteIndex >= 0 : "startByteIndex (" + startByteIndex + ") should >= 0";
+ assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7";
+ assert endByteIndex > startByteIndex;
+ assert numRecords * 4 <= array.size();
+ int inIndex = 0;
+ int outIndex = numRecords * 2;
+ if (numRecords > 0) {
+ long[][] counts = getKeyPrefixArrayCounts(array, numRecords, startByteIndex, endByteIndex);
+ for (int i = startByteIndex; i <= endByteIndex; i++) {
+ if (counts[i] != null) {
+ sortKeyPrefixArrayAtByte(
+ array, numRecords, counts[i], i, inIndex, outIndex,
+ desc, signed && i == endByteIndex);
+ int tmp = inIndex;
+ inIndex = outIndex;
+ outIndex = tmp;
+ }
+ }
+ }
+ return inIndex;
+ }
+
+ /**
+ * Specialization of getCounts() for key-prefix arrays. We could probably combine this with
+ * getCounts with some added parameters but that seems to hurt in benchmarks.
+ */
+ private static long[][] getKeyPrefixArrayCounts(
+ LongArray array, int numRecords, int startByteIndex, int endByteIndex) {
+ long[][] counts = new long[8][];
+ long bitwiseMax = 0;
+ long bitwiseMin = -1L;
+ long limit = array.getBaseOffset() + numRecords * 16;
+ Object baseObject = array.getBaseObject();
+ for (long offset = array.getBaseOffset(); offset < limit; offset += 16) {
+ long value = Platform.getLong(baseObject, offset + 8);
+ bitwiseMax |= value;
+ bitwiseMin &= value;
+ }
+ long bitsChanged = bitwiseMin ^ bitwiseMax;
+ for (int i = startByteIndex; i <= endByteIndex; i++) {
+ if (((bitsChanged >>> (i * 8)) & 0xff) != 0) {
+ counts[i] = new long[256];
+ for (long offset = array.getBaseOffset(); offset < limit; offset += 16) {
+ counts[i][(int)((Platform.getLong(baseObject, offset + 8) >>> (i * 8)) & 0xff)]++;
+ }
+ }
+ }
+ return counts;
+ }
+
+ /**
+ * Specialization of sortAtByte() for key-prefix arrays.
+ */
+ private static void sortKeyPrefixArrayAtByte(
+ LongArray array, int numRecords, long[] counts, int byteIdx, int inIndex, int outIndex,
+ boolean desc, boolean signed) {
+ assert counts.length == 256;
+ long[] offsets = transformCountsToOffsets(
+ counts, numRecords, array.getBaseOffset() + outIndex * 8, 16, desc, signed);
+ Object baseObject = array.getBaseObject();
+ long baseOffset = array.getBaseOffset() + inIndex * 8;
+ long maxOffset = baseOffset + numRecords * 16;
+ for (long offset = baseOffset; offset < maxOffset; offset += 16) {
+ long key = Platform.getLong(baseObject, offset);
+ long prefix = Platform.getLong(baseObject, offset + 8);
+ int bucket = (int)((prefix >>> (byteIdx * 8)) & 0xff);
+ long dest = offsets[bucket];
+ Platform.putLong(baseObject, dest, key);
+ Platform.putLong(baseObject, dest + 8, prefix);
+ offsets[bucket] += 16;
+ }
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index 3e32dd9d63..66a77982ad 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -89,7 +89,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
UnsafeInMemorySorter inMemorySorter) throws IOException {
UnsafeExternalSorter sorter = new UnsafeExternalSorter(taskMemoryManager, blockManager,
serializerManager, taskContext, recordComparator, prefixComparator, initialSize,
- pageSizeBytes, inMemorySorter);
+ pageSizeBytes, inMemorySorter, false /* ignored */);
sorter.spill(Long.MAX_VALUE, sorter);
// The external sorter will be used to insert records, in-memory sorter is not needed.
sorter.inMemSorter = null;
@@ -104,9 +104,11 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
RecordComparator recordComparator,
PrefixComparator prefixComparator,
int initialSize,
- long pageSizeBytes) {
+ long pageSizeBytes,
+ boolean canUseRadixSort) {
return new UnsafeExternalSorter(taskMemoryManager, blockManager, serializerManager,
- taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, null);
+ taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, null,
+ canUseRadixSort);
}
private UnsafeExternalSorter(
@@ -118,7 +120,8 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
PrefixComparator prefixComparator,
int initialSize,
long pageSizeBytes,
- @Nullable UnsafeInMemorySorter existingInMemorySorter) {
+ @Nullable UnsafeInMemorySorter existingInMemorySorter,
+ boolean canUseRadixSort) {
super(taskMemoryManager, pageSizeBytes);
this.taskMemoryManager = taskMemoryManager;
this.blockManager = blockManager;
@@ -133,7 +136,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
if (existingInMemorySorter == null) {
this.inMemSorter = new UnsafeInMemorySorter(
- this, taskMemoryManager, recordComparator, prefixComparator, initialSize);
+ this, taskMemoryManager, recordComparator, prefixComparator, initialSize, canUseRadixSort);
} else {
this.inMemSorter = existingInMemorySorter;
}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index 01eae0e8dc..5f46ef9a81 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -17,6 +17,7 @@
package org.apache.spark.util.collection.unsafe.sort;
+import java.lang.Long;
import java.util.Comparator;
import org.apache.avro.reflect.Nullable;
@@ -26,6 +27,7 @@ import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.LongArray;
import org.apache.spark.util.collection.Sorter;
+import org.apache.spark.util.collection.unsafe.sort.RadixSort;
/**
* Sorts records using an AlphaSort-style key-prefix sort. This sort stores pointers to records
@@ -74,6 +76,17 @@ public final class UnsafeInMemorySorter {
private final Comparator<RecordPointerAndKeyPrefix> sortComparator;
/**
+ * If non-null, specifies the radix sort parameters and that radix sort will be used.
+ */
+ @Nullable
+ private final PrefixComparators.RadixSortSupport radixSortSupport;
+
+ /**
+ * Set to 2x for radix sort to reserve extra memory for sorting, otherwise 1x.
+ */
+ private final int memoryAllocationFactor;
+
+ /**
* Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at
* index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix.
*/
@@ -91,27 +104,36 @@ public final class UnsafeInMemorySorter {
final TaskMemoryManager memoryManager,
final RecordComparator recordComparator,
final PrefixComparator prefixComparator,
- int initialSize) {
+ int initialSize,
+ boolean canUseRadixSort) {
this(consumer, memoryManager, recordComparator, prefixComparator,
- consumer.allocateArray(initialSize * 2));
+ consumer.allocateArray(initialSize * 2), canUseRadixSort);
}
public UnsafeInMemorySorter(
- final MemoryConsumer consumer,
+ final MemoryConsumer consumer,
final TaskMemoryManager memoryManager,
final RecordComparator recordComparator,
final PrefixComparator prefixComparator,
- LongArray array) {
+ LongArray array,
+ boolean canUseRadixSort) {
this.consumer = consumer;
this.memoryManager = memoryManager;
this.initialSize = array.size();
if (recordComparator != null) {
this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE);
this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager);
+ if (canUseRadixSort && prefixComparator instanceof PrefixComparators.RadixSortSupport) {
+ this.radixSortSupport = (PrefixComparators.RadixSortSupport)prefixComparator;
+ } else {
+ this.radixSortSupport = null;
+ }
} else {
this.sorter = null;
this.sortComparator = null;
+ this.radixSortSupport = null;
}
+ this.memoryAllocationFactor = this.radixSortSupport != null ? 2 : 1;
this.array = array;
}
@@ -141,11 +163,11 @@ public final class UnsafeInMemorySorter {
}
public long getMemoryUsage() {
- return array.size() * 8L;
+ return array.size() * Long.BYTES;
}
public boolean hasSpaceForAnotherRecord() {
- return pos + 2 <= array.size();
+ return pos + 1 < (array.size() / memoryAllocationFactor);
}
public void expandPointerArray(LongArray newArray) {
@@ -157,7 +179,7 @@ public final class UnsafeInMemorySorter {
array.getBaseOffset(),
newArray.getBaseObject(),
newArray.getBaseOffset(),
- array.size() * 8L);
+ array.size() * (Long.BYTES / memoryAllocationFactor));
consumer.freeArray(array);
array = newArray;
}
@@ -183,18 +205,20 @@ public final class UnsafeInMemorySorter {
private final int numRecords;
private int position;
+ private int offset;
private Object baseObject;
private long baseOffset;
private long keyPrefix;
private int recordLength;
- private SortedIterator(int numRecords) {
+ private SortedIterator(int numRecords, int offset) {
this.numRecords = numRecords;
this.position = 0;
+ this.offset = offset;
}
public SortedIterator clone() {
- SortedIterator iter = new SortedIterator(numRecords);
+ SortedIterator iter = new SortedIterator(numRecords, offset);
iter.position = position;
iter.baseObject = baseObject;
iter.baseOffset = baseOffset;
@@ -216,11 +240,11 @@ public final class UnsafeInMemorySorter {
@Override
public void loadNext() {
// This pointer points to a 4-byte record length, followed by the record's bytes
- final long recordPointer = array.get(position);
+ final long recordPointer = array.get(offset + position);
baseObject = memoryManager.getPage(recordPointer);
baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4; // Skip over record length
recordLength = Platform.getInt(baseObject, baseOffset - 4);
- keyPrefix = array.get(position + 1);
+ keyPrefix = array.get(offset + position + 1);
position += 2;
}
@@ -242,9 +266,17 @@ public final class UnsafeInMemorySorter {
* {@code next()} will return the same mutable object.
*/
public SortedIterator getSortedIterator() {
+ int offset = 0;
if (sorter != null) {
- sorter.sort(array, 0, pos / 2, sortComparator);
+ if (this.radixSortSupport != null) {
+ // TODO(ekl) we should handle NULL values before radix sort for efficiency, since they
+ // force a full-width sort (and we cannot radix-sort nullable long fields at all).
+ offset = RadixSort.sortKeyPrefixArray(
+ array, pos / 2, 0, 7, radixSortSupport.sortDescending(), radixSortSupport.sortSigned());
+ } else {
+ sorter.sort(array, 0, pos / 2, sortComparator);
+ }
}
- return new SortedIterator(pos / 2);
+ return new SortedIterator(pos / 2, offset);
}
}
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemoryRadixSorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemoryRadixSorterSuite.java
new file mode 100644
index 0000000000..6927d0a815
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemoryRadixSorterSuite.java
@@ -0,0 +1,23 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+public class ShuffleInMemoryRadixSorterSuite extends ShuffleInMemorySorterSuite {
+ @Override
+ protected boolean shouldUseRadixSort() { return true; }
+}
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
index 4cd3600df1..43e32f073a 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
@@ -17,6 +17,7 @@
package org.apache.spark.shuffle.sort;
+import java.lang.Long;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Random;
@@ -34,6 +35,8 @@ import org.apache.spark.unsafe.memory.MemoryBlock;
public class ShuffleInMemorySorterSuite {
+ protected boolean shouldUseRadixSort() { return false; }
+
final TestMemoryManager memoryManager =
new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false"));
final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0);
@@ -47,7 +50,8 @@ public class ShuffleInMemorySorterSuite {
@Test
public void testSortingEmptyInput() {
- final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(consumer, 100);
+ final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(
+ consumer, 100, shouldUseRadixSort());
final ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator();
Assert.assertFalse(iter.hasNext());
}
@@ -70,14 +74,16 @@ public class ShuffleInMemorySorterSuite {
new TaskMemoryManager(new TestMemoryManager(conf), 0);
final MemoryBlock dataPage = memoryManager.allocatePage(2048, null);
final Object baseObject = dataPage.getBaseObject();
- final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(consumer, 4);
+ final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(
+ consumer, 4, shouldUseRadixSort());
final HashPartitioner hashPartitioner = new HashPartitioner(4);
// Write the records into the data page and store pointers into the sorter
long position = dataPage.getBaseOffset();
for (String str : dataToSort) {
if (!sorter.hasSpaceForAnotherRecord()) {
- sorter.expandPointerArray(consumer.allocateArray(sorter.numRecords() * 2));
+ sorter.expandPointerArray(
+ consumer.allocateArray(sorter.getMemoryUsage() / Long.BYTES * 2));
}
final long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, position);
final byte[] strBytes = str.getBytes(StandardCharsets.UTF_8);
@@ -114,12 +120,12 @@ public class ShuffleInMemorySorterSuite {
@Test
public void testSortingManyNumbers() throws Exception {
- ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(consumer, 4);
+ ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(consumer, 4, shouldUseRadixSort());
int[] numbersToSort = new int[128000];
Random random = new Random(16);
for (int i = 0; i < numbersToSort.length; i++) {
if (!sorter.hasSpaceForAnotherRecord()) {
- sorter.expandPointerArray(consumer.allocateArray(sorter.numRecords() * 2));
+ sorter.expandPointerArray(consumer.allocateArray(sorter.getMemoryUsage() / 8 * 2));
}
numbersToSort[i] = random.nextInt(PackedRecordPointer.MAXIMUM_PARTITION_ID + 1);
sorter.insertRecord(0, numbersToSort[i]);
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index fbaaa1cf49..f9dc20d8b7 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -392,7 +392,20 @@ public class UnsafeShuffleWriterSuite {
}
@Test
- public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception {
+ public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpillRadixOff() throws Exception {
+ conf.set("spark.shuffle.sort.useRadixSort", "false");
+ writeEnoughRecordsToTriggerSortBufferExpansionAndSpill();
+ assertEquals(2, spillFilesCreated.size());
+ }
+
+ @Test
+ public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpillRadixOn() throws Exception {
+ conf.set("spark.shuffle.sort.useRadixSort", "true");
+ writeEnoughRecordsToTriggerSortBufferExpansionAndSpill();
+ assertEquals(3, spillFilesCreated.size());
+ }
+
+ private void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception {
memoryManager.limit(UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE * 16);
final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<>();
@@ -400,7 +413,6 @@ public class UnsafeShuffleWriterSuite {
dataToWrite.add(new Tuple2<Object, Object>(i, i));
}
writer.write(dataToWrite.iterator());
- assertEquals(2, spillFilesCreated.size());
writer.stop(true);
readRecordsFromFile();
assertSpillFilesWereCleanedUp();
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterRadixSortSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterRadixSortSuite.java
new file mode 100644
index 0000000000..bb38305a07
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterRadixSortSuite.java
@@ -0,0 +1,23 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+public class UnsafeExternalSorterRadixSortSuite extends UnsafeExternalSorterSuite {
+ @Override
+ protected boolean shouldUseRadixSort() { return true; }
+}
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
index a2253d8559..60a40cc172 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -64,12 +64,7 @@ public class UnsafeExternalSorterSuite {
new JavaSerializer(new SparkConf()),
new SparkConf().set("spark.shuffle.spill.compress", "false"));
// Use integer comparison for comparing prefixes (which are partition ids, in this case)
- final PrefixComparator prefixComparator = new PrefixComparator() {
- @Override
- public int compare(long prefix1, long prefix2) {
- return (int) prefix1 - (int) prefix2;
- }
- };
+ final PrefixComparator prefixComparator = PrefixComparators.LONG;
// Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so
// use a dummy comparator
final RecordComparator recordComparator = new RecordComparator() {
@@ -88,6 +83,7 @@ public class UnsafeExternalSorterSuite {
@Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
@Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext;
+ protected boolean shouldUseRadixSort() { return false; }
private final long pageSizeBytes = new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "4m");
@@ -178,7 +174,8 @@ public class UnsafeExternalSorterSuite {
recordComparator,
prefixComparator,
/* initialSize */ 1024,
- pageSizeBytes);
+ pageSizeBytes,
+ shouldUseRadixSort());
}
@Test
@@ -381,7 +378,8 @@ public class UnsafeExternalSorterSuite {
null,
null,
/* initialSize */ 1024,
- pageSizeBytes);
+ pageSizeBytes,
+ shouldUseRadixSort());
long[] record = new long[100];
int recordSize = record.length * 8;
int n = (int) pageSizeBytes / recordSize * 3;
@@ -416,7 +414,8 @@ public class UnsafeExternalSorterSuite {
recordComparator,
prefixComparator,
1024,
- pageSizeBytes);
+ pageSizeBytes,
+ shouldUseRadixSort());
// Peak memory should be monotonically increasing. More specifically, every time
// we allocate a new page it should increase by exactly the size of the page.
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterRadixSortSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterRadixSortSuite.java
new file mode 100644
index 0000000000..ae69ededf7
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterRadixSortSuite.java
@@ -0,0 +1,23 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort;
+
+public class UnsafeInMemorySorterRadixSortSuite extends UnsafeInMemorySorterSuite {
+ @Override
+ protected boolean shouldUseRadixSort() { return true; }
+}
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
index f90214fffd..23f4abfed2 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
@@ -17,6 +17,7 @@
package org.apache.spark.util.collection.unsafe.sort;
+import java.lang.Long;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
@@ -39,6 +40,8 @@ import static org.mockito.Mockito.mock;
public class UnsafeInMemorySorterSuite {
+ protected boolean shouldUseRadixSort() { return false; }
+
private static String getStringFromDataPage(Object baseObject, long baseOffset, int length) {
final byte[] strBytes = new byte[length];
Platform.copyMemory(baseObject, baseOffset, strBytes, Platform.BYTE_ARRAY_OFFSET, length);
@@ -54,7 +57,8 @@ public class UnsafeInMemorySorterSuite {
memoryManager,
mock(RecordComparator.class),
mock(PrefixComparator.class),
- 100);
+ 100,
+ shouldUseRadixSort());
final UnsafeSorterIterator iter = sorter.getSortedIterator();
Assert.assertFalse(iter.hasNext());
}
@@ -102,19 +106,15 @@ public class UnsafeInMemorySorterSuite {
// Compute key prefixes based on the records' partition ids
final HashPartitioner hashPartitioner = new HashPartitioner(4);
// Use integer comparison for comparing prefixes (which are partition ids, in this case)
- final PrefixComparator prefixComparator = new PrefixComparator() {
- @Override
- public int compare(long prefix1, long prefix2) {
- return (int) prefix1 - (int) prefix2;
- }
- };
+ final PrefixComparator prefixComparator = PrefixComparators.LONG;
UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(consumer, memoryManager,
- recordComparator, prefixComparator, dataToSort.length);
+ recordComparator, prefixComparator, dataToSort.length, shouldUseRadixSort());
// Given a page of records, insert those records into the sorter one-by-one:
position = dataPage.getBaseOffset();
for (int i = 0; i < dataToSort.length; i++) {
if (!sorter.hasSpaceForAnotherRecord()) {
- sorter.expandPointerArray(consumer.allocateArray(sorter.numRecords() * 2 * 2));
+ sorter.expandPointerArray(
+ consumer.allocateArray(sorter.getMemoryUsage() / Long.BYTES * 2));
}
// position now points to the start of a record (which holds its length).
final int recordLength = Platform.getInt(baseObject, position);
diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
index dda8bee222..b4083230b4 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
@@ -101,6 +101,8 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks {
test("double prefix comparator handles NaNs properly") {
val nan1: Double = java.lang.Double.longBitsToDouble(0x7ff0000000000001L)
val nan2: Double = java.lang.Double.longBitsToDouble(0x7fffffffffffffffL)
+ assert(
+ java.lang.Double.doubleToRawLongBits(nan1) != java.lang.Double.doubleToRawLongBits(nan2))
assert(nan1.isNaN)
assert(nan2.isNaN)
val nan1Prefix = PrefixComparators.DoublePrefixComparator.computePrefix(nan1)
@@ -110,4 +112,28 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks {
assert(PrefixComparators.DOUBLE.compare(nan1Prefix, doubleMaxPrefix) === 1)
}
+ test("double prefix comparator handles negative NaNs properly") {
+ val negativeNan: Double = java.lang.Double.longBitsToDouble(0xfff0000000000001L)
+ assert(negativeNan.isNaN)
+ assert(java.lang.Double.doubleToRawLongBits(negativeNan) < 0)
+ val prefix = PrefixComparators.DoublePrefixComparator.computePrefix(negativeNan)
+ val doubleMaxPrefix = PrefixComparators.DoublePrefixComparator.computePrefix(Double.MaxValue)
+ assert(PrefixComparators.DOUBLE.compare(prefix, doubleMaxPrefix) === 1)
+ }
+
+ test("double prefix comparator handles other special values properly") {
+ val nullValue = 0L
+ val nan = PrefixComparators.DoublePrefixComparator.computePrefix(Double.NaN)
+ val posInf = PrefixComparators.DoublePrefixComparator.computePrefix(Double.PositiveInfinity)
+ val negInf = PrefixComparators.DoublePrefixComparator.computePrefix(Double.NegativeInfinity)
+ val minValue = PrefixComparators.DoublePrefixComparator.computePrefix(Double.MinValue)
+ val maxValue = PrefixComparators.DoublePrefixComparator.computePrefix(Double.MaxValue)
+ val zero = PrefixComparators.DoublePrefixComparator.computePrefix(0.0)
+ assert(PrefixComparators.DOUBLE.compare(nan, posInf) === 1)
+ assert(PrefixComparators.DOUBLE.compare(posInf, maxValue) === 1)
+ assert(PrefixComparators.DOUBLE.compare(maxValue, zero) === 1)
+ assert(PrefixComparators.DOUBLE.compare(zero, minValue) === 1)
+ assert(PrefixComparators.DOUBLE.compare(minValue, negInf) === 1)
+ assert(PrefixComparators.DOUBLE.compare(negInf, nullValue) === 1)
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala
new file mode 100644
index 0000000000..52428634e5
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala
@@ -0,0 +1,264 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection.unsafe.sort
+
+import java.lang.{Long => JLong}
+import java.util.{Arrays, Comparator}
+
+import scala.util.Random
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.internal.Logging
+import org.apache.spark.unsafe.array.LongArray
+import org.apache.spark.unsafe.memory.MemoryBlock
+import org.apache.spark.util.Benchmark
+import org.apache.spark.util.collection.Sorter
+import org.apache.spark.util.random.XORShiftRandom
+
+class RadixSortSuite extends SparkFunSuite with Logging {
+ private val N = 10000 // scale this down for more readable results
+
+ /**
+ * Describes a type of sort to test, e.g. two's complement descending. Each sort type has
+ * a defined reference ordering as well as radix sort parameters that can be used to
+ * reproduce the given ordering.
+ */
+ case class RadixSortType(
+ name: String,
+ referenceComparator: PrefixComparator,
+ startByteIdx: Int, endByteIdx: Int, descending: Boolean, signed: Boolean)
+
+ val SORT_TYPES_TO_TEST = Seq(
+ RadixSortType("unsigned binary data asc", PrefixComparators.BINARY, 0, 7, false, false),
+ RadixSortType("unsigned binary data desc", PrefixComparators.BINARY_DESC, 0, 7, true, false),
+ RadixSortType("twos complement asc", PrefixComparators.LONG, 0, 7, false, true),
+ RadixSortType("twos complement desc", PrefixComparators.LONG_DESC, 0, 7, true, true),
+ RadixSortType(
+ "binary data partial",
+ new PrefixComparators.RadixSortSupport {
+ override def sortDescending = false
+ override def sortSigned = false
+ override def compare(a: Long, b: Long): Int = {
+ return PrefixComparators.BINARY.compare(a & 0xffffff0000L, b & 0xffffff0000L)
+ }
+ },
+ 2, 4, false, false))
+
+ private def generateTestData(size: Int, rand: => Long): (Array[JLong], LongArray) = {
+ val ref = Array.tabulate[Long](size) { i => rand }
+ val extended = ref ++ Array.fill[Long](size)(0)
+ (ref.map(i => new JLong(i)), new LongArray(MemoryBlock.fromLongArray(extended)))
+ }
+
+ private def generateKeyPrefixTestData(size: Int, rand: => Long): (LongArray, LongArray) = {
+ val ref = Array.tabulate[Long](size * 2) { i => rand }
+ val extended = ref ++ Array.fill[Long](size * 2)(0)
+ (new LongArray(MemoryBlock.fromLongArray(ref)),
+ new LongArray(MemoryBlock.fromLongArray(extended)))
+ }
+
+ private def collectToArray(array: LongArray, offset: Int, length: Int): Array[Long] = {
+ var i = 0
+ val out = new Array[Long](length)
+ while (i < length) {
+ out(i) = array.get(offset + i)
+ i += 1
+ }
+ out
+ }
+
+ private def toJavaComparator(p: PrefixComparator): Comparator[JLong] = {
+ new Comparator[JLong] {
+ override def compare(a: JLong, b: JLong): Int = {
+ p.compare(a, b)
+ }
+ override def equals(other: Any): Boolean = {
+ other == this
+ }
+ }
+ }
+
+ private def referenceKeyPrefixSort(buf: LongArray, lo: Int, hi: Int, refCmp: PrefixComparator) {
+ new Sorter(UnsafeSortDataFormat.INSTANCE).sort(
+ buf, lo, hi, new Comparator[RecordPointerAndKeyPrefix] {
+ override def compare(
+ r1: RecordPointerAndKeyPrefix,
+ r2: RecordPointerAndKeyPrefix): Int = {
+ refCmp.compare(r1.keyPrefix, r2.keyPrefix)
+ }
+ })
+ }
+
+ private def fuzzTest(name: String)(testFn: Long => Unit): Unit = {
+ test(name) {
+ var seed = 0L
+ try {
+ for (i <- 0 to 10) {
+ seed = System.nanoTime
+ testFn(seed)
+ }
+ } catch {
+ case t: Throwable =>
+ throw new Exception("Failed with seed: " + seed, t)
+ }
+ }
+ }
+
+ // Radix sort is sensitive to the value distribution at different bit indices (e.g., we may
+ // omit a sort on a byte if all values are equal). This generates random good test masks.
+ def randomBitMask(rand: Random): Long = {
+ var tmp = ~0L
+ for (i <- 0 to rand.nextInt(5)) {
+ tmp &= rand.nextLong
+ }
+ tmp
+ }
+
+ for (sortType <- SORT_TYPES_TO_TEST) {
+ test("radix support for " + sortType.name) {
+ val s = sortType.referenceComparator.asInstanceOf[PrefixComparators.RadixSortSupport]
+ assert(s.sortDescending() == sortType.descending)
+ assert(s.sortSigned() == sortType.signed)
+ }
+
+ test("sort " + sortType.name) {
+ val rand = new XORShiftRandom(123)
+ val (ref, buffer) = generateTestData(N, rand.nextLong)
+ Arrays.sort(ref, toJavaComparator(sortType.referenceComparator))
+ val outOffset = RadixSort.sort(
+ buffer, N, sortType.startByteIdx, sortType.endByteIdx,
+ sortType.descending, sortType.signed)
+ val result = collectToArray(buffer, outOffset, N)
+ assert(ref.view == result.view)
+ }
+
+ test("sort key prefix " + sortType.name) {
+ val rand = new XORShiftRandom(123)
+ val (buf1, buf2) = generateKeyPrefixTestData(N, rand.nextLong & 0xff)
+ referenceKeyPrefixSort(buf1, 0, N, sortType.referenceComparator)
+ val outOffset = RadixSort.sortKeyPrefixArray(
+ buf2, N, sortType.startByteIdx, sortType.endByteIdx,
+ sortType.descending, sortType.signed)
+ val res1 = collectToArray(buf1, 0, N * 2)
+ val res2 = collectToArray(buf2, outOffset, N * 2)
+ assert(res1.view == res2.view)
+ }
+
+ fuzzTest(s"fuzz test ${sortType.name} with random bitmasks") { seed =>
+ val rand = new XORShiftRandom(seed)
+ val mask = randomBitMask(rand)
+ val (ref, buffer) = generateTestData(N, rand.nextLong & mask)
+ Arrays.sort(ref, toJavaComparator(sortType.referenceComparator))
+ val outOffset = RadixSort.sort(
+ buffer, N, sortType.startByteIdx, sortType.endByteIdx,
+ sortType.descending, sortType.signed)
+ val result = collectToArray(buffer, outOffset, N)
+ assert(ref.view == result.view)
+ }
+
+ fuzzTest(s"fuzz test key prefix ${sortType.name} with random bitmasks") { seed =>
+ val rand = new XORShiftRandom(seed)
+ val mask = randomBitMask(rand)
+ val (buf1, buf2) = generateKeyPrefixTestData(N, rand.nextLong & mask)
+ referenceKeyPrefixSort(buf1, 0, N, sortType.referenceComparator)
+ val outOffset = RadixSort.sortKeyPrefixArray(
+ buf2, N, sortType.startByteIdx, sortType.endByteIdx,
+ sortType.descending, sortType.signed)
+ val res1 = collectToArray(buf1, 0, N * 2)
+ val res2 = collectToArray(buf2, outOffset, N * 2)
+ assert(res1.view == res2.view)
+ }
+ }
+
+ ignore("microbenchmarks") {
+ val size = 25000000
+ val rand = new XORShiftRandom(123)
+ val benchmark = new Benchmark("radix sort " + size, size)
+ benchmark.addTimerCase("reference TimSort key prefix array") { timer =>
+ val array = Array.tabulate[Long](size * 2) { i => rand.nextLong }
+ val buf = new LongArray(MemoryBlock.fromLongArray(array))
+ timer.startTiming()
+ referenceKeyPrefixSort(buf, 0, size, PrefixComparators.BINARY)
+ timer.stopTiming()
+ }
+ benchmark.addTimerCase("reference Arrays.sort") { timer =>
+ val ref = Array.tabulate[Long](size) { i => rand.nextLong }
+ timer.startTiming()
+ Arrays.sort(ref)
+ timer.stopTiming()
+ }
+ benchmark.addTimerCase("radix sort one byte") { timer =>
+ val array = new Array[Long](size * 2)
+ var i = 0
+ while (i < size) {
+ array(i) = rand.nextLong & 0xff
+ i += 1
+ }
+ val buf = new LongArray(MemoryBlock.fromLongArray(array))
+ timer.startTiming()
+ RadixSort.sort(buf, size, 0, 7, false, false)
+ timer.stopTiming()
+ }
+ benchmark.addTimerCase("radix sort two bytes") { timer =>
+ val array = new Array[Long](size * 2)
+ var i = 0
+ while (i < size) {
+ array(i) = rand.nextLong & 0xffff
+ i += 1
+ }
+ val buf = new LongArray(MemoryBlock.fromLongArray(array))
+ timer.startTiming()
+ RadixSort.sort(buf, size, 0, 7, false, false)
+ timer.stopTiming()
+ }
+ benchmark.addTimerCase("radix sort eight bytes") { timer =>
+ val array = new Array[Long](size * 2)
+ var i = 0
+ while (i < size) {
+ array(i) = rand.nextLong
+ i += 1
+ }
+ val buf = new LongArray(MemoryBlock.fromLongArray(array))
+ timer.startTiming()
+ RadixSort.sort(buf, size, 0, 7, false, false)
+ timer.stopTiming()
+ }
+ benchmark.addTimerCase("radix sort key prefix array") { timer =>
+ val (_, buf2) = generateKeyPrefixTestData(size, rand.nextLong)
+ timer.startTiming()
+ RadixSort.sortKeyPrefixArray(buf2, size, 0, 7, false, false)
+ timer.stopTiming()
+ }
+ benchmark.run
+
+ /**
+ Running benchmark: radix sort 25000000
+ Java HotSpot(TM) 64-Bit Server VM 1.8.0_66-b17 on Linux 3.13.0-44-generic
+ Intel(R) Core(TM) i7-4600U CPU @ 2.10GHz
+
+ radix sort 25000000: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ -------------------------------------------------------------------------------------------
+ reference TimSort key prefix array 15546 / 15859 1.6 621.9 1.0X
+ reference Arrays.sort 2416 / 2446 10.3 96.6 6.4X
+ radix sort one byte 133 / 137 188.4 5.3 117.2X
+ radix sort two bytes 255 / 258 98.2 10.2 61.1X
+ radix sort eight bytes 991 / 997 25.2 39.6 15.7X
+ radix sort key prefix array 1540 / 1563 16.2 61.6 10.1X
+ */
+ }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
index 7784345a7a..8d9906da7e 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -59,7 +59,8 @@ public final class UnsafeExternalRowSorter {
Ordering<InternalRow> ordering,
PrefixComparator prefixComparator,
PrefixComputer prefixComputer,
- long pageSizeBytes) throws IOException {
+ long pageSizeBytes,
+ boolean canUseRadixSort) throws IOException {
this.schema = schema;
this.prefixComputer = prefixComputer;
final SparkEnv sparkEnv = SparkEnv.get();
@@ -72,7 +73,8 @@ public final class UnsafeExternalRowSorter {
new RowComparator(ordering, schema.length()),
prefixComparator,
/* initialSize */ 4096,
- pageSizeBytes
+ pageSizeBytes,
+ canUseRadixSort
);
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
index e0c3b22a3c..42a8be6b1b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
@@ -84,8 +84,7 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression {
case DateType | TimestampType =>
(Long.MinValue, s"(long) $input")
case FloatType | DoubleType =>
- (DoublePrefixComparator.computePrefix(Double.NegativeInfinity),
- s"$DoublePrefixCmp.computePrefix((double)$input)")
+ (0L, s"$DoublePrefixCmp.computePrefix((double)$input)")
case StringType => (0L, s"$input.getPrefix()")
case BinaryType => (0L, s"$BinaryPrefixCmp.computePrefix($input)")
case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS =>
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
index 8132bba04c..b6499e35b5 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
@@ -85,13 +85,15 @@ public final class UnsafeKVExternalSorter {
recordComparator,
prefixComparator,
/* initialSize */ 4096,
- pageSizeBytes);
+ pageSizeBytes,
+ keySchema.length() == 1 && SortPrefixUtils.canSortFullyWithPrefix(keySchema.apply(0)));
} else {
// During spilling, the array in map will not be used, so we can borrow that and use it
// as the underline array for in-memory sorter (it's always large enough).
// Since we will not grow the array, it's fine to pass `null` as consumer.
final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter(
- null, taskMemoryManager, recordComparator, prefixComparator, map.getArray());
+ null, taskMemoryManager, recordComparator, prefixComparator, map.getArray(),
+ false /* TODO(ekl) we can only radix sort if the BytesToBytes load factor is <= 0.5 */);
// We cannot use the destructive iterator here because we are reusing the existing memory
// pages in BytesToBytesMap to hold records during sorting.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
index 80255fafbe..6f9baa2d33 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
@@ -25,6 +25,8 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, OrderedDistribution, UnspecifiedDistribution}
import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.types._
+import org.apache.spark.util.collection.unsafe.sort.RadixSort;
/**
* Performs (external) sorting.
@@ -48,7 +50,10 @@ case class Sort(
override def requiredChildDistribution: Seq[Distribution] =
if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil
+ private val enableRadixSort = sqlContext.conf.enableRadixSort
+
override private[sql] lazy val metrics = Map(
+ "sortTime" -> SQLMetrics.createLongMetric(sparkContext, "sort time"),
"dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
"spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"))
@@ -59,6 +64,9 @@ case class Sort(
val boundSortExpression = BindReferences.bindReference(sortOrder.head, output)
val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression)
+ val canUseRadixSort = enableRadixSort && sortOrder.length == 1 &&
+ SortPrefixUtils.canSortFullyWithPrefix(boundSortExpression)
+
// The generator for prefix
val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression)))
val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer {
@@ -69,7 +77,8 @@ case class Sort(
val pageSize = SparkEnv.get.memoryManager.pageSizeBytes
val sorter = new UnsafeExternalRowSorter(
- schema, ordering, prefixComparator, prefixComputer, pageSize)
+ schema, ordering, prefixComparator, prefixComputer, pageSize, canUseRadixSort)
+
if (testSpillFrequency > 0) {
sorter.setTestSpillFrequency(testSpillFrequency)
}
@@ -139,11 +148,15 @@ case class Sort(
val dataSize = metricTerm(ctx, "dataSize")
val spillSize = metricTerm(ctx, "spillSize")
val spillSizeBefore = ctx.freshName("spillSizeBefore")
+ val startTime = ctx.freshName("startTime")
+ val sortTime = metricTerm(ctx, "sortTime")
s"""
| if ($needToSort) {
| $addToSorter();
- | Long $spillSizeBefore = $metrics.memoryBytesSpilled();
+ | long $spillSizeBefore = $metrics.memoryBytesSpilled();
+ | long $startTime = System.nanoTime();
| $sortedIterator = $sorterVariable.sort();
+ | $sortTime.add(System.nanoTime() - $startTime);
| $dataSize.add($sorterVariable.getPeakMemoryUsage());
| $spillSize.add($metrics.memoryBytesSpilled() - $spillSizeBefore);
| $metrics.incPeakExecutionMemory($sorterVariable.getPeakMemoryUsage());
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
index 909f124d2c..1a5ff5fcec 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
@@ -66,6 +66,32 @@ object SortPrefixUtils {
}
/**
+ * Returns whether the specified SortOrder can be satisfied with a radix sort on the prefix.
+ */
+ def canSortFullyWithPrefix(sortOrder: SortOrder): Boolean = {
+ sortOrder.dataType match {
+ // TODO(ekl) long-type is problematic because it's null prefix representation collides with
+ // the lowest possible long value. Handle this special case outside radix sort.
+ case LongType if sortOrder.nullable =>
+ false
+ case BooleanType | ByteType | ShortType | IntegerType | LongType | DateType |
+ TimestampType | FloatType | DoubleType =>
+ true
+ case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS =>
+ true
+ case _ =>
+ false
+ }
+ }
+
+ /**
+ * Returns whether the fully sorting on the specified key field is possible with radix sort.
+ */
+ def canSortFullyWithPrefix(field: StructField): Boolean = {
+ canSortFullyWithPrefix(SortOrder(BoundReference(0, field.dataType, field.nullable), Ascending))
+ }
+
+ /**
* Creates the prefix computer for the first field in the given schema, in ascending order.
*/
def createPrefixGenerator(schema: StructType): UnsafeExternalRowSorter.PrefixComputer = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
index 85ce388de0..a46d0e0ba7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala
@@ -344,7 +344,8 @@ case class Window(
null,
null,
1024,
- SparkEnv.get.memoryManager.pageSizeBytes)
+ SparkEnv.get.memoryManager.pageSizeBytes,
+ false)
rows.foreach { r =>
sorter.insertRecord(r.getBaseObject, r.getBaseOffset, r.getSizeInBytes, 0)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala
index edb4c5a16f..b1de52b5f4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala
@@ -46,7 +46,8 @@ class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numField
null,
null,
1024,
- SparkEnv.get.memoryManager.pageSizeBytes)
+ SparkEnv.get.memoryManager.pageSizeBytes,
+ false)
val partition = split.asInstanceOf[CartesianPartition]
for (y <- rdd2.iterator(partition.s2, context)) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 6e7c1bc133..a4e82d80f5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -95,6 +95,14 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val RADIX_SORT_ENABLED = SQLConfigBuilder("spark.sql.sort.enableRadixSort")
+ .internal()
+ .doc("When true, enable use of radix sort when possible. Radix sort is much faster but " +
+ "requires additional memory to be reserved up-front. The memory overhead may be " +
+ "significant when sorting very small rows (up to 50% more in this case).")
+ .booleanConf
+ .createWithDefault(true)
+
val AUTO_BROADCASTJOIN_THRESHOLD = SQLConfigBuilder("spark.sql.autoBroadcastJoinThreshold")
.doc("Configures the maximum size in bytes for a table that will be broadcast to all worker " +
"nodes when performing a join. By setting this value to -1 broadcasting can be disabled. " +
@@ -584,6 +592,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {
def preferSortMergeJoin: Boolean = getConf(PREFER_SORTMERGEJOIN)
+ def enableRadixSort: Boolean = getConf(RADIX_SORT_ENABLED)
+
def defaultSizeInBytes: Long =
getConf(DEFAULT_SIZE_IN_BYTES, autoBroadcastJoinThreshold + 1L)