From e7a0976e991f75a7bda99509e2b040daab965ae6 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 30 Jul 2015 17:17:27 -0700 Subject: [SPARK-9458][SPARK-9469][SQL] Code generate prefix computation in sorting & moves unsafe conversion out of TungstenSort. Author: Reynold Xin Closes #7803 from rxin/SPARK-9458 and squashes the following commits: 5b032dc [Reynold Xin] Fix string. b670dbb [Reynold Xin] [SPARK-9458][SPARK-9469][SQL] Code generate prefix computation in sorting & moves unsafe conversion out of TungstenSort. --- .../collection/unsafe/sort/PrefixComparators.java | 49 +++++++------ .../unsafe/sort/PrefixComparatorsSuite.scala | 22 ++---- .../sql/execution/UnsafeExternalRowSorter.java | 27 ++++--- .../spark/sql/catalyst/expressions/SortOrder.scala | 44 +++++++++++- .../spark/sql/execution/SortPrefixUtils.scala | 64 +++-------------- .../spark/sql/execution/SparkStrategies.scala | 4 +- .../spark/sql/execution/joins/HashedRelation.scala | 4 +- .../org/apache/spark/sql/execution/sort.scala | 64 ++++++++--------- .../sql/execution/RowFormatConvertersSuite.scala | 11 ++- .../spark/sql/execution/TungstenSortSuite.scala | 83 ++++++++++++++++++++++ .../sql/execution/UnsafeExternalSortSuite.scala | 83 ---------------------- 11 files changed, 216 insertions(+), 239 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java index 600aff7d15..4d7e5b3dfb 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java @@ -28,9 +28,11 @@ public class PrefixComparators { private PrefixComparators() {} public static final StringPrefixComparator STRING = new StringPrefixComparator(); - public static final IntegralPrefixComparator INTEGRAL = new IntegralPrefixComparator(); - public static final FloatPrefixComparator FLOAT = new FloatPrefixComparator(); + public static final StringPrefixComparatorDesc STRING_DESC = new StringPrefixComparatorDesc(); + public static final LongPrefixComparator LONG = new LongPrefixComparator(); + public static final LongPrefixComparatorDesc LONG_DESC = new LongPrefixComparatorDesc(); public static final DoublePrefixComparator DOUBLE = new DoublePrefixComparator(); + public static final DoublePrefixComparatorDesc DOUBLE_DESC = new DoublePrefixComparatorDesc(); public static final class StringPrefixComparator extends PrefixComparator { @Override @@ -38,50 +40,55 @@ public class PrefixComparators { return UnsignedLongs.compare(aPrefix, bPrefix); } - public long computePrefix(UTF8String value) { + public static long computePrefix(UTF8String value) { return value == null ? 0L : value.getPrefix(); } } - /** - * Prefix comparator for all integral types (boolean, byte, short, int, long). - */ - public static final class IntegralPrefixComparator extends PrefixComparator { + public static final class StringPrefixComparatorDesc extends PrefixComparator { + @Override + public int compare(long bPrefix, long aPrefix) { + return UnsignedLongs.compare(aPrefix, bPrefix); + } + } + + public static final class LongPrefixComparator extends PrefixComparator { @Override public int compare(long a, long b) { return (a < b) ? -1 : (a > b) ? 1 : 0; } + } - public final long NULL_PREFIX = Long.MIN_VALUE; + public static final class LongPrefixComparatorDesc extends PrefixComparator { + @Override + public int compare(long b, long a) { + return (a < b) ? -1 : (a > b) ? 1 : 0; + } } - public static final class FloatPrefixComparator extends PrefixComparator { + public static final class DoublePrefixComparator extends PrefixComparator { @Override public int compare(long aPrefix, long bPrefix) { - float a = Float.intBitsToFloat((int) aPrefix); - float b = Float.intBitsToFloat((int) bPrefix); - return Utils.nanSafeCompareFloats(a, b); + double a = Double.longBitsToDouble(aPrefix); + double b = Double.longBitsToDouble(bPrefix); + return Utils.nanSafeCompareDoubles(a, b); } - public long computePrefix(float value) { - return Float.floatToIntBits(value) & 0xffffffffL; + public static long computePrefix(double value) { + return Double.doubleToLongBits(value); } - - public final long NULL_PREFIX = computePrefix(Float.NEGATIVE_INFINITY); } - public static final class DoublePrefixComparator extends PrefixComparator { + public static final class DoublePrefixComparatorDesc extends PrefixComparator { @Override - public int compare(long aPrefix, long bPrefix) { + public int compare(long bPrefix, long aPrefix) { double a = Double.longBitsToDouble(aPrefix); double b = Double.longBitsToDouble(bPrefix); return Utils.nanSafeCompareDoubles(a, b); } - public long computePrefix(double value) { + public static long computePrefix(double value) { return Double.doubleToLongBits(value); } - - public final long NULL_PREFIX = computePrefix(Double.NEGATIVE_INFINITY); } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala index cf53a8ad21..26a2e96eda 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala @@ -29,8 +29,8 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks { def testPrefixComparison(s1: String, s2: String): Unit = { val utf8string1 = UTF8String.fromString(s1) val utf8string2 = UTF8String.fromString(s2) - val s1Prefix = PrefixComparators.STRING.computePrefix(utf8string1) - val s2Prefix = PrefixComparators.STRING.computePrefix(utf8string2) + val s1Prefix = PrefixComparators.StringPrefixComparator.computePrefix(utf8string1) + val s2Prefix = PrefixComparators.StringPrefixComparator.computePrefix(utf8string2) val prefixComparisonResult = PrefixComparators.STRING.compare(s1Prefix, s2Prefix) val cmp = UnsignedBytes.lexicographicalComparator().compare( @@ -55,27 +55,15 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks { forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) } } - test("float prefix comparator handles NaN properly") { - val nan1: Float = java.lang.Float.intBitsToFloat(0x7f800001) - val nan2: Float = java.lang.Float.intBitsToFloat(0x7fffffff) - assert(nan1.isNaN) - assert(nan2.isNaN) - val nan1Prefix = PrefixComparators.FLOAT.computePrefix(nan1) - val nan2Prefix = PrefixComparators.FLOAT.computePrefix(nan2) - assert(nan1Prefix === nan2Prefix) - val floatMaxPrefix = PrefixComparators.FLOAT.computePrefix(Float.MaxValue) - assert(PrefixComparators.FLOAT.compare(nan1Prefix, floatMaxPrefix) === 1) - } - test("double prefix comparator handles NaNs properly") { val nan1: Double = java.lang.Double.longBitsToDouble(0x7ff0000000000001L) val nan2: Double = java.lang.Double.longBitsToDouble(0x7fffffffffffffffL) assert(nan1.isNaN) assert(nan2.isNaN) - val nan1Prefix = PrefixComparators.DOUBLE.computePrefix(nan1) - val nan2Prefix = PrefixComparators.DOUBLE.computePrefix(nan2) + val nan1Prefix = PrefixComparators.DoublePrefixComparator.computePrefix(nan1) + val nan2Prefix = PrefixComparators.DoublePrefixComparator.computePrefix(nan2) assert(nan1Prefix === nan2Prefix) - val doubleMaxPrefix = PrefixComparators.DOUBLE.computePrefix(Double.MaxValue) + val doubleMaxPrefix = PrefixComparators.DoublePrefixComparator.computePrefix(Double.MaxValue) assert(PrefixComparators.DOUBLE.compare(nan1Prefix, doubleMaxPrefix) === 1) } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 4c3f2c6557..68c49feae9 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -48,7 +48,6 @@ final class UnsafeExternalRowSorter { private long numRowsInserted = 0; private final StructType schema; - private final UnsafeProjection unsafeProjection; private final PrefixComputer prefixComputer; private final UnsafeExternalSorter sorter; @@ -62,7 +61,6 @@ final class UnsafeExternalRowSorter { PrefixComparator prefixComparator, PrefixComputer prefixComputer) throws IOException { this.schema = schema; - this.unsafeProjection = UnsafeProjection.create(schema); this.prefixComputer = prefixComputer; final SparkEnv sparkEnv = SparkEnv.get(); final TaskContext taskContext = TaskContext.get(); @@ -88,13 +86,12 @@ final class UnsafeExternalRowSorter { } @VisibleForTesting - void insertRow(InternalRow row) throws IOException { - UnsafeRow unsafeRow = unsafeProjection.apply(row); + void insertRow(UnsafeRow row) throws IOException { final long prefix = prefixComputer.computePrefix(row); sorter.insertRecord( - unsafeRow.getBaseObject(), - unsafeRow.getBaseOffset(), - unsafeRow.getSizeInBytes(), + row.getBaseObject(), + row.getBaseOffset(), + row.getSizeInBytes(), prefix ); numRowsInserted++; @@ -113,7 +110,7 @@ final class UnsafeExternalRowSorter { } @VisibleForTesting - Iterator sort() throws IOException { + Iterator sort() throws IOException { try { final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator(); if (!sortedIterator.hasNext()) { @@ -121,7 +118,7 @@ final class UnsafeExternalRowSorter { // here in order to prevent memory leaks. cleanupResources(); } - return new AbstractScalaRowIterator() { + return new AbstractScalaRowIterator() { private final int numFields = schema.length(); private UnsafeRow row = new UnsafeRow(); @@ -132,7 +129,7 @@ final class UnsafeExternalRowSorter { } @Override - public InternalRow next() { + public UnsafeRow next() { try { sortedIterator.loadNext(); row.pointTo( @@ -164,11 +161,11 @@ final class UnsafeExternalRowSorter { } - public Iterator sort(Iterator inputIterator) throws IOException { - while (inputIterator.hasNext()) { - insertRow(inputIterator.next()); - } - return sort(); + public Iterator sort(Iterator inputIterator) throws IOException { + while (inputIterator.hasNext()) { + insertRow(inputIterator.next()); + } + return sort(); } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 3f436c0eb8..9fe877f10f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.types._ +import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.DoublePrefixComparator abstract sealed class SortDirection case object Ascending extends SortDirection @@ -37,4 +40,43 @@ case class SortOrder(child: Expression, direction: SortDirection) override def nullable: Boolean = child.nullable override def toString: String = s"$child ${if (direction == Ascending) "ASC" else "DESC"}" + + def isAscending: Boolean = direction == Ascending +} + +/** + * An expression to generate a 64-bit long prefix used in sorting. + */ +case class SortPrefix(child: SortOrder) extends UnaryExpression { + + override def eval(input: InternalRow): Any = throw new UnsupportedOperationException + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val childCode = child.child.gen(ctx) + val input = childCode.primitive + val DoublePrefixCmp = classOf[DoublePrefixComparator].getName + + val (nullValue: Long, prefixCode: String) = child.child.dataType match { + case BooleanType => + (Long.MinValue, s"$input ? 1L : 0L") + case _: IntegralType => + (Long.MinValue, s"(long) $input") + case FloatType | DoubleType => + (DoublePrefixComparator.computePrefix(Double.NegativeInfinity), + s"$DoublePrefixCmp.computePrefix((double)$input)") + case StringType => (0L, s"$input.getPrefix()") + case _ => (0L, "0L") + } + + childCode.code + + s""" + |long ${ev.primitive} = ${nullValue}L; + |boolean ${ev.isNull} = false; + |if (!${childCode.isNull}) { + | ${ev.primitive} = $prefixCode; + |} + """.stripMargin + } + + override def dataType: DataType = LongType } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala index 2dee3542d6..a2145b185c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala @@ -18,10 +18,8 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SortOrder import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, PrefixComparator} @@ -37,61 +35,15 @@ object SortPrefixUtils { def getPrefixComparator(sortOrder: SortOrder): PrefixComparator = { sortOrder.dataType match { - case StringType => PrefixComparators.STRING - case BooleanType | ByteType | ShortType | IntegerType | LongType => PrefixComparators.INTEGRAL - case FloatType => PrefixComparators.FLOAT - case DoubleType => PrefixComparators.DOUBLE + case StringType if sortOrder.isAscending => PrefixComparators.STRING + case StringType if !sortOrder.isAscending => PrefixComparators.STRING_DESC + case BooleanType | ByteType | ShortType | IntegerType | LongType if sortOrder.isAscending => + PrefixComparators.LONG + case BooleanType | ByteType | ShortType | IntegerType | LongType if !sortOrder.isAscending => + PrefixComparators.LONG_DESC + case FloatType | DoubleType if sortOrder.isAscending => PrefixComparators.DOUBLE + case FloatType | DoubleType if !sortOrder.isAscending => PrefixComparators.DOUBLE_DESC case _ => NoOpPrefixComparator } } - - def getPrefixComputer(sortOrder: SortOrder): InternalRow => Long = { - sortOrder.dataType match { - case StringType => (row: InternalRow) => { - PrefixComparators.STRING.computePrefix(sortOrder.child.eval(row).asInstanceOf[UTF8String]) - } - case BooleanType => - (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX - else if (sortOrder.child.eval(row).asInstanceOf[Boolean]) 1 - else 0 - } - case ByteType => - (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX - else sortOrder.child.eval(row).asInstanceOf[Byte] - } - case ShortType => - (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX - else sortOrder.child.eval(row).asInstanceOf[Short] - } - case IntegerType => - (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX - else sortOrder.child.eval(row).asInstanceOf[Int] - } - case LongType => - (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX - else sortOrder.child.eval(row).asInstanceOf[Long] - } - case FloatType => (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.FLOAT.NULL_PREFIX - else PrefixComparators.FLOAT.computePrefix(sortOrder.child.eval(row).asInstanceOf[Float]) - } - case DoubleType => (row: InternalRow) => { - val exprVal = sortOrder.child.eval(row) - if (exprVal == null) PrefixComparators.DOUBLE.NULL_PREFIX - else PrefixComparators.DOUBLE.computePrefix(sortOrder.child.eval(row).asInstanceOf[Double]) - } - case _ => (row: InternalRow) => 0L - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 52a9b02d37..03d24a88d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -341,8 +341,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { */ def getSortOperator(sortExprs: Seq[SortOrder], global: Boolean, child: SparkPlan): SparkPlan = { if (sqlContext.conf.unsafeEnabled && sqlContext.conf.codegenEnabled && - UnsafeExternalSort.supportsSchema(child.schema)) { - execution.UnsafeExternalSort(sortExprs, global, child) + TungstenSort.supportsSchema(child.schema)) { + execution.TungstenSort(sortExprs, global, child) } else if (sqlContext.conf.externalSortEnabled) { execution.ExternalSort(sortExprs, global, child) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 26dbc911e9..f88a45f48a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -229,7 +229,7 @@ private[joins] final class UnsafeHashedRelation( // write all the values as single byte array var totalSize = 0L var i = 0 - while (i < values.size) { + while (i < values.length) { totalSize += values(i).getSizeInBytes + 4 + 4 i += 1 } @@ -240,7 +240,7 @@ private[joins] final class UnsafeHashedRelation( out.writeInt(totalSize.toInt) out.write(key.getBytes) i = 0 - while (i < values.size) { + while (i < values.length) { // [num of fields] [num of bytes] [row bytes] // write the integer in native order, so they can be read by UNSAFE.getInt() if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala index f82208868c..6d903ab23c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala @@ -17,16 +17,14 @@ package org.apache.spark.sql.execution -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.catalyst.expressions.{Descending, BindReferences, Attribute, SortOrder} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, OrderedDistribution, Distribution} import org.apache.spark.sql.types.StructType import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter -import org.apache.spark.util.collection.unsafe.sort.PrefixComparator //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines various sort operators. @@ -97,59 +95,53 @@ case class ExternalSort( * @param testSpillFrequency Method for configuring periodic spilling in unit tests. If set, will * spill every `frequency` records. */ -case class UnsafeExternalSort( +case class TungstenSort( sortOrder: Seq[SortOrder], global: Boolean, child: SparkPlan, testSpillFrequency: Int = 0) extends UnaryNode { - private[this] val schema: StructType = child.schema + override def outputsUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = false + + override def output: Seq[Attribute] = child.output + + override def outputOrdering: Seq[SortOrder] = sortOrder override def requiredChildDistribution: Seq[Distribution] = if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") { - assert(codegenEnabled, "UnsafeExternalSort requires code generation to be enabled") - def doSort(iterator: Iterator[InternalRow]): Iterator[InternalRow] = { - val ordering = newOrdering(sortOrder, child.output) - val boundSortExpression = BindReferences.bindReference(sortOrder.head, child.output) - // Hack until we generate separate comparator implementations for ascending vs. descending - // (or choose to codegen them): - val prefixComparator = { - val comp = SortPrefixUtils.getPrefixComparator(boundSortExpression) - if (sortOrder.head.direction == Descending) { - new PrefixComparator { - override def compare(p1: Long, p2: Long): Int = -1 * comp.compare(p1, p2) - } - } else { - comp - } - } - val prefixComputer = { - val prefixComputer = SortPrefixUtils.getPrefixComputer(boundSortExpression) - new UnsafeExternalRowSorter.PrefixComputer { - override def computePrefix(row: InternalRow): Long = prefixComputer(row) + protected override def doExecute(): RDD[InternalRow] = { + val schema = child.schema + val childOutput = child.output + child.execute().mapPartitions({ iter => + val ordering = newOrdering(sortOrder, childOutput) + + // The comparator for comparing prefix + val boundSortExpression = BindReferences.bindReference(sortOrder.head, childOutput) + val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression) + + // The generator for prefix + val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression))) + val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { + override def computePrefix(row: InternalRow): Long = { + prefixProjection.apply(row).getLong(0) } } + val sorter = new UnsafeExternalRowSorter(schema, ordering, prefixComparator, prefixComputer) if (testSpillFrequency > 0) { sorter.setTestSpillFrequency(testSpillFrequency) } - sorter.sort(iterator) - } - child.execute().mapPartitions(doSort, preservesPartitioning = true) + sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]]) + }, preservesPartitioning = true) } - override def output: Seq[Attribute] = child.output - - override def outputOrdering: Seq[SortOrder] = sortOrder - - override def outputsUnsafeRows: Boolean = true } -@DeveloperApi -object UnsafeExternalSort { +object TungstenSort { /** * Return true if UnsafeExternalSort can sort rows with the given schema, false otherwise. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala index 7b75f75591..707cd9c6d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.IsNull +import org.apache.spark.sql.catalyst.expressions.{Literal, IsNull} import org.apache.spark.sql.test.TestSQLContext class RowFormatConvertersSuite extends SparkPlanTest { @@ -31,7 +30,7 @@ class RowFormatConvertersSuite extends SparkPlanTest { private val outputsSafe = ExternalSort(Nil, false, PhysicalRDD(Seq.empty, null)) assert(!outputsSafe.outputsUnsafeRows) - private val outputsUnsafe = UnsafeExternalSort(Nil, false, PhysicalRDD(Seq.empty, null)) + private val outputsUnsafe = TungstenSort(Nil, false, PhysicalRDD(Seq.empty, null)) assert(outputsUnsafe.outputsUnsafeRows) test("planner should insert unsafe->safe conversions when required") { @@ -41,14 +40,14 @@ class RowFormatConvertersSuite extends SparkPlanTest { } test("filter can process unsafe rows") { - val plan = Filter(IsNull(null), outputsUnsafe) + val plan = Filter(IsNull(IsNull(Literal(1))), outputsUnsafe) val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) - assert(getConverters(preparedPlan).isEmpty) + assert(getConverters(preparedPlan).size === 1) assert(preparedPlan.outputsUnsafeRows) } test("filter can process safe rows") { - val plan = Filter(IsNull(null), outputsSafe) + val plan = Filter(IsNull(IsNull(Literal(1))), outputsSafe) val preparedPlan = TestSQLContext.prepareForExecution.execute(plan) assert(getConverters(preparedPlan).isEmpty) assert(!preparedPlan.outputsUnsafeRows) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala new file mode 100644 index 0000000000..450963547c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import scala.util.Random + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.{RandomDataGenerator, Row, SQLConf} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.types._ + +class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll { + + override def beforeAll(): Unit = { + TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true) + } + + override def afterAll(): Unit = { + TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get) + } + + test("sort followed by limit") { + checkThatPlansAgree( + (1 to 100).map(v => Tuple1(v)).toDF("a"), + (child: SparkPlan) => Limit(10, TungstenSort('a.asc :: Nil, true, child)), + (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)), + sortAnswers = false + ) + } + + test("sorting does not crash for large inputs") { + val sortOrder = 'a.asc :: Nil + val stringLength = 1024 * 1024 * 2 + checkThatPlansAgree( + Seq(Tuple1("a" * stringLength), Tuple1("b" * stringLength)).toDF("a").repartition(1), + TungstenSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1), + Sort(sortOrder, global = true, _: SparkPlan), + sortAnswers = false + ) + } + + // Test sorting on different data types + for ( + dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType) + if !dataType.isInstanceOf[DecimalType]; // We don't have an unsafe representation for decimals + nullable <- Seq(true, false); + sortOrder <- Seq('a.asc :: Nil, 'a.desc :: Nil); + randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable) + ) { + test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") { + val inputData = Seq.fill(1000)(randomDataGenerator()) + val inputDf = TestSQLContext.createDataFrame( + TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), + StructType(StructField("a", dataType, nullable = true) :: Nil) + ) + assert(TungstenSort.supportsSchema(inputDf.schema)) + checkThatPlansAgree( + inputDf, + plan => ConvertToSafe( + TungstenSort(sortOrder, global = true, plan: SparkPlan, testSpillFrequency = 23)), + Sort(sortOrder, global = true, _: SparkPlan), + sortAnswers = false + ) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala deleted file mode 100644 index 138636b0c6..0000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import scala.util.Random - -import org.scalatest.BeforeAndAfterAll - -import org.apache.spark.sql.{RandomDataGenerator, Row, SQLConf} -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.types._ - -class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { - - override def beforeAll(): Unit = { - TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true) - } - - override def afterAll(): Unit = { - TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get) - } - - test("sort followed by limit") { - checkThatPlansAgree( - (1 to 100).map(v => Tuple1(v)).toDF("a"), - (child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)), - (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)), - sortAnswers = false - ) - } - - test("sorting does not crash for large inputs") { - val sortOrder = 'a.asc :: Nil - val stringLength = 1024 * 1024 * 2 - checkThatPlansAgree( - Seq(Tuple1("a" * stringLength), Tuple1("b" * stringLength)).toDF("a").repartition(1), - UnsafeExternalSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1), - Sort(sortOrder, global = true, _: SparkPlan), - sortAnswers = false - ) - } - - // Test sorting on different data types - for ( - dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType) - if !dataType.isInstanceOf[DecimalType]; // We don't have an unsafe representation for decimals - nullable <- Seq(true, false); - sortOrder <- Seq('a.asc :: Nil, 'a.desc :: Nil); - randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable) - ) { - test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") { - val inputData = Seq.fill(1000)(randomDataGenerator()) - val inputDf = TestSQLContext.createDataFrame( - TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), - StructType(StructField("a", dataType, nullable = true) :: Nil) - ) - assert(UnsafeExternalSort.supportsSchema(inputDf.schema)) - checkThatPlansAgree( - inputDf, - plan => ConvertToSafe( - UnsafeExternalSort(sortOrder, global = true, plan: SparkPlan, testSpillFrequency = 23)), - Sort(sortOrder, global = true, _: SparkPlan), - sortAnswers = false - ) - } - } -} -- cgit v1.2.3