aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-07-30 17:17:27 -0700
committerReynold Xin <rxin@databricks.com>2015-07-30 17:17:27 -0700
commite7a0976e991f75a7bda99509e2b040daab965ae6 (patch)
tree8a8197424593977086fca74b073a96bd52f5a89d /sql/catalyst
parentdf32669514afc0223ecdeca30fbfbe0b40baef3a (diff)
downloadspark-e7a0976e991f75a7bda99509e2b040daab965ae6.tar.gz
spark-e7a0976e991f75a7bda99509e2b040daab965ae6.tar.bz2
spark-e7a0976e991f75a7bda99509e2b040daab965ae6.zip
[SPARK-9458][SPARK-9469][SQL] Code generate prefix computation in sorting & moves unsafe conversion out of TungstenSort.
Author: Reynold Xin <rxin@databricks.com> Closes #7803 from rxin/SPARK-9458 and squashes the following commits: 5b032dc [Reynold Xin] Fix string. b670dbb [Reynold Xin] [SPARK-9458][SPARK-9469][SQL] Code generate prefix computation in sorting & moves unsafe conversion out of TungstenSort.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java27
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala44
2 files changed, 55 insertions, 16 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
index 4c3f2c6557..68c49feae9 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -48,7 +48,6 @@ final class UnsafeExternalRowSorter {
private long numRowsInserted = 0;
private final StructType schema;
- private final UnsafeProjection unsafeProjection;
private final PrefixComputer prefixComputer;
private final UnsafeExternalSorter sorter;
@@ -62,7 +61,6 @@ final class UnsafeExternalRowSorter {
PrefixComparator prefixComparator,
PrefixComputer prefixComputer) throws IOException {
this.schema = schema;
- this.unsafeProjection = UnsafeProjection.create(schema);
this.prefixComputer = prefixComputer;
final SparkEnv sparkEnv = SparkEnv.get();
final TaskContext taskContext = TaskContext.get();
@@ -88,13 +86,12 @@ final class UnsafeExternalRowSorter {
}
@VisibleForTesting
- void insertRow(InternalRow row) throws IOException {
- UnsafeRow unsafeRow = unsafeProjection.apply(row);
+ void insertRow(UnsafeRow row) throws IOException {
final long prefix = prefixComputer.computePrefix(row);
sorter.insertRecord(
- unsafeRow.getBaseObject(),
- unsafeRow.getBaseOffset(),
- unsafeRow.getSizeInBytes(),
+ row.getBaseObject(),
+ row.getBaseOffset(),
+ row.getSizeInBytes(),
prefix
);
numRowsInserted++;
@@ -113,7 +110,7 @@ final class UnsafeExternalRowSorter {
}
@VisibleForTesting
- Iterator<InternalRow> sort() throws IOException {
+ Iterator<UnsafeRow> sort() throws IOException {
try {
final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator();
if (!sortedIterator.hasNext()) {
@@ -121,7 +118,7 @@ final class UnsafeExternalRowSorter {
// here in order to prevent memory leaks.
cleanupResources();
}
- return new AbstractScalaRowIterator() {
+ return new AbstractScalaRowIterator<UnsafeRow>() {
private final int numFields = schema.length();
private UnsafeRow row = new UnsafeRow();
@@ -132,7 +129,7 @@ final class UnsafeExternalRowSorter {
}
@Override
- public InternalRow next() {
+ public UnsafeRow next() {
try {
sortedIterator.loadNext();
row.pointTo(
@@ -164,11 +161,11 @@ final class UnsafeExternalRowSorter {
}
- public Iterator<InternalRow> sort(Iterator<InternalRow> inputIterator) throws IOException {
- while (inputIterator.hasNext()) {
- insertRow(inputIterator.next());
- }
- return sort();
+ public Iterator<UnsafeRow> sort(Iterator<UnsafeRow> inputIterator) throws IOException {
+ while (inputIterator.hasNext()) {
+ insertRow(inputIterator.next());
+ }
+ return sort();
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
index 3f436c0eb8..9fe877f10f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala
@@ -17,7 +17,10 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
+import org.apache.spark.sql.types._
+import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.DoublePrefixComparator
abstract sealed class SortDirection
case object Ascending extends SortDirection
@@ -37,4 +40,43 @@ case class SortOrder(child: Expression, direction: SortDirection)
override def nullable: Boolean = child.nullable
override def toString: String = s"$child ${if (direction == Ascending) "ASC" else "DESC"}"
+
+ def isAscending: Boolean = direction == Ascending
+}
+
+/**
+ * An expression to generate a 64-bit long prefix used in sorting.
+ */
+case class SortPrefix(child: SortOrder) extends UnaryExpression {
+
+ override def eval(input: InternalRow): Any = throw new UnsupportedOperationException
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val childCode = child.child.gen(ctx)
+ val input = childCode.primitive
+ val DoublePrefixCmp = classOf[DoublePrefixComparator].getName
+
+ val (nullValue: Long, prefixCode: String) = child.child.dataType match {
+ case BooleanType =>
+ (Long.MinValue, s"$input ? 1L : 0L")
+ case _: IntegralType =>
+ (Long.MinValue, s"(long) $input")
+ case FloatType | DoubleType =>
+ (DoublePrefixComparator.computePrefix(Double.NegativeInfinity),
+ s"$DoublePrefixCmp.computePrefix((double)$input)")
+ case StringType => (0L, s"$input.getPrefix()")
+ case _ => (0L, "0L")
+ }
+
+ childCode.code +
+ s"""
+ |long ${ev.primitive} = ${nullValue}L;
+ |boolean ${ev.isNull} = false;
+ |if (!${childCode.isNull}) {
+ | ${ev.primitive} = $prefixCode;
+ |}
+ """.stripMargin
+ }
+
+ override def dataType: DataType = LongType
}