aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java24
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java11
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java56
-rw-r--r--core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java26
-rw-r--r--core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java2
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala4
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java20
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala40
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala32
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala11
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala2
15 files changed, 178 insertions, 79 deletions
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java
index 4f3f0de7b8..404361734a 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java
@@ -170,9 +170,13 @@ public class RadixSort {
/**
* Specialization of sort() for key-prefix arrays. In this type of array, each record consists
* of two longs, only the second of which is sorted on.
+ *
+ * @param startIndex starting index in the array to sort from. This parameter is not supported
+ * in the plain sort() implementation.
*/
public static int sortKeyPrefixArray(
LongArray array,
+ int startIndex,
int numRecords,
int startByteIndex,
int endByteIndex,
@@ -182,10 +186,11 @@ public class RadixSort {
assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7";
assert endByteIndex > startByteIndex;
assert numRecords * 4 <= array.size();
- int inIndex = 0;
- int outIndex = numRecords * 2;
+ int inIndex = startIndex;
+ int outIndex = startIndex + numRecords * 2;
if (numRecords > 0) {
- long[][] counts = getKeyPrefixArrayCounts(array, numRecords, startByteIndex, endByteIndex);
+ long[][] counts = getKeyPrefixArrayCounts(
+ array, startIndex, numRecords, startByteIndex, endByteIndex);
for (int i = startByteIndex; i <= endByteIndex; i++) {
if (counts[i] != null) {
sortKeyPrefixArrayAtByte(
@@ -205,13 +210,14 @@ public class RadixSort {
* getCounts with some added parameters but that seems to hurt in benchmarks.
*/
private static long[][] getKeyPrefixArrayCounts(
- LongArray array, int numRecords, int startByteIndex, int endByteIndex) {
+ LongArray array, int startIndex, int numRecords, int startByteIndex, int endByteIndex) {
long[][] counts = new long[8][];
long bitwiseMax = 0;
long bitwiseMin = -1L;
- long limit = array.getBaseOffset() + numRecords * 16;
+ long baseOffset = array.getBaseOffset() + startIndex * 8L;
+ long limit = baseOffset + numRecords * 16L;
Object baseObject = array.getBaseObject();
- for (long offset = array.getBaseOffset(); offset < limit; offset += 16) {
+ for (long offset = baseOffset; offset < limit; offset += 16) {
long value = Platform.getLong(baseObject, offset + 8);
bitwiseMax |= value;
bitwiseMin &= value;
@@ -220,7 +226,7 @@ public class RadixSort {
for (int i = startByteIndex; i <= endByteIndex; i++) {
if (((bitsChanged >>> (i * 8)) & 0xff) != 0) {
counts[i] = new long[256];
- for (long offset = array.getBaseOffset(); offset < limit; offset += 16) {
+ for (long offset = baseOffset; offset < limit; offset += 16) {
counts[i][(int)((Platform.getLong(baseObject, offset + 8) >>> (i * 8)) & 0xff)]++;
}
}
@@ -238,8 +244,8 @@ public class RadixSort {
long[] offsets = transformCountsToOffsets(
counts, numRecords, array.getBaseOffset() + outIndex * 8, 16, desc, signed);
Object baseObject = array.getBaseObject();
- long baseOffset = array.getBaseOffset() + inIndex * 8;
- long maxOffset = baseOffset + numRecords * 16;
+ long baseOffset = array.getBaseOffset() + inIndex * 8L;
+ long maxOffset = baseOffset + numRecords * 16L;
for (long offset = baseOffset; offset < maxOffset; offset += 16) {
long key = Platform.getLong(baseObject, offset);
long prefix = Platform.getLong(baseObject, offset + 8);
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index e14a23f4a6..ec15f0b59d 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -369,7 +369,8 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
/**
* Write a record to the sorter.
*/
- public void insertRecord(Object recordBase, long recordOffset, int length, long prefix)
+ public void insertRecord(
+ Object recordBase, long recordOffset, int length, long prefix, boolean prefixIsNull)
throws IOException {
growPointerArrayIfNecessary();
@@ -384,7 +385,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length);
pageCursor += length;
assert(inMemSorter != null);
- inMemSorter.insertRecord(recordAddress, prefix);
+ inMemSorter.insertRecord(recordAddress, prefix, prefixIsNull);
}
/**
@@ -396,7 +397,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
* record length = key length + value length + 4
*/
public void insertKVRecord(Object keyBase, long keyOffset, int keyLen,
- Object valueBase, long valueOffset, int valueLen, long prefix)
+ Object valueBase, long valueOffset, int valueLen, long prefix, boolean prefixIsNull)
throws IOException {
growPointerArrayIfNecessary();
@@ -415,7 +416,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
pageCursor += valueLen;
assert(inMemSorter != null);
- inMemSorter.insertRecord(recordAddress, prefix);
+ inMemSorter.insertRecord(recordAddress, prefix, prefixIsNull);
}
/**
@@ -465,7 +466,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
private boolean loaded = false;
private int numRecords = 0;
- SpillableIterator(UnsafeInMemorySorter.SortedIterator inMemIterator) {
+ SpillableIterator(UnsafeSorterIterator inMemIterator) {
this.upstream = inMemIterator;
this.numRecords = inMemIterator.getNumRecords();
}
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
index c7b070f519..78da389278 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java
@@ -18,6 +18,7 @@
package org.apache.spark.util.collection.unsafe.sort;
import java.util.Comparator;
+import java.util.LinkedList;
import org.apache.avro.reflect.Nullable;
@@ -93,6 +94,14 @@ public final class UnsafeInMemorySorter {
private int pos = 0;
/**
+ * If sorting with radix sort, specifies the starting position in the sort buffer where records
+ * with non-null prefixes are kept. Positions [0..nullBoundaryPos) will contain null-prefixed
+ * records, and positions [nullBoundaryPos..pos) non-null prefixed records. This lets us avoid
+ * radix sorting over null values.
+ */
+ private int nullBoundaryPos = 0;
+
+ /*
* How many records could be inserted, because part of the array should be left for sorting.
*/
private int usableCapacity = 0;
@@ -160,6 +169,7 @@ public final class UnsafeInMemorySorter {
usableCapacity = getUsableCapacity();
}
pos = 0;
+ nullBoundaryPos = 0;
}
/**
@@ -206,14 +216,27 @@ public final class UnsafeInMemorySorter {
* @param recordPointer pointer to a record in a data page, encoded by {@link TaskMemoryManager}.
* @param keyPrefix a user-defined key prefix
*/
- public void insertRecord(long recordPointer, long keyPrefix) {
+ public void insertRecord(long recordPointer, long keyPrefix, boolean prefixIsNull) {
if (!hasSpaceForAnotherRecord()) {
throw new IllegalStateException("There is no space for new record");
}
- array.set(pos, recordPointer);
- pos++;
- array.set(pos, keyPrefix);
- pos++;
+ if (prefixIsNull && radixSortSupport != null) {
+ // Swap forward a non-null record to make room for this one at the beginning of the array.
+ array.set(pos, array.get(nullBoundaryPos));
+ pos++;
+ array.set(pos, array.get(nullBoundaryPos + 1));
+ pos++;
+ // Place this record in the vacated position.
+ array.set(nullBoundaryPos, recordPointer);
+ nullBoundaryPos++;
+ array.set(nullBoundaryPos, keyPrefix);
+ nullBoundaryPos++;
+ } else {
+ array.set(pos, recordPointer);
+ pos++;
+ array.set(pos, keyPrefix);
+ pos++;
+ }
}
public final class SortedIterator extends UnsafeSorterIterator implements Cloneable {
@@ -280,15 +303,14 @@ public final class UnsafeInMemorySorter {
* Return an iterator over record pointers in sorted order. For efficiency, all calls to
* {@code next()} will return the same mutable object.
*/
- public SortedIterator getSortedIterator() {
+ public UnsafeSorterIterator getSortedIterator() {
int offset = 0;
long start = System.nanoTime();
if (sortComparator != null) {
if (this.radixSortSupport != null) {
- // TODO(ekl) we should handle NULL values before radix sort for efficiency, since they
- // force a full-width sort (and we cannot radix-sort nullable long fields at all).
offset = RadixSort.sortKeyPrefixArray(
- array, pos / 2, 0, 7, radixSortSupport.sortDescending(), radixSortSupport.sortSigned());
+ array, nullBoundaryPos, (pos - nullBoundaryPos) / 2, 0, 7,
+ radixSortSupport.sortDescending(), radixSortSupport.sortSigned());
} else {
MemoryBlock unused = new MemoryBlock(
array.getBaseObject(),
@@ -301,6 +323,20 @@ public final class UnsafeInMemorySorter {
}
}
totalSortTimeNanos += System.nanoTime() - start;
- return new SortedIterator(pos / 2, offset);
+ if (nullBoundaryPos > 0) {
+ assert radixSortSupport != null : "Nulls are only stored separately with radix sort";
+ LinkedList<UnsafeSorterIterator> queue = new LinkedList<>();
+ if (radixSortSupport.sortDescending()) {
+ // Nulls are smaller than non-nulls
+ queue.add(new SortedIterator((pos - nullBoundaryPos) / 2, offset));
+ queue.add(new SortedIterator(nullBoundaryPos / 2, 0));
+ } else {
+ queue.add(new SortedIterator(nullBoundaryPos / 2, 0));
+ queue.add(new SortedIterator((pos - nullBoundaryPos) / 2, offset));
+ }
+ return new UnsafeExternalSorter.ChainedIterator(queue);
+ } else {
+ return new SortedIterator(pos / 2, offset);
+ }
}
}
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
index 2cae4beb4c..bce958c3dc 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -156,14 +156,14 @@ public class UnsafeExternalSorterSuite {
private static void insertNumber(UnsafeExternalSorter sorter, int value) throws Exception {
final int[] arr = new int[]{ value };
- sorter.insertRecord(arr, Platform.INT_ARRAY_OFFSET, 4, value);
+ sorter.insertRecord(arr, Platform.INT_ARRAY_OFFSET, 4, value, false);
}
private static void insertRecord(
UnsafeExternalSorter sorter,
int[] record,
long prefix) throws IOException {
- sorter.insertRecord(record, Platform.INT_ARRAY_OFFSET, record.length * 4, prefix);
+ sorter.insertRecord(record, Platform.INT_ARRAY_OFFSET, record.length * 4, prefix, false);
}
private UnsafeExternalSorter newSorter() throws IOException {
@@ -206,13 +206,13 @@ public class UnsafeExternalSorterSuite {
@Test
public void testSortingEmptyArrays() throws Exception {
final UnsafeExternalSorter sorter = newSorter();
- sorter.insertRecord(null, 0, 0, 0);
- sorter.insertRecord(null, 0, 0, 0);
+ sorter.insertRecord(null, 0, 0, 0, false);
+ sorter.insertRecord(null, 0, 0, 0, false);
sorter.spill();
- sorter.insertRecord(null, 0, 0, 0);
+ sorter.insertRecord(null, 0, 0, 0, false);
sorter.spill();
- sorter.insertRecord(null, 0, 0, 0);
- sorter.insertRecord(null, 0, 0, 0);
+ sorter.insertRecord(null, 0, 0, 0, false);
+ sorter.insertRecord(null, 0, 0, 0, false);
UnsafeSorterIterator iter = sorter.getSortedIterator();
@@ -232,7 +232,7 @@ public class UnsafeExternalSorterSuite {
long prevSortTime = sorter.getSortTimeNanos();
assertEquals(prevSortTime, 0);
- sorter.insertRecord(null, 0, 0, 0);
+ sorter.insertRecord(null, 0, 0, 0, false);
sorter.spill();
assertThat(sorter.getSortTimeNanos(), greaterThan(prevSortTime));
prevSortTime = sorter.getSortTimeNanos();
@@ -240,7 +240,7 @@ public class UnsafeExternalSorterSuite {
sorter.spill(); // no sort needed
assertEquals(sorter.getSortTimeNanos(), prevSortTime);
- sorter.insertRecord(null, 0, 0, 0);
+ sorter.insertRecord(null, 0, 0, 0, false);
UnsafeSorterIterator iter = sorter.getSortedIterator();
assertThat(sorter.getSortTimeNanos(), greaterThan(prevSortTime));
}
@@ -280,7 +280,7 @@ public class UnsafeExternalSorterSuite {
final UnsafeExternalSorter sorter = newSorter();
byte[] record = new byte[16];
while (sorter.getNumberOfAllocatedPages() < 2) {
- sorter.insertRecord(record, Platform.BYTE_ARRAY_OFFSET, record.length, 0);
+ sorter.insertRecord(record, Platform.BYTE_ARRAY_OFFSET, record.length, 0, false);
}
sorter.cleanupResources();
assertSpillFilesWereCleanedUp();
@@ -340,7 +340,7 @@ public class UnsafeExternalSorterSuite {
int n = (int) pageSizeBytes / recordSize * 3;
for (int i = 0; i < n; i++) {
record[0] = (long) i;
- sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0);
+ sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0, false);
}
assertTrue(sorter.getNumberOfAllocatedPages() >= 2);
UnsafeExternalSorter.SpillableIterator iter =
@@ -372,7 +372,7 @@ public class UnsafeExternalSorterSuite {
int n = (int) pageSizeBytes / recordSize * 3;
for (int i = 0; i < n; i++) {
record[0] = (long) i;
- sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0);
+ sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0, false);
}
assertTrue(sorter.getNumberOfAllocatedPages() >= 2);
UnsafeExternalSorter.SpillableIterator iter =
@@ -406,7 +406,7 @@ public class UnsafeExternalSorterSuite {
int batch = n / 4;
for (int i = 0; i < n; i++) {
record[0] = (long) i;
- sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0);
+ sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0, false);
if (i % batch == batch - 1) {
sorter.spill();
}
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
index 383c5b3b08..bd89085aa9 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java
@@ -120,7 +120,7 @@ public class UnsafeInMemorySorterSuite {
final long address = memoryManager.encodePageNumberAndOffset(dataPage, position);
final String str = getStringFromDataPage(baseObject, position + 4, recordLength);
final int partitionId = hashPartitioner.getPartition(str);
- sorter.insertRecord(address, partitionId);
+ sorter.insertRecord(address, partitionId, false);
position += 4 + recordLength;
}
final UnsafeSorterIterator iter = sorter.getSortedIterator();
diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala
index 1d26d4a830..2c13806410 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala
@@ -152,7 +152,7 @@ class RadixSortSuite extends SparkFunSuite with Logging {
val (buf1, buf2) = generateKeyPrefixTestData(N, rand.nextLong & 0xff)
referenceKeyPrefixSort(buf1, 0, N, sortType.referenceComparator)
val outOffset = RadixSort.sortKeyPrefixArray(
- buf2, N, sortType.startByteIdx, sortType.endByteIdx,
+ buf2, 0, N, sortType.startByteIdx, sortType.endByteIdx,
sortType.descending, sortType.signed)
val res1 = collectToArray(buf1, 0, N * 2)
val res2 = collectToArray(buf2, outOffset, N * 2)
@@ -177,7 +177,7 @@ class RadixSortSuite extends SparkFunSuite with Logging {
val (buf1, buf2) = generateKeyPrefixTestData(N, rand.nextLong & mask)
referenceKeyPrefixSort(buf1, 0, N, sortType.referenceComparator)
val outOffset = RadixSort.sortKeyPrefixArray(
- buf2, N, sortType.startByteIdx, sortType.endByteIdx,
+ buf2, 0, N, sortType.startByteIdx, sortType.endByteIdx,
sortType.descending, sortType.signed)
val res1 = collectToArray(buf1, 0, N * 2)
val res2 = collectToArray(buf2, outOffset, N * 2)
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
index 37fbad47c1..ad76bf5a0a 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -51,7 +51,20 @@ public final class UnsafeExternalRowSorter {
private final UnsafeExternalSorter sorter;
public abstract static class PrefixComputer {
- abstract long computePrefix(InternalRow row);
+
+ public static class Prefix {
+ /** Key prefix value, or the null prefix value if isNull = true. **/
+ long value;
+
+ /** Whether the key is null. */
+ boolean isNull;
+ }
+
+ /**
+ * Computes prefix for the given row. For efficiency, the returned object may be reused in
+ * further calls to a given PrefixComputer.
+ */
+ abstract Prefix computePrefix(InternalRow row);
}
public UnsafeExternalRowSorter(
@@ -88,12 +101,13 @@ public final class UnsafeExternalRowSorter {
}
public void insertRow(UnsafeRow row) throws IOException {
- final long prefix = prefixComputer.computePrefix(row);
+ final PrefixComputer.Prefix prefix = prefixComputer.computePrefix(row);
sorter.insertRecord(
row.getBaseObject(),
row.getBaseOffset(),
row.getSizeInBytes(),
- prefix
+ prefix.value,
+ prefix.isNull
);
numRowsInserted++;
if (testSpillFrequency > 0 && (numRowsInserted % testSpillFrequency) == 0) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
index 42a8be6b1b..de779ed370 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
@@ -64,10 +64,21 @@ case class SortOrder(child: Expression, direction: SortDirection)
}
/**
- * An expression to generate a 64-bit long prefix used in sorting.
+ * An expression to generate a 64-bit long prefix used in sorting. If the sort must operate over
+ * null keys as well, this.nullValue can be used in place of emitted null prefixes in the sort.
*/
case class SortPrefix(child: SortOrder) extends UnaryExpression {
+ val nullValue = child.child.dataType match {
+ case BooleanType | DateType | TimestampType | _: IntegralType =>
+ Long.MinValue
+ case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS =>
+ Long.MinValue
+ case _: DecimalType =>
+ DoublePrefixComparator.computePrefix(Double.NegativeInfinity)
+ case _ => 0L
+ }
+
override def eval(input: InternalRow): Any = throw new UnsupportedOperationException
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
@@ -75,20 +86,19 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression {
val input = childCode.value
val BinaryPrefixCmp = classOf[BinaryPrefixComparator].getName
val DoublePrefixCmp = classOf[DoublePrefixComparator].getName
-
- val (nullValue: Long, prefixCode: String) = child.child.dataType match {
+ val prefixCode = child.child.dataType match {
case BooleanType =>
- (Long.MinValue, s"$input ? 1L : 0L")
+ s"$input ? 1L : 0L"
case _: IntegralType =>
- (Long.MinValue, s"(long) $input")
+ s"(long) $input"
case DateType | TimestampType =>
- (Long.MinValue, s"(long) $input")
+ s"(long) $input"
case FloatType | DoubleType =>
- (0L, s"$DoublePrefixCmp.computePrefix((double)$input)")
- case StringType => (0L, s"$input.getPrefix()")
- case BinaryType => (0L, s"$BinaryPrefixCmp.computePrefix($input)")
+ s"$DoublePrefixCmp.computePrefix((double)$input)"
+ case StringType => s"$input.getPrefix()"
+ case BinaryType => s"$BinaryPrefixCmp.computePrefix($input)"
case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS =>
- val prefix = if (dt.precision <= Decimal.MAX_LONG_DIGITS) {
+ if (dt.precision <= Decimal.MAX_LONG_DIGITS) {
s"$input.toUnscaledLong()"
} else {
// reduce the scale to fit in a long
@@ -96,17 +106,15 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression {
val s = p - (dt.precision - dt.scale)
s"$input.changePrecision($p, $s) ? $input.toUnscaledLong() : ${Long.MinValue}L"
}
- (Long.MinValue, prefix)
case dt: DecimalType =>
- (DoublePrefixComparator.computePrefix(Double.NegativeInfinity),
- s"$DoublePrefixCmp.computePrefix($input.toDouble())")
- case _ => (0L, "0L")
+ s"$DoublePrefixCmp.computePrefix($input.toDouble())"
+ case _ => "0L"
}
ev.copy(code = childCode.code +
s"""
- |long ${ev.value} = ${nullValue}L;
- |boolean ${ev.isNull} = false;
+ |long ${ev.value} = 0L;
+ |boolean ${ev.isNull} = ${childCode.isNull};
|if (!${childCode.isNull}) {
| ${ev.value} = $prefixCode;
|}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
index bb823cd07b..99fe51db68 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
@@ -118,9 +118,10 @@ public final class UnsafeKVExternalSorter {
// Compute prefix
row.pointTo(baseObject, baseOffset, loc.getKeyLength());
- final long prefix = prefixComputer.computePrefix(row);
+ final UnsafeExternalRowSorter.PrefixComputer.Prefix prefix =
+ prefixComputer.computePrefix(row);
- inMemSorter.insertRecord(address, prefix);
+ inMemSorter.insertRecord(address, prefix.value, prefix.isNull);
}
sorter = UnsafeExternalSorter.createWithExistingInMemorySorter(
@@ -146,10 +147,12 @@ public final class UnsafeKVExternalSorter {
* sorted runs, and then reallocates memory to hold the new record.
*/
public void insertKV(UnsafeRow key, UnsafeRow value) throws IOException {
- final long prefix = prefixComputer.computePrefix(key);
+ final UnsafeExternalRowSorter.PrefixComputer.Prefix prefix =
+ prefixComputer.computePrefix(key);
sorter.insertKVRecord(
key.getBaseObject(), key.getBaseOffset(), key.getSizeInBytes(),
- value.getBaseObject(), value.getBaseOffset(), value.getSizeInBytes(), prefix);
+ value.getBaseObject(), value.getBaseOffset(), value.getSizeInBytes(),
+ prefix.value, prefix.isNull);
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
index 66a16ac576..6db7f45cfd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
@@ -68,10 +68,16 @@ case class SortExec(
SortPrefixUtils.canSortFullyWithPrefix(boundSortExpression)
// The generator for prefix
- val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression)))
+ val prefixExpr = SortPrefix(boundSortExpression)
+ val prefixProjection = UnsafeProjection.create(Seq(prefixExpr))
val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer {
- override def computePrefix(row: InternalRow): Long = {
- prefixProjection.apply(row).getLong(0)
+ private val result = new UnsafeExternalRowSorter.PrefixComputer.Prefix
+ override def computePrefix(row: InternalRow):
+ UnsafeExternalRowSorter.PrefixComputer.Prefix = {
+ val prefix = prefixProjection.apply(row)
+ result.isNull = prefix.isNullAt(0)
+ result.value = if (result.isNull) prefixExpr.nullValue else prefix.getLong(0)
+ result
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
index 1a5ff5fcec..940467e74d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
@@ -33,6 +33,11 @@ object SortPrefixUtils {
override def compare(prefix1: Long, prefix2: Long): Int = 0
}
+ /**
+ * Dummy sort prefix result to use for empty rows.
+ */
+ private val emptyPrefix = new UnsafeExternalRowSorter.PrefixComputer.Prefix
+
def getPrefixComparator(sortOrder: SortOrder): PrefixComparator = {
sortOrder.dataType match {
case StringType =>
@@ -70,10 +75,6 @@ object SortPrefixUtils {
*/
def canSortFullyWithPrefix(sortOrder: SortOrder): Boolean = {
sortOrder.dataType match {
- // TODO(ekl) long-type is problematic because it's null prefix representation collides with
- // the lowest possible long value. Handle this special case outside radix sort.
- case LongType if sortOrder.nullable =>
- false
case BooleanType | ByteType | ShortType | IntegerType | LongType | DateType |
TimestampType | FloatType | DoubleType =>
true
@@ -97,16 +98,29 @@ object SortPrefixUtils {
def createPrefixGenerator(schema: StructType): UnsafeExternalRowSorter.PrefixComputer = {
if (schema.nonEmpty) {
val boundReference = BoundReference(0, schema.head.dataType, nullable = true)
- val prefixProjection = UnsafeProjection.create(
- SortPrefix(SortOrder(boundReference, Ascending)))
+ val prefixExpr = SortPrefix(SortOrder(boundReference, Ascending))
+ val prefixProjection = UnsafeProjection.create(prefixExpr)
new UnsafeExternalRowSorter.PrefixComputer {
- override def computePrefix(row: InternalRow): Long = {
- prefixProjection.apply(row).getLong(0)
+ private val result = new UnsafeExternalRowSorter.PrefixComputer.Prefix
+ override def computePrefix(row: InternalRow):
+ UnsafeExternalRowSorter.PrefixComputer.Prefix = {
+ val prefix = prefixProjection.apply(row)
+ if (prefix.isNullAt(0)) {
+ result.isNull = true
+ result.value = prefixExpr.nullValue
+ } else {
+ result.isNull = false
+ result.value = prefix.getLong(0)
+ }
+ result
}
}
} else {
new UnsafeExternalRowSorter.PrefixComputer {
- override def computePrefix(row: InternalRow): Long = 0
+ override def computePrefix(row: InternalRow):
+ UnsafeExternalRowSorter.PrefixComputer.Prefix = {
+ emptyPrefix
+ }
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala
index 97bbab65af..1b9634cfc0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala
@@ -347,13 +347,13 @@ case class WindowExec(
SparkEnv.get.memoryManager.pageSizeBytes,
false)
rows.foreach { r =>
- sorter.insertRecord(r.getBaseObject, r.getBaseOffset, r.getSizeInBytes, 0)
+ sorter.insertRecord(r.getBaseObject, r.getBaseOffset, r.getSizeInBytes, 0, false)
}
rows.clear()
}
} else {
sorter.insertRecord(nextRow.getBaseObject, nextRow.getBaseOffset,
- nextRow.getSizeInBytes, 0)
+ nextRow.getSizeInBytes, 0, false)
}
fetchNextRow()
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
index 88f78a7a73..d870d91edc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
@@ -53,7 +53,7 @@ class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numField
val partition = split.asInstanceOf[CartesianPartition]
for (y <- rdd2.iterator(partition.s2, context)) {
- sorter.insertRecord(y.getBaseObject, y.getBaseOffset, y.getSizeInBytes, 0)
+ sorter.insertRecord(y.getBaseObject, y.getBaseOffset, y.getSizeInBytes, 0, false)
}
// Create an iterator from sorter and wrapper it as Iterator[UnsafeRow]
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
index c3acf29c2d..ba3fa3732d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
@@ -54,6 +54,17 @@ class SortSuite extends SparkPlanTest with SharedSQLContext {
sortAnswers = false)
}
+ test("sorting all nulls") {
+ checkThatPlansAgree(
+ (1 to 100).map(v => Tuple1(v)).toDF().selectExpr("NULL as a"),
+ (child: SparkPlan) =>
+ GlobalLimitExec(10, SortExec('a.asc :: Nil, global = true, child = child)),
+ (child: SparkPlan) =>
+ GlobalLimitExec(10, ReferenceSort('a.asc :: Nil, global = true, child)),
+ sortAnswers = false
+ )
+ }
+
test("sort followed by limit") {
checkThatPlansAgree(
(1 to 100).map(v => Tuple1(v)).toDF("a"),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala
index 9964b7373f..50ae26a3ff 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala
@@ -110,7 +110,7 @@ class SortBenchmark extends BenchmarkBase {
benchmark.addTimerCase("radix sort key prefix array") { timer =>
val (_, buf2) = generateKeyPrefixTestData(size, rand.nextLong)
timer.startTiming()
- RadixSort.sortKeyPrefixArray(buf2, size, 0, 7, false, false)
+ RadixSort.sortKeyPrefixArray(buf2, 0, size, 0, 7, false, false)
timer.stopTiming()
}
benchmark.run()