aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorEric Liang <ekl@databricks.com>2016-06-11 15:42:58 -0700
committerReynold Xin <rxin@databricks.com>2016-06-11 15:42:58 -0700
commitc06c58bbbb2de0c22cfc70c486d23a94c3079ba4 (patch)
tree2d7b99a05f88c5e90ad5b18898447defb53fbb20 /sql
parent75705e8dbb51ac91ffc7012fa67f072494c13832 (diff)
downloadspark-c06c58bbbb2de0c22cfc70c486d23a94c3079ba4.tar.gz
spark-c06c58bbbb2de0c22cfc70c486d23a94c3079ba4.tar.bz2
spark-c06c58bbbb2de0c22cfc70c486d23a94c3079ba4.zip
[SPARK-14851][CORE] Support radix sort with nullable longs
## What changes were proposed in this pull request? This adds support for radix sort of nullable long fields. When a sort field is null and radix sort is enabled, we keep nulls in a separate region of the sort buffer so that radix sort does not need to deal with them. This also has performance benefits when sorting smaller integer types, since the current representation of nulls in two's complement (Long.MIN_VALUE) otherwise forces a full-width radix sort. This strategy for nulls does mean the sort is no longer stable. cc davies ## How was this patch tested? Existing randomized sort tests for correctness. I also tested some TPCDS queries and there does not seem to be any significant regression for non-null sorts. Some test queries (best of 5 runs each). Before change: scala> val start = System.nanoTime; spark.range(5000000).selectExpr("if(id > 5, cast(hash(id) as long), NULL) as h").coalesce(1).orderBy("h").collect(); (System.nanoTime - start) / 1e6 start: Long = 3190437233227987 res3: Double = 4716.471091 After change: scala> val start = System.nanoTime; spark.range(5000000).selectExpr("if(id > 5, cast(hash(id) as long), NULL) as h").coalesce(1).orderBy("h").collect(); (System.nanoTime - start) / 1e6 start: Long = 3190367870952791 res4: Double = 2981.143045 Author: Eric Liang <ekl@databricks.com> Closes #13161 from ericl/sc-2998.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java20
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala40
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala32
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala11
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala2
9 files changed, 95 insertions, 39 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
index 37fbad47c1..ad76bf5a0a 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -51,7 +51,20 @@ public final class UnsafeExternalRowSorter {
private final UnsafeExternalSorter sorter;
public abstract static class PrefixComputer {
- abstract long computePrefix(InternalRow row);
+
+ public static class Prefix {
+ /** Key prefix value, or the null prefix value if isNull = true. **/
+ long value;
+
+ /** Whether the key is null. */
+ boolean isNull;
+ }
+
+ /**
+ * Computes prefix for the given row. For efficiency, the returned object may be reused in
+ * further calls to a given PrefixComputer.
+ */
+ abstract Prefix computePrefix(InternalRow row);
}
public UnsafeExternalRowSorter(
@@ -88,12 +101,13 @@ public final class UnsafeExternalRowSorter {
}
public void insertRow(UnsafeRow row) throws IOException {
- final long prefix = prefixComputer.computePrefix(row);
+ final PrefixComputer.Prefix prefix = prefixComputer.computePrefix(row);
sorter.insertRecord(
row.getBaseObject(),
row.getBaseOffset(),
row.getSizeInBytes(),
- prefix
+ prefix.value,
+ prefix.isNull
);
numRowsInserted++;
if (testSpillFrequency > 0 && (numRowsInserted % testSpillFrequency) == 0) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
index 42a8be6b1b..de779ed370 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
@@ -64,10 +64,21 @@ case class SortOrder(child: Expression, direction: SortDirection)
}
/**
- * An expression to generate a 64-bit long prefix used in sorting.
+ * An expression to generate a 64-bit long prefix used in sorting. If the sort must operate over
+ * null keys as well, this.nullValue can be used in place of emitted null prefixes in the sort.
*/
case class SortPrefix(child: SortOrder) extends UnaryExpression {
+ val nullValue = child.child.dataType match {
+ case BooleanType | DateType | TimestampType | _: IntegralType =>
+ Long.MinValue
+ case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS =>
+ Long.MinValue
+ case _: DecimalType =>
+ DoublePrefixComparator.computePrefix(Double.NegativeInfinity)
+ case _ => 0L
+ }
+
override def eval(input: InternalRow): Any = throw new UnsupportedOperationException
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
@@ -75,20 +86,19 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression {
val input = childCode.value
val BinaryPrefixCmp = classOf[BinaryPrefixComparator].getName
val DoublePrefixCmp = classOf[DoublePrefixComparator].getName
-
- val (nullValue: Long, prefixCode: String) = child.child.dataType match {
+ val prefixCode = child.child.dataType match {
case BooleanType =>
- (Long.MinValue, s"$input ? 1L : 0L")
+ s"$input ? 1L : 0L"
case _: IntegralType =>
- (Long.MinValue, s"(long) $input")
+ s"(long) $input"
case DateType | TimestampType =>
- (Long.MinValue, s"(long) $input")
+ s"(long) $input"
case FloatType | DoubleType =>
- (0L, s"$DoublePrefixCmp.computePrefix((double)$input)")
- case StringType => (0L, s"$input.getPrefix()")
- case BinaryType => (0L, s"$BinaryPrefixCmp.computePrefix($input)")
+ s"$DoublePrefixCmp.computePrefix((double)$input)"
+ case StringType => s"$input.getPrefix()"
+ case BinaryType => s"$BinaryPrefixCmp.computePrefix($input)"
case dt: DecimalType if dt.precision - dt.scale <= Decimal.MAX_LONG_DIGITS =>
- val prefix = if (dt.precision <= Decimal.MAX_LONG_DIGITS) {
+ if (dt.precision <= Decimal.MAX_LONG_DIGITS) {
s"$input.toUnscaledLong()"
} else {
// reduce the scale to fit in a long
@@ -96,17 +106,15 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression {
val s = p - (dt.precision - dt.scale)
s"$input.changePrecision($p, $s) ? $input.toUnscaledLong() : ${Long.MinValue}L"
}
- (Long.MinValue, prefix)
case dt: DecimalType =>
- (DoublePrefixComparator.computePrefix(Double.NegativeInfinity),
- s"$DoublePrefixCmp.computePrefix($input.toDouble())")
- case _ => (0L, "0L")
+ s"$DoublePrefixCmp.computePrefix($input.toDouble())"
+ case _ => "0L"
}
ev.copy(code = childCode.code +
s"""
- |long ${ev.value} = ${nullValue}L;
- |boolean ${ev.isNull} = false;
+ |long ${ev.value} = 0L;
+ |boolean ${ev.isNull} = ${childCode.isNull};
|if (!${childCode.isNull}) {
| ${ev.value} = $prefixCode;
|}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
index bb823cd07b..99fe51db68 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
@@ -118,9 +118,10 @@ public final class UnsafeKVExternalSorter {
// Compute prefix
row.pointTo(baseObject, baseOffset, loc.getKeyLength());
- final long prefix = prefixComputer.computePrefix(row);
+ final UnsafeExternalRowSorter.PrefixComputer.Prefix prefix =
+ prefixComputer.computePrefix(row);
- inMemSorter.insertRecord(address, prefix);
+ inMemSorter.insertRecord(address, prefix.value, prefix.isNull);
}
sorter = UnsafeExternalSorter.createWithExistingInMemorySorter(
@@ -146,10 +147,12 @@ public final class UnsafeKVExternalSorter {
* sorted runs, and then reallocates memory to hold the new record.
*/
public void insertKV(UnsafeRow key, UnsafeRow value) throws IOException {
- final long prefix = prefixComputer.computePrefix(key);
+ final UnsafeExternalRowSorter.PrefixComputer.Prefix prefix =
+ prefixComputer.computePrefix(key);
sorter.insertKVRecord(
key.getBaseObject(), key.getBaseOffset(), key.getSizeInBytes(),
- value.getBaseObject(), value.getBaseOffset(), value.getSizeInBytes(), prefix);
+ value.getBaseObject(), value.getBaseOffset(), value.getSizeInBytes(),
+ prefix.value, prefix.isNull);
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
index 66a16ac576..6db7f45cfd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala
@@ -68,10 +68,16 @@ case class SortExec(
SortPrefixUtils.canSortFullyWithPrefix(boundSortExpression)
// The generator for prefix
- val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression)))
+ val prefixExpr = SortPrefix(boundSortExpression)
+ val prefixProjection = UnsafeProjection.create(Seq(prefixExpr))
val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer {
- override def computePrefix(row: InternalRow): Long = {
- prefixProjection.apply(row).getLong(0)
+ private val result = new UnsafeExternalRowSorter.PrefixComputer.Prefix
+ override def computePrefix(row: InternalRow):
+ UnsafeExternalRowSorter.PrefixComputer.Prefix = {
+ val prefix = prefixProjection.apply(row)
+ result.isNull = prefix.isNullAt(0)
+ result.value = if (result.isNull) prefixExpr.nullValue else prefix.getLong(0)
+ result
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
index 1a5ff5fcec..940467e74d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
@@ -33,6 +33,11 @@ object SortPrefixUtils {
override def compare(prefix1: Long, prefix2: Long): Int = 0
}
+ /**
+ * Dummy sort prefix result to use for empty rows.
+ */
+ private val emptyPrefix = new UnsafeExternalRowSorter.PrefixComputer.Prefix
+
def getPrefixComparator(sortOrder: SortOrder): PrefixComparator = {
sortOrder.dataType match {
case StringType =>
@@ -70,10 +75,6 @@ object SortPrefixUtils {
*/
def canSortFullyWithPrefix(sortOrder: SortOrder): Boolean = {
sortOrder.dataType match {
- // TODO(ekl) long-type is problematic because it's null prefix representation collides with
- // the lowest possible long value. Handle this special case outside radix sort.
- case LongType if sortOrder.nullable =>
- false
case BooleanType | ByteType | ShortType | IntegerType | LongType | DateType |
TimestampType | FloatType | DoubleType =>
true
@@ -97,16 +98,29 @@ object SortPrefixUtils {
def createPrefixGenerator(schema: StructType): UnsafeExternalRowSorter.PrefixComputer = {
if (schema.nonEmpty) {
val boundReference = BoundReference(0, schema.head.dataType, nullable = true)
- val prefixProjection = UnsafeProjection.create(
- SortPrefix(SortOrder(boundReference, Ascending)))
+ val prefixExpr = SortPrefix(SortOrder(boundReference, Ascending))
+ val prefixProjection = UnsafeProjection.create(prefixExpr)
new UnsafeExternalRowSorter.PrefixComputer {
- override def computePrefix(row: InternalRow): Long = {
- prefixProjection.apply(row).getLong(0)
+ private val result = new UnsafeExternalRowSorter.PrefixComputer.Prefix
+ override def computePrefix(row: InternalRow):
+ UnsafeExternalRowSorter.PrefixComputer.Prefix = {
+ val prefix = prefixProjection.apply(row)
+ if (prefix.isNullAt(0)) {
+ result.isNull = true
+ result.value = prefixExpr.nullValue
+ } else {
+ result.isNull = false
+ result.value = prefix.getLong(0)
+ }
+ result
}
}
} else {
new UnsafeExternalRowSorter.PrefixComputer {
- override def computePrefix(row: InternalRow): Long = 0
+ override def computePrefix(row: InternalRow):
+ UnsafeExternalRowSorter.PrefixComputer.Prefix = {
+ emptyPrefix
+ }
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala
index 97bbab65af..1b9634cfc0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala
@@ -347,13 +347,13 @@ case class WindowExec(
SparkEnv.get.memoryManager.pageSizeBytes,
false)
rows.foreach { r =>
- sorter.insertRecord(r.getBaseObject, r.getBaseOffset, r.getSizeInBytes, 0)
+ sorter.insertRecord(r.getBaseObject, r.getBaseOffset, r.getSizeInBytes, 0, false)
}
rows.clear()
}
} else {
sorter.insertRecord(nextRow.getBaseObject, nextRow.getBaseOffset,
- nextRow.getSizeInBytes, 0)
+ nextRow.getSizeInBytes, 0, false)
}
fetchNextRow()
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
index 88f78a7a73..d870d91edc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
@@ -53,7 +53,7 @@ class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numField
val partition = split.asInstanceOf[CartesianPartition]
for (y <- rdd2.iterator(partition.s2, context)) {
- sorter.insertRecord(y.getBaseObject, y.getBaseOffset, y.getSizeInBytes, 0)
+ sorter.insertRecord(y.getBaseObject, y.getBaseOffset, y.getSizeInBytes, 0, false)
}
// Create an iterator from sorter and wrapper it as Iterator[UnsafeRow]
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
index c3acf29c2d..ba3fa3732d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
@@ -54,6 +54,17 @@ class SortSuite extends SparkPlanTest with SharedSQLContext {
sortAnswers = false)
}
+ test("sorting all nulls") {
+ checkThatPlansAgree(
+ (1 to 100).map(v => Tuple1(v)).toDF().selectExpr("NULL as a"),
+ (child: SparkPlan) =>
+ GlobalLimitExec(10, SortExec('a.asc :: Nil, global = true, child = child)),
+ (child: SparkPlan) =>
+ GlobalLimitExec(10, ReferenceSort('a.asc :: Nil, global = true, child)),
+ sortAnswers = false
+ )
+ }
+
test("sort followed by limit") {
checkThatPlansAgree(
(1 to 100).map(v => Tuple1(v)).toDF("a"),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala
index 9964b7373f..50ae26a3ff 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SortBenchmark.scala
@@ -110,7 +110,7 @@ class SortBenchmark extends BenchmarkBase {
benchmark.addTimerCase("radix sort key prefix array") { timer =>
val (_, buf2) = generateKeyPrefixTestData(size, rand.nextLong)
timer.startTiming()
- RadixSort.sortKeyPrefixArray(buf2, size, 0, 7, false, false)
+ RadixSort.sortKeyPrefixArray(buf2, 0, size, 0, 7, false, false)
timer.stopTiming()
}
benchmark.run()