diff options
15 files changed, 178 insertions, 79 deletions
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java index 4f3f0de7b8..404361734a 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java @@ -170,9 +170,13 @@ public class RadixSort { /** * Specialization of sort() for key-prefix arrays. In this type of array, each record consists * of two longs, only the second of which is sorted on. + * + * @param startIndex starting index in the array to sort from. This parameter is not supported + * in the plain sort() implementation. */ public static int sortKeyPrefixArray( LongArray array, + int startIndex, int numRecords, int startByteIndex, int endByteIndex, @@ -182,10 +186,11 @@ public class RadixSort { assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7"; assert endByteIndex > startByteIndex; assert numRecords * 4 <= array.size(); - int inIndex = 0; - int outIndex = numRecords * 2; + int inIndex = startIndex; + int outIndex = startIndex + numRecords * 2; if (numRecords > 0) { - long[][] counts = getKeyPrefixArrayCounts(array, numRecords, startByteIndex, endByteIndex); + long[][] counts = getKeyPrefixArrayCounts( + array, startIndex, numRecords, startByteIndex, endByteIndex); for (int i = startByteIndex; i <= endByteIndex; i++) { if (counts[i] != null) { sortKeyPrefixArrayAtByte( @@ -205,13 +210,14 @@ public class RadixSort { * getCounts with some added parameters but that seems to hurt in benchmarks. */ private static long[][] getKeyPrefixArrayCounts( - LongArray array, int numRecords, int startByteIndex, int endByteIndex) { + LongArray array, int startIndex, int numRecords, int startByteIndex, int endByteIndex) { long[][] counts = new long[8][]; long bitwiseMax = 0; long bitwiseMin = -1L; - long limit = array.getBaseOffset() + numRecords * 16; + long baseOffset = array.getBaseOffset() + startIndex * 8L; + long limit = baseOffset + numRecords * 16L; Object baseObject = array.getBaseObject(); - for (long offset = array.getBaseOffset(); offset < limit; offset += 16) { + for (long offset = baseOffset; offset < limit; offset += 16) { long value = Platform.getLong(baseObject, offset + 8); bitwiseMax |= value; bitwiseMin &= value; @@ -220,7 +226,7 @@ public class RadixSort { for (int i = startByteIndex; i <= endByteIndex; i++) { if (((bitsChanged >>> (i * 8)) & 0xff) != 0) { counts[i] = new long[256]; - for (long offset = array.getBaseOffset(); offset < limit; offset += 16) { + for (long offset = baseOffset; offset < limit; offset += 16) { counts[i][(int)((Platform.getLong(baseObject, offset + 8) >>> (i * 8)) & 0xff)]++; } } @@ -238,8 +244,8 @@ public class RadixSort { long[] offsets = transformCountsToOffsets( counts, numRecords, array.getBaseOffset() + outIndex * 8, 16, desc, signed); Object baseObject = array.getBaseObject(); - long baseOffset = array.getBaseOffset() + inIndex * 8; - long maxOffset = baseOffset + numRecords * 16; + long baseOffset = array.getBaseOffset() + inIndex * 8L; + long maxOffset = baseOffset + numRecords * 16L; for (long offset = baseOffset; offset < maxOffset; offset += 16) { long key = Platform.getLong(baseObject, offset); long prefix = Platform.getLong(baseObject, offset + 8); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index e14a23f4a6..ec15f0b59d 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -369,7 +369,8 @@ public final class UnsafeExternalSorter extends MemoryConsumer { /** * Write a record to the sorter. */ - public void insertRecord(Object recordBase, long recordOffset, int length, long prefix) + public void insertRecord( + Object recordBase, long recordOffset, int length, long prefix, boolean prefixIsNull) throws IOException { growPointerArrayIfNecessary(); @@ -384,7 +385,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer { Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length); pageCursor += length; assert(inMemSorter != null); - inMemSorter.insertRecord(recordAddress, prefix); + inMemSorter.insertRecord(recordAddress, prefix, prefixIsNull); } /** @@ -396,7 +397,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer { * record length = key length + value length + 4 */ public void insertKVRecord(Object keyBase, long keyOffset, int keyLen, - Object valueBase, long valueOffset, int valueLen, long prefix) + Object valueBase, long valueOffset, int valueLen, long prefix, boolean prefixIsNull) throws IOException { growPointerArrayIfNecessary(); @@ -415,7 +416,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer { pageCursor += valueLen; assert(inMemSorter != null); - inMemSorter.insertRecord(recordAddress, prefix); + inMemSorter.insertRecord(recordAddress, prefix, prefixIsNull); } /** @@ -465,7 +466,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer { private boolean loaded = false; private int numRecords = 0; - SpillableIterator(UnsafeInMemorySorter.SortedIterator inMemIterator) { + SpillableIterator(UnsafeSorterIterator inMemIterator) { this.upstream = inMemIterator; this.numRecords = inMemIterator.getNumRecords(); } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index c7b070f519..78da389278 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -18,6 +18,7 @@ package org.apache.spark.util.collection.unsafe.sort; import java.util.Comparator; +import java.util.LinkedList; import org.apache.avro.reflect.Nullable; @@ -93,6 +94,14 @@ public final class UnsafeInMemorySorter { private int pos = 0; /** + * If sorting with radix sort, specifies the starting position in the sort buffer where records + * with non-null prefixes are kept. Positions [0..nullBoundaryPos) will contain null-prefixed + * records, and positions [nullBoundaryPos..pos) non-null prefixed records. This lets us avoid + * radix sorting over null values. + */ + private int nullBoundaryPos = 0; + + /* * How many records could be inserted, because part of the array should be left for sorting. */ private int usableCapacity = 0; @@ -160,6 +169,7 @@ public final class UnsafeInMemorySorter { usableCapacity = getUsableCapacity(); } pos = 0; + nullBoundaryPos = 0; } /** @@ -206,14 +216,27 @@ public final class UnsafeInMemorySorter { * @param recordPointer pointer to a record in a data page, encoded by {@link TaskMemoryManager}. * @param keyPrefix a user-defined key prefix */ - public void insertRecord(long recordPointer, long keyPrefix) { + public void insertRecord(long recordPointer, long keyPrefix, boolean prefixIsNull) { if (!hasSpaceForAnotherRecord()) { throw new IllegalStateException("There is no space for new record"); } - array.set(pos, recordPointer); - pos++; - array.set(pos, keyPrefix); - pos++; + if (prefixIsNull && radixSortSupport != null) { + // Swap forward a non-null record to make room for this one at the beginning of the array. + array.set(pos, array.get(nullBoundaryPos)); + pos++; + array.set(pos, array.get(nullBoundaryPos + 1)); + pos++; + // Place this record in the vacated position. + array.set(nullBoundaryPos, recordPointer); + nullBoundaryPos++; + array.set(nullBoundaryPos, keyPrefix); + nullBoundaryPos++; + } else { + array.set(pos, recordPointer); + pos++; + array.set(pos, keyPrefix); + pos++; + } } public final class SortedIterator extends UnsafeSorterIterator implements Cloneable { @@ -280,15 +303,14 @@ public final class UnsafeInMemorySorter { * Return an iterator over record pointers in sorted order. For efficiency, all calls to * {@code next()} will return the same mutable object. */ - public SortedIterator getSortedIterator() { + public UnsafeSorterIterator getSortedIterator() { int offset = 0; long start = System.nanoTime(); if (sortComparator != null) { if (this.radixSortSupport != null) { - // TODO(ekl) we should handle NULL values before radix sort for efficiency, since they - // force a full-width sort (and we cannot radix-sort nullable long fields at all). offset = RadixSort.sortKeyPrefixArray( - array, pos / 2, 0, 7, radixSortSupport.sortDescending(), radixSortSupport.sortSigned()); + array, nullBoundaryPos, (pos - nullBoundaryPos) / 2, 0, 7, + radixSortSupport.sortDescending(), radixSortSupport.sortSigned()); } else { MemoryBlock unused = new MemoryBlock( array.getBaseObject(), @@ -301,6 +323,20 @@ public final class UnsafeInMemorySorter { } } totalSortTimeNanos += System.nanoTime() - start; - return new SortedIterator(pos / 2, offset); + if (nullBoundaryPos > 0) { + assert radixSortSupport != null : "Nulls are only stored separately with radix sort"; + LinkedList<UnsafeSorterIterator> queue = new LinkedList<>(); + if (radixSortSupport.sortDescending()) { + // Nulls are smaller than non-nulls + queue.add(new SortedIterator((pos - nullBoundaryPos) / 2, offset)); + queue.add(new SortedIterator(nullBoundaryPos / 2, 0)); + } else { + queue.add(new SortedIterator(nullBoundaryPos / 2, 0)); + queue.add(new SortedIterator((pos - nullBoundaryPos) / 2, offset)); + } + return new UnsafeExternalSorter.ChainedIterator(queue); + } else { + return new SortedIterator(pos / 2, offset); + } } } diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 2cae4beb4c..bce958c3dc 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -156,14 +156,14 @@ public class UnsafeExternalSorterSuite { private static void insertNumber(UnsafeExternalSorter sorter, int value) throws Exception { final int[] arr = new int[]{ value }; - sorter.insertRecord(arr, Platform.INT_ARRAY_OFFSET, 4, value); + sorter.insertRecord(arr, Platform.INT_ARRAY_OFFSET, 4, value, false); } private static void insertRecord( UnsafeExternalSorter sorter, int[] record, long prefix) throws IOException { - sorter.insertRecord(record, Platform.INT_ARRAY_OFFSET, record.length * 4, prefix); + sorter.insertRecord(record, Platform.INT_ARRAY_OFFSET, record.length * 4, prefix, false); } private UnsafeExternalSorter newSorter() throws IOException { @@ -206,13 +206,13 @@ public class UnsafeExternalSorterSuite { @Test public void testSortingEmptyArrays() throws Exception { final UnsafeExternalSorter sorter = newSorter(); - sorter.insertRecord(null, 0, 0, 0); - sorter.insertRecord(null, 0, 0, 0); + sorter.insertRecord(null, 0, 0, 0, false); + sorter.insertRecord(null, 0, 0, 0, false); sorter.spill(); - sorter.insertRecord(null, 0, 0, 0); + sorter.insertRecord(null, 0, 0, 0, false); sorter.spill(); - sorter.insertRecord(null, 0, 0, 0); - sorter.insertRecord(null, 0, 0, 0); + sorter.insertRecord(null, 0, 0, 0, false); + sorter.insertRecord(null, 0, 0, 0, false); UnsafeSorterIterator iter = sorter.getSortedIterator(); @@ -232,7 +232,7 @@ public class UnsafeExternalSorterSuite { long prevSortTime = sorter.getSortTimeNanos(); assertEquals(prevSortTime, 0); - sorter.insertRecord(null, 0, 0, 0); + sorter.insertRecord(null, 0, 0, 0, false); sorter.spill(); assertThat(sorter.getSortTimeNanos(), greaterThan(prevSortTime)); prevSortTime = sorter.getSortTimeNanos(); @@ -240,7 +240,7 @@ public class UnsafeExternalSorterSuite { sorter.spill(); // no sort needed assertEquals(sorter.getSortTimeNanos(), prevSortTime); - sorter.insertRecord(null, 0, 0, 0); + sorter.insertRecord(null, 0, 0, 0, false); UnsafeSorterIterator iter = sorter.getSortedIterator(); assertThat(sorter.getSortTimeNanos(), greaterThan(prevSortTime)); } @@ -280,7 +280,7 @@ public class UnsafeExternalSorterSuite { final UnsafeExternalSorter sorter = newSorter(); byte[] record = new byte[16]; while (sorter.getNumberOfAllocatedPages() < 2) { - sorter.insertRecord(record, Platform.BYTE_ARRAY_OFFSET, record.length, 0); + sorter.insertRecord(record, Platform.BYTE_ARRAY_OFFSET, record.length, 0, false); } sorter.cleanupResources(); assertSpillFilesWereCleanedUp(); @@ -340,7 +340,7 @@ public class UnsafeExternalSorterSuite { int n = (int) pageSizeBytes / recordSize * 3; for (int i = 0; i < n; i++) { record[0] = (long) i; - sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0); + sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0, false); } assertTrue(sorter.getNumberOfAllocatedPages() >= 2); UnsafeExternalSorter.SpillableIterator iter = @@ -372,7 +372,7 @@ public class UnsafeExternalSorterSuite { int n = (int) pageSizeBytes / recordSize * 3; for (int i = 0; i < n; i++) { record[0] = (long) i; - sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0); + sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0, false); } assertTrue(sorter.getNumberOfAllocatedPages() >= 2); UnsafeExternalSorter.SpillableIterator iter = @@ -406,7 +406,7 @@ public class UnsafeExternalSorterSuite { int batch = n / 4; for (int i = 0; i < n; i++) { record[0] = (long) i; - sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0); + sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0, false); if (i % batch == batch - 1) { sorter.spill(); } diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index 383c5b3b08..bd89085aa9 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -120,7 +120,7 @@ public class UnsafeInMemorySorterSuite { final long address = memoryManager.encodePageNumberAndOffset(dataPage, position); final String str = getStringFromDataPage(baseObject, position + 4, recordLength); final int partitionId = hashPartitioner.getPartition(str); - sorter.insertRecord(address, partitionId); + sorter.insertRecord(address, partitionId, false); position += 4 + recordLength; } final UnsafeSorterIterator iter = sorter.getSortedIterator(); diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala index 1d26d4a830..2c13806410 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala @@ -152,7 +152,7 @@ class RadixSortSuite extends SparkFunSuite with Logging { val (buf1, buf2) = generateKeyPrefixTestData(N, rand.nextLong & 0xff) referenceKeyPrefixSort(buf1, 0, N, sortType.referenceComparator) val outOffset = RadixSort.sortKeyPrefixArray( - buf2, N, sortType.startByteIdx, sortType.endByteIdx, + buf2, 0, N, sortType.startByteIdx, sortType.endByteIdx, sortType.descending, sortType.signed) val res1 = collectToArray(buf1, 0, N * 2) val res2 = collectToArray(buf2, outOffset, N * 2) @@ -177,7 +177,7 @@ class RadixSortSuite extends SparkFunSuite with Logging { val (buf1, buf2) = generateKeyPrefixTestData(N, rand.nextLong & mask) referenceKeyPrefixSort(buf1, 0, N, sortType.referenceComparator) val outOffset = RadixSort.sortKeyPrefixArray( - buf2, N, sortType.startByteIdx, sortType.endByteIdx, + buf2, 0, N, sortType.startByteIdx, sortType.endByteIdx, sortType.descending, sortType.signed) val res1 = collectToArray(buf1, 0, N * 2) val res2 = collectToArray(buf2, outOffset, N * 2) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 37fbad47c1..ad76bf5a0a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -51,7 +51,20 @@ public final class UnsafeExternalRowSorter { private final UnsafeExternalSorter sorter; public abstract static class PrefixComputer { - abstract long computePrefix(InternalRow row); + + public static class Prefix { + /** Key prefix value, or the null prefix value if isNull = true. **/ + long value; + + /** Whether the key is null. */ + boolean isNull; + } + + /** + * Computes prefix for the given row. For efficiency, the returned object may be reused in + * further calls to a given PrefixComputer. + */ + abstract Prefix computePrefix(InternalRow row); } public UnsafeExternalRowSorter( @@ -88,12 +101,13 @@ public final class UnsafeExternalRowSorter { } public void insertRow(UnsafeRow row) throws IOException { - final long prefix = prefixComputer.computePrefix(row); + final PrefixComputer.Prefix prefix = prefixComputer.computePrefix(row); sorter.insertRecord( row.getBaseObject(), row.getBaseOffset(), row.getSizeInBytes(), - prefix + prefix.value, + prefix.isNull ); numRowsInserted++; if (testSpillFrequency > 0 && (numRowsInserted % testSpillFrequency) == 0) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 42a8be6b1b..de779ed370 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -64,10 +64,21 @@ case class SortOrder(child: Expression, direction: SortDirection) } /** - * An expression to generate a 64-bit long prefix used in sorting. + * An expression to generate a 64-bit long prefix used in sorting. If the sort must operate over + * null keys as well, this.nullValue can be used in place of emitted null prefixes in the sort. */ case class SortPrefix(child: SortOrder) extends UnaryExpression { + val nullValue = child.child.dataType match { + case BooleanType | DateType | TimestampType | _: IntegralType => + Long.MinValue + case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS => + Long.MinValue + case _: DecimalType => + DoublePrefixComparator.computePrefix(Double.NegativeInfinity) + case _ => 0L + } + override def eval(input: InternalRow): Any = throw new UnsupportedOperationException override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -75,20 +86,19 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { val input = childCode.value val BinaryPrefixCmp = classOf[BinaryPrefixComparator].getName val DoublePrefixCmp = classOf[DoublePrefixComparator].getName - - val (nullValue: Long, prefixCode: String) = child.child.dataType match { + val prefixCode = child.child.dataType match { case BooleanType => - (Long.MinValue, s"$input ? 1L : 0L") + s"$input ? 1L : 0L" case _: IntegralType => - (Long.MinValue, s"(long) $input") + s"(long) $input" case DateType | TimestampType => - (Long.MinValue, s"(long) $input") + s"(long) $input" case FloatType | DoubleType => - (0L, s"$DoublePrefixCmp.computePrefix((double)$input)") - case StringType => (0L, s"$input.getPrefix()") - case BinaryType => (0L, s"$BinaryPrefixCmp.computePrefix($input)") + s"$DoublePrefixCmp.computePrefix((double)$input)" + case StringType => s"$input.getPrefix()" + case BinaryType => s"$BinaryPrefixCmp.computePrefix($input)" case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS => - val prefix = if (dt.precision <= Decimal.MAX_LONG_DIGITS) { + if (dt.precision <= Decimal.MAX_LONG_DIGITS) { s"$input.toUnscaledLong()" } else { // reduce the scale to fit in a long @@ -96,17 +106,15 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { val s = p - (dt.precision - dt.scale) s"$input.changePrecision($p, $s) ? $input.toUnscaledLong() : ${Long.MinValue}L" } - (Long.MinValue, prefix) case dt: DecimalType => - (DoublePrefixComparator.computePrefix(Double.NegativeInfinity), - s"$DoublePrefixCmp.computePrefix($input.toDouble())") - case _ => (0L, "0L") + s"$DoublePrefixCmp.computePrefix($input.toDouble())" + case _ => "0L" } ev.copy(code = childCode.code + s""" - |long ${ev.value} = ${nullValue}L; - |boolean ${ev.isNull} = false; + |long ${ev.value} = 0L; + |boolean ${ev.isNull} = ${childCode.isNull}; |if (!${childCode.isNull}) { | ${ev.value} = $prefixCode; |} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index bb823cd07b..99fe51db68 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -118,9 +118,10 @@ public final class UnsafeKVExternalSorter { // Compute prefix row.pointTo(baseObject, baseOffset, loc.getKeyLength()); - final long prefix = prefixComputer.computePrefix(row); + final UnsafeExternalRowSorter.PrefixComputer.Prefix prefix = + prefixComputer.computePrefix(row); - inMemSorter.insertRecord(address, prefix); + inMemSorter.insertRecord(address, prefix.value, prefix.isNull); } sorter = UnsafeExternalSorter.createWithExistingInMemorySorter( @@ -146,10 +147,12 @@ public final class UnsafeKVExternalSorter { * sorted runs, and then reallocates memory to hold the new record. */ public void insertKV(UnsafeRow key, UnsafeRow value) throws IOException { - final long prefix = prefixComputer.computePrefix(key); + final UnsafeExternalRowSorter.PrefixComputer.Prefix prefix = + prefixComputer.computePrefix(key); sorter.insertKVRecord( key.getBaseObject(), key.getBaseOffset(), key.getSizeInBytes(), - value.getBaseObject(), value.getBaseOffset(), value.getSizeInBytes(), prefix); + value.getBaseObject(), value.getBaseOffset(), value.getSizeInBytes(), + prefix.value, prefix.isNull); } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index 66a16ac576..6db7f45cfd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -68,10 +68,16 @@ case class SortExec( SortPrefixUtils.canSortFullyWithPrefix(boundSortExpression) // The generator for prefix - val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression))) + val prefixExpr = SortPrefix(boundSortExpression) + val prefixProjection = UnsafeProjection.create(Seq(prefixExpr)) val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { - override def computePrefix(row: InternalRow): Long = { - prefixProjection.apply(row).getLong(0) + private val result = new UnsafeExternalRowSorter.PrefixComputer.Prefix + override def computePrefix(row: InternalRow): + UnsafeExternalRowSorter.PrefixComputer.Prefix = { + val prefix = prefixProjection.apply(row) + result.isNull = prefix.isNullAt(0) + result.value = if (result.isNull) prefixExpr.nullValue else prefix.getLong(0) + result } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala index 1a5ff5fcec..940467e74d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala @@ -33,6 +33,11 @@ object SortPrefixUtils { override def compare(prefix1: Long, prefix2: Long): Int = 0 } + /** + * Dummy sort prefix result to use for empty rows. + */ + private val emptyPrefix = new UnsafeExternalRowSorter.PrefixComputer.Prefix + def getPrefixComparator(sortOrder: SortOrder): PrefixComparator = { sortOrder.dataType match { case StringType => @@ -70,10 +75,6 @@ object SortPrefixUtils { */ def canSortFullyWithPrefix(sortOrder: SortOrder): Boolean = { sortOrder.dataType match { - // TODO(ekl) long-type is problematic because it's null prefix representation collides with - // the lowest possible long value. Handle this special case outside radix sort. - case LongType if sortOrder.nullable => - false case BooleanType | ByteType | ShortType | IntegerType | LongType | DateType | TimestampType | FloatType | DoubleType => true @@ -97,16 +98,29 @@ object SortPrefixUtils { def createPrefixGenerator(schema: StructType): UnsafeExternalRowSorter.PrefixComputer = { if (schema.nonEmpty) { val boundReference = BoundReference(0, schema.head.dataType, nullable = true) - val prefixProjection = UnsafeProjection.create( - SortPrefix(SortOrder(boundReference, Ascending))) + val prefixExpr = SortPrefix(SortOrder(boundReference, Ascending)) + val prefixProjection = UnsafeProjection.create(prefixExpr) new UnsafeExternalRowSorter.PrefixComputer { - override def computePrefix(row: InternalRow): Long = { - prefixProjection.apply(row).getLong(0) + private val result = new UnsafeExternalRowSorter.PrefixComputer.Prefix + override def computePrefix(row: InternalRow): + UnsafeExternalRowSorter.PrefixComputer.Prefix = { + val prefix = prefixProjection.apply(row) + if (prefix.isNullAt(0)) { + result.isNull = true + result.value = prefixExpr.nullValue + } else { + result.isNull = false + result.value = prefix.getLong(0) + } + result } } } else { new UnsafeExternalRowSorter.PrefixComputer { - override def computePrefix(row: InternalRow): Long = 0 + override def computePrefix(row: InternalRow): + UnsafeExternalRowSorter.PrefixComputer.Prefix = { + emptyPrefix + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala index 97bbab65af..1b9634cfc0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala @@ -347,13 +347,13 @@ case class WindowExec( SparkEnv.get.memoryManager.pageSizeBytes, false) rows.foreach { r => - sorter.insertRecord(r.getBaseObject, r.getBaseOffset, r.getSizeInBytes, 0) + sorter.insertRecord(r.getBaseObject, r.getBaseOffset, r.getSizeInBytes, 0, false) } rows.clear() } } else { sorter.insertRecord(nextRow.getBaseObject, nextRow.getBaseOffset, - nextRow.getSizeInBytes, 0) + nextRow.getSizeInBytes, 0, false) } fetchNextRow() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala index 88f78a7a73..d870d91edc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala @@ -53,7 +53,7 @@ class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numField val partition = split.asInstanceOf[CartesianPartition] for (y <- rdd2.iterator(partition.s2, context)) { - sorter.insertRecord(y.getBaseObject, y.getBaseOffset, y.getSizeInBytes, 0) + sorter.insertRecord(y.getBaseObject, y.getBaseOffset, y.getSizeInBytes, 0, false) } // Create an iterator from sorter and wrapper it as Iterator[UnsafeRow] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala index c3acf29c2d..ba3fa3732d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala @@ -54,6 +54,17 @@ class SortSuite extends SparkPlanTest with SharedSQLContext { sortAnswers = false) } + test("sorting all nulls") { + checkThatPlansAgree( + (1 to 100).map(v => Tuple1(v)).toDF().selectExpr("NULL as a"), + (child: SparkPlan) => + GlobalLimitExec(10, SortExec('a.asc :: Nil, global = true, child = child)), + (child: SparkPlan) => + GlobalLimitExec(10, ReferenceSort('a.asc :: Nil, global = true, child)), + sortAnswers = false + ) + } + test("sort followed by limit") { checkThatPlansAgree( (1 to 100).map(v => Tuple1(v)).toDF("a"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala index 9964b7373f..50ae26a3ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala @@ -110,7 +110,7 @@ class SortBenchmark extends BenchmarkBase { benchmark.addTimerCase("radix sort key prefix array") { timer => val (_, buf2) = generateKeyPrefixTestData(size, rand.nextLong) timer.startTiming() - RadixSort.sortKeyPrefixArray(buf2, size, 0, 7, false, false) + RadixSort.sortKeyPrefixArray(buf2, 0, size, 0, 7, false, false) timer.stopTiming() } benchmark.run() |