aboutsummaryrefslogtreecommitdiff
path: root/core/src
diff options
context:
space:
mode:
authorEric Liang <ekl@databricks.com>2016-06-11 15:42:58 -0700
committerReynold Xin <rxin@databricks.com>2016-06-11 15:42:58 -0700
commitc06c58bbbb2de0c22cfc70c486d23a94c3079ba4 (patch)
tree2d7b99a05f88c5e90ad5b18898447defb53fbb20 /core/src
parent75705e8dbb51ac91ffc7012fa67f072494c13832 (diff)
downloadspark-c06c58bbbb2de0c22cfc70c486d23a94c3079ba4.tar.gz
spark-c06c58bbbb2de0c22cfc70c486d23a94c3079ba4.tar.bz2
spark-c06c58bbbb2de0c22cfc70c486d23a94c3079ba4.zip
[SPARK-14851][CORE] Support radix sort with nullable longs
## What changes were proposed in this pull request? This adds support for radix sort of nullable long fields. When a sort field is null and radix sort is enabled, we keep nulls in a separate region of the sort buffer so that radix sort does not need to deal with them. This also has performance benefits when sorting smaller integer types, since the current representation of nulls in two's complement (Long.MIN_VALUE) otherwise forces a full-width radix sort. This strategy for nulls does mean the sort is no longer stable. cc davies ## How was this patch tested? Existing randomized sort tests for correctness. I also tested some TPCDS queries and there does not seem to be any significant regression for non-null sorts. Some test queries (best of 5 runs each). Before change: scala> val start = System.nanoTime; spark.range(5000000).selectExpr("if(id > 5, cast(hash(id) as long), NULL) as h").coalesce(1).orderBy("h").collect(); (System.nanoTime - start) / 1e6 start: Long = 3190437233227987 res3: Double = 4716.471091 After change: scala> val start = System.nanoTime; spark.range(5000000).selectExpr("if(id > 5, cast(hash(id) as long), NULL) as h").coalesce(1).orderBy("h").collect(); (System.nanoTime - start) / 1e6 start: Long = 3190367870952791 res4: Double = 2981.143045 Author: Eric Liang <ekl@databricks.com> Closes #13161 from ericl/sc-2998.
Diffstat (limited to 'core/src')
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java24
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java11
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java56
-rw-r--r--core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java26
-rw-r--r--core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java2
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala4
6 files changed, 83 insertions, 40 deletions
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java
index 4f3f0de7b8..404361734a 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java
@@ -170,9 +170,13 @@ public class RadixSort {
/**
* Specialization of sort() for key-prefix arrays. In this type of array, each record consists
* of two longs, only the second of which is sorted on.
+ *
+ * @param startIndex starting index in the array to sort from. This parameter is not supported
+ * in the plain sort() implementation.
*/
public static int sortKeyPrefixArray(
LongArray array,
+ int startIndex,
int numRecords,
int startByteIndex,
int endByteIndex,
@@ -182,10 +186,11 @@ public class RadixSort {
assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7";
assert endByteIndex > startByteIndex;
assert numRecords * 4 <= array.size();
- int inIndex = 0;
- int outIndex = numRecords * 2;
+ int inIndex = startIndex;
+ int outIndex = startIndex + numRecords * 2;
if (numRecords > 0) {
- long[][] counts = getKeyPrefixArrayCounts(array, numRecords, startByteIndex, endByteIndex);
+ long[][] counts = getKeyPrefixArrayCounts(
+ array, startIndex, numRecords, startByteIndex, endByteIndex);
for (int i = startByteIndex; i <= endByteIndex; i++) {
if (counts[i] != null) {
sortKeyPrefixArrayAtByte(
@@ -205,13 +210,14 @@ public class RadixSort {
* getCounts with some added parameters but that seems to hurt in benchmarks.
*/
private static long[][] getKeyPrefixArrayCounts(
- LongArray array, int numRecords, int startByteIndex, int endByteIndex) {
+ LongArray array, int startIndex, int numRecords, int startByteIndex, int endByteIndex) {
long[][] counts = new long[8][];
long bitwiseMax = 0;
long bitwiseMin = -1L;
- long limit = array.getBaseOffset() + numRecords * 16;
+ long baseOffset = array.getBaseOffset() + startIndex * 8L;
+ long limit = baseOffset + numRecords * 16L;
Object baseObject = array.getBaseObject();
- for (long offset = array.getBaseOffset(); offset < limit; offset += 16) {
+ for (long offset = baseOffset; offset < limit; offset += 16) {
long value = Platform.getLong(baseObject, offset + 8);
bitwiseMax |= value;
bitwiseMin &= value;
@@ -220,7 +226,7 @@ public class RadixSort {
for (int i = startByteIndex; i <= endByteIndex; i++) {
if (((bitsChanged >>> (i * 8)) & 0xff) != 0) {
counts[i] = new long[256];
- for (long offset = array.getBaseOffset(); offset < limit; offset += 16) {
+ for (long offset = baseOffset; offset < limit; offset += 16) {
counts[i][(int)((Platform.getLong(baseObject, offset + 8) >>> (i * 8)) & 0xff)]++;
}
}
@@ -238,8 +244,8 @@ public class RadixSort {
long[] offsets = transformCountsToOffsets(
counts, numRecords, array.getBaseOffset() + outIndex * 8, 16, desc, signed);
Object baseObject = array.getBaseObject();
- long baseOffset = array.getBaseOffset() + inIndex * 8;
- long maxOffset = baseOffset + numRecords * 16;
+ long baseOffset = array.getBaseOffset() + inIndex * 8L;
+ long maxOffset = baseOffset + numRecords * 16L;
for (long offset = baseOffset; offset < maxOffset; offset += 16) {
long key = Platform.getLong(baseObject, offset);
long prefix = Platform.getLong(baseObject, offset + 8);
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index e14a23f4a6..ec15f0b59d 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -369,7 +369,8 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
/**
* Write a record to the sorter.
*/
- public void insertRecord(Object recordBase, long recordOffset, int length, long prefix)
+ public void insertRecord(
+ Object recordBase, long recordOffset, int length, long prefix, boolean prefixIsNull)
throws IOException {
growPointerArrayIfNecessary();
@@ -384,7 +385,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length);
pageCursor += length;
assert(inMemSorter != null);
- inMemSorter.insertRecord(recordAddress, prefix);
+ inMemSorter.insertRecord(recordAddress, prefix, prefixIsNull);
}
/**
@@ -396,7 +397,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
* record length = key length + value length + 4
*/
public void insertKVRecord(Object keyBase, long keyOffset, int keyLen,
- Object valueBase, long valueOffset, int valueLen, long prefix)
+ Object valueBase, long valueOffset, int valueLen, long prefix, boolean prefixIsNull)
throws IOException {
growPointerArrayIfNecessary();
@@ -415,7 +416,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
pageCursor += valueLen;
assert(inMemSorter != null);
- inMemSorter.insertRecord(recordAddress, prefix);
+ inMemSorter.insertRecord(recordAddress, prefix, prefixIsNull);
}
/**
@@ -465,7 +466,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
private boolean loaded = false;
private int numRecords = 0;
- SpillableIterator(UnsafeInMemorySorter.SortedIterator inMemIterator) {
+ SpillableIterator(UnsafeSorterIterator inMemIterator) {
this.upstream = inMemIterator;
this.numRecords = inMemIterator.getNumRecords();
}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index c7b070f519..78da389278 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -18,6 +18,7 @@
package org.apache.spark.util.collection.unsafe.sort;
import java.util.Comparator;
+import java.util.LinkedList;
import org.apache.avro.reflect.Nullable;
@@ -93,6 +94,14 @@ public final class UnsafeInMemorySorter {
private int pos = 0;
/**
+ * If sorting with radix sort, specifies the starting position in the sort buffer where records
+ * with non-null prefixes are kept. Positions [0..nullBoundaryPos) will contain null-prefixed
+ * records, and positions [nullBoundaryPos..pos) non-null prefixed records. This lets us avoid
+ * radix sorting over null values.
+ */
+ private int nullBoundaryPos = 0;
+
+ /*
* How many records could be inserted, because part of the array should be left for sorting.
*/
private int usableCapacity = 0;
@@ -160,6 +169,7 @@ public final class UnsafeInMemorySorter {
usableCapacity = getUsableCapacity();
}
pos = 0;
+ nullBoundaryPos = 0;
}
/**
@@ -206,14 +216,27 @@ public final class UnsafeInMemorySorter {
* @param recordPointer pointer to a record in a data page, encoded by {@link TaskMemoryManager}.
* @param keyPrefix a user-defined key prefix
*/
- public void insertRecord(long recordPointer, long keyPrefix) {
+ public void insertRecord(long recordPointer, long keyPrefix, boolean prefixIsNull) {
if (!hasSpaceForAnotherRecord()) {
throw new IllegalStateException("There is no space for new record");
}
- array.set(pos, recordPointer);
- pos++;
- array.set(pos, keyPrefix);
- pos++;
+ if (prefixIsNull && radixSortSupport != null) {
+ // Swap forward a non-null record to make room for this one at the beginning of the array.
+ array.set(pos, array.get(nullBoundaryPos));
+ pos++;
+ array.set(pos, array.get(nullBoundaryPos + 1));
+ pos++;
+ // Place this record in the vacated position.
+ array.set(nullBoundaryPos, recordPointer);
+ nullBoundaryPos++;
+ array.set(nullBoundaryPos, keyPrefix);
+ nullBoundaryPos++;
+ } else {
+ array.set(pos, recordPointer);
+ pos++;
+ array.set(pos, keyPrefix);
+ pos++;
+ }
}
public final class SortedIterator extends UnsafeSorterIterator implements Cloneable {
@@ -280,15 +303,14 @@ public final class UnsafeInMemorySorter {
* Return an iterator over record pointers in sorted order. For efficiency, all calls to
* {@code next()} will return the same mutable object.
*/
- public SortedIterator getSortedIterator() {
+ public UnsafeSorterIterator getSortedIterator() {
int offset = 0;
long start = System.nanoTime();
if (sortComparator != null) {
if (this.radixSortSupport != null) {
- // TODO(ekl) we should handle NULL values before radix sort for efficiency, since they
- // force a full-width sort (and we cannot radix-sort nullable long fields at all).
offset = RadixSort.sortKeyPrefixArray(
- array, pos / 2, 0, 7, radixSortSupport.sortDescending(), radixSortSupport.sortSigned());
+ array, nullBoundaryPos, (pos - nullBoundaryPos) / 2, 0, 7,
+ radixSortSupport.sortDescending(), radixSortSupport.sortSigned());
} else {
MemoryBlock unused = new MemoryBlock(
array.getBaseObject(),
@@ -301,6 +323,20 @@ public final class UnsafeInMemorySorter {
}
}
totalSortTimeNanos += System.nanoTime() - start;
- return new SortedIterator(pos / 2, offset);
+ if (nullBoundaryPos > 0) {
+ assert radixSortSupport != null : "Nulls are only stored separately with radix sort";
+ LinkedList<UnsafeSorterIterator> queue = new LinkedList<>();
+ if (radixSortSupport.sortDescending()) {
+ // Nulls are smaller than non-nulls
+ queue.add(new SortedIterator((pos - nullBoundaryPos) / 2, offset));
+ queue.add(new SortedIterator(nullBoundaryPos / 2, 0));
+ } else {
+ queue.add(new SortedIterator(nullBoundaryPos / 2, 0));
+ queue.add(new SortedIterator((pos - nullBoundaryPos) / 2, offset));
+ }
+ return new UnsafeExternalSorter.ChainedIterator(queue);
+ } else {
+ return new SortedIterator(pos / 2, offset);
+ }
}
}
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
index 2cae4beb4c..bce958c3dc 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -156,14 +156,14 @@ public class UnsafeExternalSorterSuite {
private static void insertNumber(UnsafeExternalSorter sorter, int value) throws Exception {
final int[] arr = new int[]{ value };
- sorter.insertRecord(arr, Platform.INT_ARRAY_OFFSET, 4, value);
+ sorter.insertRecord(arr, Platform.INT_ARRAY_OFFSET, 4, value, false);
}
private static void insertRecord(
UnsafeExternalSorter sorter,
int[] record,
long prefix) throws IOException {
- sorter.insertRecord(record, Platform.INT_ARRAY_OFFSET, record.length * 4, prefix);
+ sorter.insertRecord(record, Platform.INT_ARRAY_OFFSET, record.length * 4, prefix, false);
}
private UnsafeExternalSorter newSorter() throws IOException {
@@ -206,13 +206,13 @@ public class UnsafeExternalSorterSuite {
@Test
public void testSortingEmptyArrays() throws Exception {
final UnsafeExternalSorter sorter = newSorter();
- sorter.insertRecord(null, 0, 0, 0);
- sorter.insertRecord(null, 0, 0, 0);
+ sorter.insertRecord(null, 0, 0, 0, false);
+ sorter.insertRecord(null, 0, 0, 0, false);
sorter.spill();
- sorter.insertRecord(null, 0, 0, 0);
+ sorter.insertRecord(null, 0, 0, 0, false);
sorter.spill();
- sorter.insertRecord(null, 0, 0, 0);
- sorter.insertRecord(null, 0, 0, 0);
+ sorter.insertRecord(null, 0, 0, 0, false);
+ sorter.insertRecord(null, 0, 0, 0, false);
UnsafeSorterIterator iter = sorter.getSortedIterator();
@@ -232,7 +232,7 @@ public class UnsafeExternalSorterSuite {
long prevSortTime = sorter.getSortTimeNanos();
assertEquals(prevSortTime, 0);
- sorter.insertRecord(null, 0, 0, 0);
+ sorter.insertRecord(null, 0, 0, 0, false);
sorter.spill();
assertThat(sorter.getSortTimeNanos(), greaterThan(prevSortTime));
prevSortTime = sorter.getSortTimeNanos();
@@ -240,7 +240,7 @@ public class UnsafeExternalSorterSuite {
sorter.spill(); // no sort needed
assertEquals(sorter.getSortTimeNanos(), prevSortTime);
- sorter.insertRecord(null, 0, 0, 0);
+ sorter.insertRecord(null, 0, 0, 0, false);
UnsafeSorterIterator iter = sorter.getSortedIterator();
assertThat(sorter.getSortTimeNanos(), greaterThan(prevSortTime));
}
@@ -280,7 +280,7 @@ public class UnsafeExternalSorterSuite {
final UnsafeExternalSorter sorter = newSorter();
byte[] record = new byte[16];
while (sorter.getNumberOfAllocatedPages() < 2) {
- sorter.insertRecord(record, Platform.BYTE_ARRAY_OFFSET, record.length, 0);
+ sorter.insertRecord(record, Platform.BYTE_ARRAY_OFFSET, record.length, 0, false);
}
sorter.cleanupResources();
assertSpillFilesWereCleanedUp();
@@ -340,7 +340,7 @@ public class UnsafeExternalSorterSuite {
int n = (int) pageSizeBytes / recordSize * 3;
for (int i = 0; i < n; i++) {
record[0] = (long) i;
- sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0);
+ sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0, false);
}
assertTrue(sorter.getNumberOfAllocatedPages() >= 2);
UnsafeExternalSorter.SpillableIterator iter =
@@ -372,7 +372,7 @@ public class UnsafeExternalSorterSuite {
int n = (int) pageSizeBytes / recordSize * 3;
for (int i = 0; i < n; i++) {
record[0] = (long) i;
- sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0);
+ sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0, false);
}
assertTrue(sorter.getNumberOfAllocatedPages() >= 2);
UnsafeExternalSorter.SpillableIterator iter =
@@ -406,7 +406,7 @@ public class UnsafeExternalSorterSuite {
int batch = n / 4;
for (int i = 0; i < n; i++) {
record[0] = (long) i;
- sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0);
+ sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0, false);
if (i % batch == batch - 1) {
sorter.spill();
}
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
index 383c5b3b08..bd89085aa9 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
@@ -120,7 +120,7 @@ public class UnsafeInMemorySorterSuite {
final long address = memoryManager.encodePageNumberAndOffset(dataPage, position);
final String str = getStringFromDataPage(baseObject, position + 4, recordLength);
final int partitionId = hashPartitioner.getPartition(str);
- sorter.insertRecord(address, partitionId);
+ sorter.insertRecord(address, partitionId, false);
position += 4 + recordLength;
}
final UnsafeSorterIterator iter = sorter.getSortedIterator();
diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala
index 1d26d4a830..2c13806410 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala
@@ -152,7 +152,7 @@ class RadixSortSuite extends SparkFunSuite with Logging {
val (buf1, buf2) = generateKeyPrefixTestData(N, rand.nextLong & 0xff)
referenceKeyPrefixSort(buf1, 0, N, sortType.referenceComparator)
val outOffset = RadixSort.sortKeyPrefixArray(
- buf2, N, sortType.startByteIdx, sortType.endByteIdx,
+ buf2, 0, N, sortType.startByteIdx, sortType.endByteIdx,
sortType.descending, sortType.signed)
val res1 = collectToArray(buf1, 0, N * 2)
val res2 = collectToArray(buf2, outOffset, N * 2)
@@ -177,7 +177,7 @@ class RadixSortSuite extends SparkFunSuite with Logging {
val (buf1, buf2) = generateKeyPrefixTestData(N, rand.nextLong & mask)
referenceKeyPrefixSort(buf1, 0, N, sortType.referenceComparator)
val outOffset = RadixSort.sortKeyPrefixArray(
- buf2, N, sortType.startByteIdx, sortType.endByteIdx,
+ buf2, 0, N, sortType.startByteIdx, sortType.endByteIdx,
sortType.descending, sortType.signed)
val res1 = collectToArray(buf1, 0, N * 2)
val res2 = collectToArray(buf2, outOffset, N * 2)