aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2015-07-10 16:44:51 -0700
committerReynold Xin <rxin@databricks.com>2015-07-10 16:44:51 -0700
commitfb8807c9b04f27467b36fc9d0177ef92dd012670 (patch)
tree556101568eb707115d6d20013a0c1dcd8090b696 /sql
parent0772026c2fc88aa85423034006b758f6ff0cc2ed (diff)
downloadspark-fb8807c9b04f27467b36fc9d0177ef92dd012670.tar.gz
spark-fb8807c9b04f27467b36fc9d0177ef92dd012670.tar.bz2
spark-fb8807c9b04f27467b36fc9d0177ef92dd012670.zip
[SPARK-7078] [SPARK-7079] Binary processing sort for Spark SQL
This patch adds a cache-friendly external sorter which operates on serialized bytes and uses this sorter to implement a new sort operator for Spark SQL and DataFrames. ### Overview of the new sorter The new sorter design is inspired by [Alphasort](http://research.microsoft.com/pubs/68249/alphasort.doc) and implements a key-prefix optimization in order to improve the cache friendliness of the sort. In naive sort implementations, the sorting algorithm operates on an array of record pointers. To compare two records for ordering, the sorter must dereference these pointers, which likely involves random memory access, then compare the objects themselves. ![image](https://cloud.githubusercontent.com/assets/50748/8611390/3b1402ae-2675-11e5-8308-1a10bf347e6e.png) In a key-prefix sort, the sort operates on an array which stores the record pointer alongside a prefix of the record's key. When comparing two records for ordering, the sorter first compares the the stored key prefixes. If the ordering can be determined from the key prefixes (i.e. the prefixes are unequal), then the sort can avoid directly comparing the records, avoiding random memory accesses and full record comparisons. For example, if we're sorting a list of strings then we can store the first 8 bytes of the UTF-8 encoded string as the key-prefix and can perform unsigned byte-at-a-time comparisons to determine the ordering of strings based on their prefixes, only resorting to full comparisons for strings that share a common prefix. In cases where the sort key can fit entirely in the space allotted for the key prefix (e.g. the sorting key is an integer), we completely avoid direct record comparison. In this patch's implementation of key-prefix sorting, our sorter's internal array stores a 64-bit long and 64-bit pointer for each record being sorted. The key prefixes are generated by the user when inserting records into the sorter, which uses a user-defined comparison function for comparing them. The `PrefixComparators` object implements a set of comparators for many common types, including primitive numeric types and UTF-8 strings. The actual sorting is implemented by `UnsafeInMemorySorter`. Most consumers will not use this directly, but instead will use `UnsafeExternalSorter`, a class which implements a sort that can spill to disk in response to memory pressure. Internally, `UnsafeExternalSorter` creates `UnsafeInMemorySorters` to perform sorting and uses `UnsafeSortSpillReader/Writer` to spill and read back runs of sorted records and `UnsafeSortSpillMerger` to merge multiple sorted spills into a single sorted iterator. This external sorter integrates with Spark's existing ShuffleMemoryManager for controlling spilling. Many parts of this sorter's design are based on / copied from the more specialized external sort implementation that I designed for the new UnsafeShuffleManager write path; see #5868 for more details on that patch. ### Sorting rows in Spark SQL For now, `UnsafeExternalSorter` is only used by Spark SQL, which uses it to implement a new sort operator, `UnsafeExternalSort`. This sort operator uses a SQL-specific class called `UnsafeExternalRowSorter` that configures an `UnsafeExternalSorter` to use prefix generators and comparators that operate on rows encoded in the UnsafeRow format that was designed for Project Tungsten. I used some interesting unit-testing techniques to test this patch's SQL-specific components. `UnsafeExternalSortSuite` uses the SQL random data generators introduced in #7176 to test the UnsafeSort operator with all atomic types both with and without nullability and in both ascending and descending sort orders. `PrefixComparatorsSuite` contains a cool use of ScalaCheck + ScalaTest's `GeneratorDrivenPropertyChecks` in order to test UTF8String prefix comparison. ### Misc. additional improvements made in this patch This patch made several miscellaneous improvements to related code in Spark SQL: - The logic for selecting physical sort operator implementations, which was partially duplicated in both `Exchange` and `SparkStrategies, has now been consolidated into a `getSortOperator()` helper function in `SparkStrategies`. - The `SparkPlanTest` unit testing helper trait has been extended with new methods for comparing the output produced by two different physical plans. This makes it easy to write tests which assert that two physical operator implementations should produce the same output. I also added a method for disabling the implicit sorting of outputs prior to comparing them, a change which is necessary in order to be able to write proper SparkPlan tests for sort operators. ### Tasks deferred to followup patches While most of this patch's features are reasonably well-tested and complete, there are a number of tasks that are intentionally being deferred to followup patches: - Add tests which mock the ShuffleMemoryManager to check that memory pressure properly triggers spilling (there are examples of this type of test in #5868). - Add tests to ensure that spill files are properly cleaned up after errors. I'd like to do this in the context of a patch which introduces more general metrics for ensuring proper cleanup of tasks' temporary files; see https://issues.apache.org/jira/browse/SPARK-8966 for more details. - Metrics integration: there are some open questions regarding how to track / report spill metrics for non-shuffle operations, so I've deferred most of the IO / shuffle metrics integration for now. - Performance profiling. <!-- Reviewable:start --> [<img src="https://reviewable.io/review_button.png" height=40 alt="Review on Reviewable"/>](https://reviewable.io/reviews/apache/spark/6444) <!-- Reviewable:end --> Author: Josh Rosen <joshrosen@databricks.com> Closes #6444 from JoshRosen/sql-external-sort and squashes the following commits: 6beb467 [Josh Rosen] Remove a bunch of overloaded methods to avoid default args. issue 2bbac9c [Josh Rosen] Merge remote-tracking branch 'origin/master' into sql-external-sort 35dad9f [Josh Rosen] Make sortAnswers = false the default in SparkPlanTest 5135200 [Josh Rosen] Fix spill reading for large rows; add test 2f48777 [Josh Rosen] Add test and fix bug for sorting empty arrays d1e28bc [Josh Rosen] Merge remote-tracking branch 'origin/master' into sql-external-sort cd05866 [Josh Rosen] Fix scalastyle 3947fc1 [Josh Rosen] Merge remote-tracking branch 'origin/master' into sql-external-sort d13ac55 [Josh Rosen] Hacky approach to copying of UnsafeRows for sort followed by limit. 845bea3 [Josh Rosen] Remove unnecessary zeroing of row conversion buffer c56ec18 [Josh Rosen] Clean up final row copying code. d31f180 [Josh Rosen] Re-enable NullType sorting test now that SPARK-8868 is fixed 844f4ca [Josh Rosen] Merge remote-tracking branch 'origin/master' into sql-external-sort 293f109 [Josh Rosen] Add missing license header. f99a612 [Josh Rosen] Fix bugs in string prefix comparison. 9d00afc [Josh Rosen] Clean up prefix comparators for integral types 88aff18 [Josh Rosen] NULL_PREFIX has to be negative infinity for floating point types 613e16f [Josh Rosen] Test with larger data. 1d7ffaa [Josh Rosen] Somewhat hacky fix for descending sorts 08701e7 [Josh Rosen] Fix prefix comparison of null primitives. b86e684 [Josh Rosen] Set global = true in UnsafeExternalSortSuite. 1c7bad8 [Josh Rosen] Make sorting of answers explicit in SparkPlanTest.checkAnswer(). b81a920 [Josh Rosen] Temporarily enable only the passing sort tests 5d6109d [Josh Rosen] Fix inconsistent handling / encoding of record lengths. 87b6ed9 [Josh Rosen] Fix critical issues in test which led to false negatives. 8d7fbe7 [Josh Rosen] Fixes to multiple spilling-related bugs. 82e21c1 [Josh Rosen] Force spilling in UnsafeExternalSortSuite. 88b72db [Josh Rosen] Test ascending and descending sort orders. f27be09 [Josh Rosen] Fix tests by binding attributes. 0a79d39 [Josh Rosen] Revert "Undo part of a SparkPlanTest change in #7162 that broke my test." 7c3c864 [Josh Rosen] Undo part of a SparkPlanTest change in #7162 that broke my test. 9969c14 [Josh Rosen] Merge remote-tracking branch 'origin/master' into sql-external-sort 5822e6f [Josh Rosen] Fix test compilation issue 939f824 [Josh Rosen] Remove code gen experiment. 0dfe919 [Josh Rosen] Implement prefix sort for strings (albeit inefficiently). 66a813e [Josh Rosen] Prefix comparators for float and double b310c88 [Josh Rosen] Integrate prefix comparators for Int and Long (others coming soon) 95058d9 [Josh Rosen] Add missing SortPrefixUtils file 4c37ba6 [Josh Rosen] Add tests for sorting on all primitive types. 6890863 [Josh Rosen] Fix memory leak on empty inputs. d246e29 [Josh Rosen] Fix consideration of column types when choosing sort implementation. 6b156fb [Josh Rosen] Some WIP work on prefix comparison. 7f875f9 [Josh Rosen] Commit failing test demonstrating bug in handling objects in spills 41b8881 [Josh Rosen] Get UnsafeInMemorySorterSuite to pass (WIP) 90c2b6a [Josh Rosen] Update test name 6d6a1e6 [Josh Rosen] Centralize logic for picking sort operator implementations 9869ec2 [Josh Rosen] Clean up Exchange code a bit 82bb0ec [Josh Rosen] Fix IntelliJ complaint due to negated if condition 1db845a [Josh Rosen] Many more changes to harmonize with shuffle sorter ebf9eea [Josh Rosen] Harmonization with shuffle's unsafe sorter 206bfa2 [Josh Rosen] Add some missing newlines at the ends of files 26c8931 [Josh Rosen] Back out some Hive changes that aren't needed anymore 62f0bb8 [Josh Rosen] Update to reflect SparkPlanTest changes 21d7d93 [Josh Rosen] Back out of BlockObjectWriter change 7eafecf [Josh Rosen] Port test to SparkPlanTest d468a88 [Josh Rosen] Update for InternalRow refactoring 269cf86 [Josh Rosen] Back out SMJ operator change; isolate changes to selection of sort op. 1b841ca [Josh Rosen] WIP towards copying b420a71 [Josh Rosen] Move most of the existing SMJ code into Java. dfdb93f [Josh Rosen] SparkFunSuite change 73cc761 [Josh Rosen] Fix whitespace 9cc98f5 [Josh Rosen] Move more code to Java; fix bugs in UnsafeRowConverter length type. c8792de [Josh Rosen] Remove some debug logging dda6752 [Josh Rosen] Commit some missing code from an old git stash. 58f36d0 [Josh Rosen] Merge in a sketch of a unit test for the new sorter (now failing). 2bd8c9a [Josh Rosen] Import my original tests and get them to pass. d5d3106 [Josh Rosen] WIP towards external sorter for Spark SQL.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java7
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java216
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/AbstractScalaRowIterator.scala27
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala97
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala22
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala73
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala11
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala253
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala104
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala21
11 files changed, 721 insertions, 123 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index edb7202245..4b99030d10 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -61,9 +61,10 @@ public final class UnsafeRow extends MutableRow {
/** A pool to hold non-primitive objects */
private ObjectPool pool;
- Object getBaseObject() { return baseObject; }
- long getBaseOffset() { return baseOffset; }
- ObjectPool getPool() { return pool; }
+ public Object getBaseObject() { return baseObject; }
+ public long getBaseOffset() { return baseOffset; }
+ public int getSizeInBytes() { return sizeInBytes; }
+ public ObjectPool getPool() { return pool; }
/** The number of fields in this row, used for calculating the bitset width (and in assertions) */
private int numFields;
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
new file mode 100644
index 0000000000..b94601cf6d
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution;
+
+import java.io.IOException;
+
+import scala.collection.Iterator;
+import scala.math.Ordering;
+
+import com.google.common.annotations.VisibleForTesting;
+
+import org.apache.spark.SparkEnv;
+import org.apache.spark.TaskContext;
+import org.apache.spark.sql.AbstractScalaRowIterator;
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.catalyst.expressions.ObjectUnsafeColumnWriter;
+import org.apache.spark.sql.catalyst.expressions.UnsafeColumnWriter;
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
+import org.apache.spark.sql.catalyst.expressions.UnsafeRowConverter;
+import org.apache.spark.sql.catalyst.util.ObjectPool;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.util.collection.unsafe.sort.PrefixComparator;
+import org.apache.spark.util.collection.unsafe.sort.RecordComparator;
+import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter;
+import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator;
+
+final class UnsafeExternalRowSorter {
+
+ /**
+ * If positive, forces records to be spilled to disk at the given frequency (measured in numbers
+ * of records). This is only intended to be used in tests.
+ */
+ private int testSpillFrequency = 0;
+
+ private long numRowsInserted = 0;
+
+ private final StructType schema;
+ private final UnsafeRowConverter rowConverter;
+ private final PrefixComputer prefixComputer;
+ private final UnsafeExternalSorter sorter;
+ private byte[] rowConversionBuffer = new byte[1024 * 8];
+
+ public static abstract class PrefixComputer {
+ abstract long computePrefix(InternalRow row);
+ }
+
+ public UnsafeExternalRowSorter(
+ StructType schema,
+ Ordering<InternalRow> ordering,
+ PrefixComparator prefixComparator,
+ PrefixComputer prefixComputer) throws IOException {
+ this.schema = schema;
+ this.rowConverter = new UnsafeRowConverter(schema);
+ this.prefixComputer = prefixComputer;
+ final SparkEnv sparkEnv = SparkEnv.get();
+ final TaskContext taskContext = TaskContext.get();
+ sorter = new UnsafeExternalSorter(
+ taskContext.taskMemoryManager(),
+ sparkEnv.shuffleMemoryManager(),
+ sparkEnv.blockManager(),
+ taskContext,
+ new RowComparator(ordering, schema.length(), null),
+ prefixComparator,
+ 4096,
+ sparkEnv.conf()
+ );
+ }
+
+ /**
+ * Forces spills to occur every `frequency` records. Only for use in tests.
+ */
+ @VisibleForTesting
+ void setTestSpillFrequency(int frequency) {
+ assert frequency > 0 : "Frequency must be positive";
+ testSpillFrequency = frequency;
+ }
+
+ @VisibleForTesting
+ void insertRow(InternalRow row) throws IOException {
+ final int sizeRequirement = rowConverter.getSizeRequirement(row);
+ if (sizeRequirement > rowConversionBuffer.length) {
+ rowConversionBuffer = new byte[sizeRequirement];
+ }
+ final int bytesWritten = rowConverter.writeRow(
+ row, rowConversionBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, sizeRequirement, null);
+ assert (bytesWritten == sizeRequirement);
+ final long prefix = prefixComputer.computePrefix(row);
+ sorter.insertRecord(
+ rowConversionBuffer,
+ PlatformDependent.BYTE_ARRAY_OFFSET,
+ sizeRequirement,
+ prefix
+ );
+ numRowsInserted++;
+ if (testSpillFrequency > 0 && (numRowsInserted % testSpillFrequency) == 0) {
+ spill();
+ }
+ }
+
+ @VisibleForTesting
+ void spill() throws IOException {
+ sorter.spill();
+ }
+
+ private void cleanupResources() {
+ sorter.freeMemory();
+ }
+
+ @VisibleForTesting
+ Iterator<InternalRow> sort() throws IOException {
+ try {
+ final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator();
+ if (!sortedIterator.hasNext()) {
+ // Since we won't ever call next() on an empty iterator, we need to clean up resources
+ // here in order to prevent memory leaks.
+ cleanupResources();
+ }
+ return new AbstractScalaRowIterator() {
+
+ private final int numFields = schema.length();
+ private final UnsafeRow row = new UnsafeRow();
+
+ @Override
+ public boolean hasNext() {
+ return sortedIterator.hasNext();
+ }
+
+ @Override
+ public InternalRow next() {
+ try {
+ sortedIterator.loadNext();
+ row.pointTo(
+ sortedIterator.getBaseObject(),
+ sortedIterator.getBaseOffset(),
+ numFields,
+ sortedIterator.getRecordLength(),
+ null);
+ if (!hasNext()) {
+ row.copy(); // so that we don't have dangling pointers to freed page
+ cleanupResources();
+ }
+ return row;
+ } catch (IOException e) {
+ cleanupResources();
+ // Scala iterators don't declare any checked exceptions, so we need to use this hack
+ // to re-throw the exception:
+ PlatformDependent.throwException(e);
+ }
+ throw new RuntimeException("Exception should have been re-thrown in next()");
+ };
+ };
+ } catch (IOException e) {
+ cleanupResources();
+ throw e;
+ }
+ }
+
+
+ public Iterator<InternalRow> sort(Iterator<InternalRow> inputIterator) throws IOException {
+ while (inputIterator.hasNext()) {
+ insertRow(inputIterator.next());
+ }
+ return sort();
+ }
+
+ /**
+ * Return true if UnsafeExternalRowSorter can sort rows with the given schema, false otherwise.
+ */
+ public static boolean supportsSchema(StructType schema) {
+ // TODO: add spilling note to explain why we do this for now:
+ for (StructField field : schema.fields()) {
+ if (UnsafeColumnWriter.forType(field.dataType()) instanceof ObjectUnsafeColumnWriter) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ private static final class RowComparator extends RecordComparator {
+ private final Ordering<InternalRow> ordering;
+ private final int numFields;
+ private final ObjectPool objPool;
+ private final UnsafeRow row1 = new UnsafeRow();
+ private final UnsafeRow row2 = new UnsafeRow();
+
+ public RowComparator(Ordering<InternalRow> ordering, int numFields, ObjectPool objPool) {
+ this.numFields = numFields;
+ this.ordering = ordering;
+ this.objPool = objPool;
+ }
+
+ @Override
+ public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) {
+ row1.pointTo(baseObj1, baseOff1, numFields, -1, objPool);
+ row2.pointTo(baseObj2, baseOff2, numFields, -1, objPool);
+ return ordering.compare(row1, row2);
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AbstractScalaRowIterator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AbstractScalaRowIterator.scala
new file mode 100644
index 0000000000..cfefb13e77
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AbstractScalaRowIterator.scala
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.catalyst.InternalRow
+
+/**
+ * Shim to allow us to implement [[scala.Iterator]] in Java. Scala 2.11+ has an AbstractIterator
+ * class for this, but that class is `private[scala]` in 2.10. We need to explicitly fix this to
+ * `Row` in order to work around a spurious IntelliJ compiler error.
+ */
+private[spark] abstract class AbstractScalaRowIterator extends Iterator[InternalRow]
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 74d9334045..4b783e30d9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -289,11 +289,8 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
}
val withSort = if (needSort) {
- if (sqlContext.conf.externalSortEnabled) {
- ExternalSort(rowOrdering, global = false, withShuffle)
- } else {
- Sort(rowOrdering, global = false, withShuffle)
- }
+ sqlContext.planner.BasicOperators.getSortOperator(
+ rowOrdering, global = false, withShuffle)
} else {
withShuffle
}
@@ -321,11 +318,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
case (UnspecifiedDistribution, Seq(), child) =>
child
case (UnspecifiedDistribution, rowOrdering, child) =>
- if (sqlContext.conf.externalSortEnabled) {
- ExternalSort(rowOrdering, global = false, child)
- } else {
- Sort(rowOrdering, global = false, child)
- }
+ sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child)
case (dist, ordering, _) =>
sys.error(s"Don't know how to ensure $dist with ordering $ordering")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
new file mode 100644
index 0000000000..2dee3542d6
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.SortOrder
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, PrefixComparator}
+
+
+object SortPrefixUtils {
+
+ /**
+ * A dummy prefix comparator which always claims that prefixes are equal. This is used in cases
+ * where we don't know how to generate or compare prefixes for a SortOrder.
+ */
+ private object NoOpPrefixComparator extends PrefixComparator {
+ override def compare(prefix1: Long, prefix2: Long): Int = 0
+ }
+
+ def getPrefixComparator(sortOrder: SortOrder): PrefixComparator = {
+ sortOrder.dataType match {
+ case StringType => PrefixComparators.STRING
+ case BooleanType | ByteType | ShortType | IntegerType | LongType => PrefixComparators.INTEGRAL
+ case FloatType => PrefixComparators.FLOAT
+ case DoubleType => PrefixComparators.DOUBLE
+ case _ => NoOpPrefixComparator
+ }
+ }
+
+ def getPrefixComputer(sortOrder: SortOrder): InternalRow => Long = {
+ sortOrder.dataType match {
+ case StringType => (row: InternalRow) => {
+ PrefixComparators.STRING.computePrefix(sortOrder.child.eval(row).asInstanceOf[UTF8String])
+ }
+ case BooleanType =>
+ (row: InternalRow) => {
+ val exprVal = sortOrder.child.eval(row)
+ if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX
+ else if (sortOrder.child.eval(row).asInstanceOf[Boolean]) 1
+ else 0
+ }
+ case ByteType =>
+ (row: InternalRow) => {
+ val exprVal = sortOrder.child.eval(row)
+ if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX
+ else sortOrder.child.eval(row).asInstanceOf[Byte]
+ }
+ case ShortType =>
+ (row: InternalRow) => {
+ val exprVal = sortOrder.child.eval(row)
+ if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX
+ else sortOrder.child.eval(row).asInstanceOf[Short]
+ }
+ case IntegerType =>
+ (row: InternalRow) => {
+ val exprVal = sortOrder.child.eval(row)
+ if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX
+ else sortOrder.child.eval(row).asInstanceOf[Int]
+ }
+ case LongType =>
+ (row: InternalRow) => {
+ val exprVal = sortOrder.child.eval(row)
+ if (exprVal == null) PrefixComparators.INTEGRAL.NULL_PREFIX
+ else sortOrder.child.eval(row).asInstanceOf[Long]
+ }
+ case FloatType => (row: InternalRow) => {
+ val exprVal = sortOrder.child.eval(row)
+ if (exprVal == null) PrefixComparators.FLOAT.NULL_PREFIX
+ else PrefixComparators.FLOAT.computePrefix(sortOrder.child.eval(row).asInstanceOf[Float])
+ }
+ case DoubleType => (row: InternalRow) => {
+ val exprVal = sortOrder.child.eval(row)
+ if (exprVal == null) PrefixComparators.DOUBLE.NULL_PREFIX
+ else PrefixComparators.DOUBLE.computePrefix(sortOrder.child.eval(row).asInstanceOf[Double])
+ }
+ case _ => (row: InternalRow) => 0L
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 59b9b553a7..ce25af58b6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -302,6 +302,22 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
object BasicOperators extends Strategy {
def numPartitions: Int = self.numPartitions
+ /**
+ * Picks an appropriate sort operator.
+ *
+ * @param global when true performs a global sort of all partitions by shuffling the data first
+ * if necessary.
+ */
+ def getSortOperator(sortExprs: Seq[SortOrder], global: Boolean, child: SparkPlan): SparkPlan = {
+ if (sqlContext.conf.unsafeEnabled && UnsafeExternalSort.supportsSchema(child.schema)) {
+ execution.UnsafeExternalSort(sortExprs, global, child)
+ } else if (sqlContext.conf.externalSortEnabled) {
+ execution.ExternalSort(sortExprs, global, child)
+ } else {
+ execution.Sort(sortExprs, global, child)
+ }
+ }
+
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case r: RunnableCommand => ExecutedCommand(r) :: Nil
@@ -313,11 +329,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.SortPartitions(sortExprs, child) =>
// This sort only sorts tuples within a partition. Its requiredDistribution will be
// an UnspecifiedDistribution.
- execution.Sort(sortExprs, global = false, planLater(child)) :: Nil
- case logical.Sort(sortExprs, global, child) if sqlContext.conf.externalSortEnabled =>
- execution.ExternalSort(sortExprs, global, planLater(child)):: Nil
+ getSortOperator(sortExprs, global = false, planLater(child)) :: Nil
case logical.Sort(sortExprs, global, child) =>
- execution.Sort(sortExprs, global, planLater(child)):: Nil
+ getSortOperator(sortExprs, global, planLater(child)):: Nil
case logical.Project(projectList, child) =>
execution.Project(projectList, planLater(child)) :: Nil
case logical.Filter(condition, child) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index de14e6ad79..4c063c299b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.execution
+import org.apache.spark.sql.types.StructType
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.{RDD, ShuffledRDD}
import org.apache.spark.shuffle.sort.SortShuffleManager
@@ -27,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.util.collection.ExternalSorter
+import org.apache.spark.util.collection.unsafe.sort.PrefixComparator
import org.apache.spark.util.{CompletionIterator, MutablePair}
import org.apache.spark.{HashPartitioner, SparkEnv}
@@ -248,6 +250,77 @@ case class ExternalSort(
/**
* :: DeveloperApi ::
+ * Optimized version of [[ExternalSort]] that operates on binary data (implemented as part of
+ * Project Tungsten).
+ *
+ * @param global when true performs a global sort of all partitions by shuffling the data first
+ * if necessary.
+ * @param testSpillFrequency Method for configuring periodic spilling in unit tests. If set, will
+ * spill every `frequency` records.
+ */
+@DeveloperApi
+case class UnsafeExternalSort(
+ sortOrder: Seq[SortOrder],
+ global: Boolean,
+ child: SparkPlan,
+ testSpillFrequency: Int = 0)
+ extends UnaryNode {
+
+ private[this] val schema: StructType = child.schema
+
+ override def requiredChildDistribution: Seq[Distribution] =
+ if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil
+
+ protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") {
+ assert(codegenEnabled, "UnsafeExternalSort requires code generation to be enabled")
+ def doSort(iterator: Iterator[InternalRow]): Iterator[InternalRow] = {
+ val ordering = newOrdering(sortOrder, child.output)
+ val boundSortExpression = BindReferences.bindReference(sortOrder.head, child.output)
+ // Hack until we generate separate comparator implementations for ascending vs. descending
+ // (or choose to codegen them):
+ val prefixComparator = {
+ val comp = SortPrefixUtils.getPrefixComparator(boundSortExpression)
+ if (sortOrder.head.direction == Descending) {
+ new PrefixComparator {
+ override def compare(p1: Long, p2: Long): Int = -1 * comp.compare(p1, p2)
+ }
+ } else {
+ comp
+ }
+ }
+ val prefixComputer = {
+ val prefixComputer = SortPrefixUtils.getPrefixComputer(boundSortExpression)
+ new UnsafeExternalRowSorter.PrefixComputer {
+ override def computePrefix(row: InternalRow): Long = prefixComputer(row)
+ }
+ }
+ val sorter = new UnsafeExternalRowSorter(schema, ordering, prefixComparator, prefixComputer)
+ if (testSpillFrequency > 0) {
+ sorter.setTestSpillFrequency(testSpillFrequency)
+ }
+ sorter.sort(iterator)
+ }
+ child.execute().mapPartitions(doSort, preservesPartitioning = true)
+ }
+
+ override def output: Seq[Attribute] = child.output
+
+ override def outputOrdering: Seq[SortOrder] = sortOrder
+}
+
+@DeveloperApi
+object UnsafeExternalSort {
+ /**
+ * Return true if UnsafeExternalSort can sort rows with the given schema, false otherwise.
+ */
+ def supportsSchema(schema: StructType): Boolean = {
+ UnsafeExternalRowSorter.supportsSchema(schema)
+ }
+}
+
+
+/**
+ * :: DeveloperApi ::
* Return a new RDD that has exactly `numPartitions` partitions.
*/
@DeveloperApi
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
index a1e3ca11b1..a2c10fdaf6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.execution
+import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
class SortSuite extends SparkPlanTest {
@@ -33,12 +34,14 @@ class SortSuite extends SparkPlanTest {
checkAnswer(
input.toDF("a", "b", "c"),
- ExternalSort('a.asc :: 'b.asc :: Nil, global = false, _: SparkPlan),
- input.sorted)
+ ExternalSort('a.asc :: 'b.asc :: Nil, global = true, _: SparkPlan),
+ input.sortBy(t => (t._1, t._2)).map(Row.fromTuple),
+ sortAnswers = false)
checkAnswer(
input.toDF("a", "b", "c"),
- ExternalSort('b.asc :: 'a.asc :: Nil, global = false, _: SparkPlan),
- input.sortBy(t => (t._2, t._1)))
+ ExternalSort('b.asc :: 'a.asc :: Nil, global = true, _: SparkPlan),
+ input.sortBy(t => (t._2, t._1)).map(Row.fromTuple),
+ sortAnswers = false)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
index 108b1122f7..6a8f394545 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
@@ -17,18 +17,15 @@
package org.apache.spark.sql.execution
-import scala.language.implicitConversions
-import scala.reflect.runtime.universe.TypeTag
-import scala.util.control.NonFatal
-
import org.apache.spark.SparkFunSuite
-
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
-import org.apache.spark.sql.catalyst.expressions.BoundReference
import org.apache.spark.sql.catalyst.util._
-
import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.{DataFrameHolder, Row, DataFrame}
+import org.apache.spark.sql.{DataFrame, DataFrameHolder, Row}
+
+import scala.language.implicitConversions
+import scala.reflect.runtime.universe.TypeTag
+import scala.util.control.NonFatal
/**
* Base class for writing tests for individual physical operators. For an example of how this
@@ -49,12 +46,19 @@ class SparkPlanTest extends SparkFunSuite {
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
* the physical operator that's being tested.
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
+ * @param sortAnswers if true, the answers will be sorted by their toString representations prior
+ * to being compared.
*/
protected def checkAnswer(
input: DataFrame,
planFunction: SparkPlan => SparkPlan,
- expectedAnswer: Seq[Row]): Unit = {
- checkAnswer(input :: Nil, (plans: Seq[SparkPlan]) => planFunction(plans.head), expectedAnswer)
+ expectedAnswer: Seq[Row],
+ sortAnswers: Boolean = true): Unit = {
+ doCheckAnswer(
+ input :: Nil,
+ (plans: Seq[SparkPlan]) => planFunction(plans.head),
+ expectedAnswer,
+ sortAnswers)
}
/**
@@ -64,86 +68,131 @@ class SparkPlanTest extends SparkFunSuite {
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
* the physical operator that's being tested.
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
+ * @param sortAnswers if true, the answers will be sorted by their toString representations prior
+ * to being compared.
*/
- protected def checkAnswer(
+ protected def checkAnswer2(
left: DataFrame,
right: DataFrame,
planFunction: (SparkPlan, SparkPlan) => SparkPlan,
- expectedAnswer: Seq[Row]): Unit = {
- checkAnswer(left :: right :: Nil,
- (plans: Seq[SparkPlan]) => planFunction(plans(0), plans(1)), expectedAnswer)
+ expectedAnswer: Seq[Row],
+ sortAnswers: Boolean = true): Unit = {
+ doCheckAnswer(
+ left :: right :: Nil,
+ (plans: Seq[SparkPlan]) => planFunction(plans(0), plans(1)),
+ expectedAnswer,
+ sortAnswers)
}
/**
* Runs the plan and makes sure the answer matches the expected result.
* @param input the input data to be used.
- * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
- * the physical operator that's being tested.
+ * @param planFunction a function which accepts a sequence of input SparkPlans and uses them to
+ * instantiate the physical operator that's being tested.
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
+ * @param sortAnswers if true, the answers will be sorted by their toString representations prior
+ * to being compared.
*/
- protected def checkAnswer(
+ protected def doCheckAnswer(
input: Seq[DataFrame],
planFunction: Seq[SparkPlan] => SparkPlan,
- expectedAnswer: Seq[Row]): Unit = {
- SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer) match {
+ expectedAnswer: Seq[Row],
+ sortAnswers: Boolean = true): Unit = {
+ SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
}
}
/**
- * Runs the plan and makes sure the answer matches the expected result.
+ * Runs the plan and makes sure the answer matches the result produced by a reference plan.
* @param input the input data to be used.
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
* the physical operator that's being tested.
- * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s.
+ * @param expectedPlanFunction a function which accepts the input SparkPlan and uses it to
+ * instantiate a reference implementation of the physical operator
+ * that's being tested. The result of executing this plan will be
+ * treated as the source-of-truth for the test.
+ * @param sortAnswers if true, the answers will be sorted by their toString representations prior
+ * to being compared.
*/
- protected def checkAnswer[A <: Product : TypeTag](
+ protected def checkThatPlansAgree(
input: DataFrame,
planFunction: SparkPlan => SparkPlan,
- expectedAnswer: Seq[A]): Unit = {
- val expectedRows = expectedAnswer.map(Row.fromTuple)
- checkAnswer(input, planFunction, expectedRows)
+ expectedPlanFunction: SparkPlan => SparkPlan,
+ sortAnswers: Boolean = true): Unit = {
+ SparkPlanTest.checkAnswer(input, planFunction, expectedPlanFunction, sortAnswers) match {
+ case Some(errorMessage) => fail(errorMessage)
+ case None =>
+ }
}
+}
- /**
- * Runs the plan and makes sure the answer matches the expected result.
- * @param left the left input data to be used.
- * @param right the right input data to be used.
- * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
- * the physical operator that's being tested.
- * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s.
- */
- protected def checkAnswer[A <: Product : TypeTag](
- left: DataFrame,
- right: DataFrame,
- planFunction: (SparkPlan, SparkPlan) => SparkPlan,
- expectedAnswer: Seq[A]): Unit = {
- val expectedRows = expectedAnswer.map(Row.fromTuple)
- checkAnswer(left, right, planFunction, expectedRows)
- }
+/**
+ * Helper methods for writing tests of individual physical operators.
+ */
+object SparkPlanTest {
/**
- * Runs the plan and makes sure the answer matches the expected result.
+ * Runs the plan and makes sure the answer matches the result produced by a reference plan.
* @param input the input data to be used.
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
* the physical operator that's being tested.
- * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s.
+ * @param expectedPlanFunction a function which accepts the input SparkPlan and uses it to
+ * instantiate a reference implementation of the physical operator
+ * that's being tested. The result of executing this plan will be
+ * treated as the source-of-truth for the test.
*/
- protected def checkAnswer[A <: Product : TypeTag](
- input: Seq[DataFrame],
- planFunction: Seq[SparkPlan] => SparkPlan,
- expectedAnswer: Seq[A]): Unit = {
- val expectedRows = expectedAnswer.map(Row.fromTuple)
- checkAnswer(input, planFunction, expectedRows)
- }
+ def checkAnswer(
+ input: DataFrame,
+ planFunction: SparkPlan => SparkPlan,
+ expectedPlanFunction: SparkPlan => SparkPlan,
+ sortAnswers: Boolean): Option[String] = {
-}
+ val outputPlan = planFunction(input.queryExecution.sparkPlan)
+ val expectedOutputPlan = expectedPlanFunction(input.queryExecution.sparkPlan)
-/**
- * Helper methods for writing tests of individual physical operators.
- */
-object SparkPlanTest {
+ val expectedAnswer: Seq[Row] = try {
+ executePlan(expectedOutputPlan)
+ } catch {
+ case NonFatal(e) =>
+ val errorMessage =
+ s"""
+ | Exception thrown while executing Spark plan to calculate expected answer:
+ | $expectedOutputPlan
+ | == Exception ==
+ | $e
+ | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
+ """.stripMargin
+ return Some(errorMessage)
+ }
+
+ val actualAnswer: Seq[Row] = try {
+ executePlan(outputPlan)
+ } catch {
+ case NonFatal(e) =>
+ val errorMessage =
+ s"""
+ | Exception thrown while executing Spark plan:
+ | $outputPlan
+ | == Exception ==
+ | $e
+ | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
+ """.stripMargin
+ return Some(errorMessage)
+ }
+
+ compareAnswers(actualAnswer, expectedAnswer, sortAnswers).map { errorMessage =>
+ s"""
+ | Results do not match.
+ | Actual result Spark plan:
+ | $outputPlan
+ | Expected result Spark plan:
+ | $expectedOutputPlan
+ | $errorMessage
+ """.stripMargin
+ }
+ }
/**
* Runs the plan and makes sure the answer matches the expected result.
@@ -151,28 +200,45 @@ object SparkPlanTest {
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
* the physical operator that's being tested.
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
+ * @param sortAnswers if true, the answers will be sorted by their toString representations prior
+ * to being compared.
*/
def checkAnswer(
input: Seq[DataFrame],
planFunction: Seq[SparkPlan] => SparkPlan,
- expectedAnswer: Seq[Row]): Option[String] = {
+ expectedAnswer: Seq[Row],
+ sortAnswers: Boolean): Option[String] = {
val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan))
- // A very simple resolver to make writing tests easier. In contrast to the real resolver
- // this is always case sensitive and does not try to handle scoping or complex type resolution.
- val resolvedPlan = TestSQLContext.prepareForExecution.execute(
- outputPlan transform {
- case plan: SparkPlan =>
- val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap
- plan.transformExpressions {
- case UnresolvedAttribute(Seq(u)) =>
- inputMap.getOrElse(u,
- sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap"))
- }
- }
- )
+ val sparkAnswer: Seq[Row] = try {
+ executePlan(outputPlan)
+ } catch {
+ case NonFatal(e) =>
+ val errorMessage =
+ s"""
+ | Exception thrown while executing Spark plan:
+ | $outputPlan
+ | == Exception ==
+ | $e
+ | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
+ """.stripMargin
+ return Some(errorMessage)
+ }
+ compareAnswers(sparkAnswer, expectedAnswer, sortAnswers).map { errorMessage =>
+ s"""
+ | Results do not match for Spark plan:
+ | $outputPlan
+ | $errorMessage
+ """.stripMargin
+ }
+ }
+
+ private def compareAnswers(
+ sparkAnswer: Seq[Row],
+ expectedAnswer: Seq[Row],
+ sort: Boolean): Option[String] = {
def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
// Converts data to types that we can do equality comparison using Scala collections.
// For BigDecimal type, the Scala type has a better definition of equality test (similar to
@@ -187,40 +253,43 @@ object SparkPlanTest {
case o => o
})
}
- converted.sortBy(_.toString())
- }
-
- val sparkAnswer: Seq[Row] = try {
- resolvedPlan.executeCollect().toSeq
- } catch {
- case NonFatal(e) =>
- val errorMessage =
- s"""
- | Exception thrown while executing Spark plan:
- | $outputPlan
- | == Exception ==
- | $e
- | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
- """.stripMargin
- return Some(errorMessage)
+ if (sort) {
+ converted.sortBy(_.toString())
+ } else {
+ converted
+ }
}
-
if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) {
val errorMessage =
s"""
- | Results do not match for Spark plan:
- | $outputPlan
| == Results ==
| ${sideBySide(
- s"== Correct Answer - ${expectedAnswer.size} ==" +:
+ s"== Expected Answer - ${expectedAnswer.size} ==" +:
prepareAnswer(expectedAnswer).map(_.toString()),
- s"== Spark Answer - ${sparkAnswer.size} ==" +:
+ s"== Actual Answer - ${sparkAnswer.size} ==" +:
prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")}
""".stripMargin
- return Some(errorMessage)
+ Some(errorMessage)
+ } else {
+ None
}
+ }
- None
+ private def executePlan(outputPlan: SparkPlan): Seq[Row] = {
+ // A very simple resolver to make writing tests easier. In contrast to the real resolver
+ // this is always case sensitive and does not try to handle scoping or complex type resolution.
+ val resolvedPlan = TestSQLContext.prepareForExecution.execute(
+ outputPlan transform {
+ case plan: SparkPlan =>
+ val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap
+ plan.transformExpressions {
+ case UnresolvedAttribute(Seq(u)) =>
+ inputMap.getOrElse(u,
+ sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap"))
+ }
+ }
+ )
+ resolvedPlan.executeCollect().toSeq
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala
new file mode 100644
index 0000000000..4f4c1f2856
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import scala.util.Random
+
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.sql.{RandomDataGenerator, Row, SQLConf}
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.types._
+
+class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll {
+
+ override def beforeAll(): Unit = {
+ TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true)
+ }
+
+ override def afterAll(): Unit = {
+ TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get)
+ }
+
+ ignore("sort followed by limit should not leak memory") {
+ // TODO: this test is going to fail until we implement a proper iterator interface
+ // with a close() method.
+ TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "true")
+ checkThatPlansAgree(
+ (1 to 100).map(v => Tuple1(v)).toDF("a"),
+ (child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)),
+ (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)),
+ sortAnswers = false
+ )
+ }
+
+ test("sort followed by limit") {
+ TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "false")
+ try {
+ checkThatPlansAgree(
+ (1 to 100).map(v => Tuple1(v)).toDF("a"),
+ (child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)),
+ (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)),
+ sortAnswers = false
+ )
+ } finally {
+ TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "true")
+
+ }
+ }
+
+ test("sorting does not crash for large inputs") {
+ val sortOrder = 'a.asc :: Nil
+ val stringLength = 1024 * 1024 * 2
+ checkThatPlansAgree(
+ Seq(Tuple1("a" * stringLength), Tuple1("b" * stringLength)).toDF("a").repartition(1),
+ UnsafeExternalSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1),
+ Sort(sortOrder, global = true, _: SparkPlan),
+ sortAnswers = false
+ )
+ }
+
+ // Test sorting on different data types
+ for (
+ dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType)
+ if !dataType.isInstanceOf[DecimalType]; // We don't have an unsafe representation for decimals
+ nullable <- Seq(true, false);
+ sortOrder <- Seq('a.asc :: Nil, 'a.desc :: Nil);
+ randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable)
+ ) {
+ test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") {
+ val inputData = Seq.fill(1000)(randomDataGenerator()).filter {
+ case d: Double => !d.isNaN
+ case f: Float => !java.lang.Float.isNaN(f)
+ case x => true
+ }
+ val inputDf = TestSQLContext.createDataFrame(
+ TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))),
+ StructType(StructField("a", dataType, nullable = true) :: Nil)
+ )
+ assert(UnsafeExternalSort.supportsSchema(inputDf.schema))
+ checkThatPlansAgree(
+ inputDf,
+ UnsafeExternalSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 23),
+ Sort(sortOrder, global = true, _: SparkPlan),
+ sortAnswers = false
+ )
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
index 5707d2fb30..2c27da596b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.execution.joins
+import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{Expression, LessThan}
import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, RightOuter}
@@ -41,23 +42,23 @@ class OuterJoinSuite extends SparkPlanTest {
val condition = Some(LessThan('b, 'd))
test("shuffled hash outer join") {
- checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) =>
+ checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
ShuffledHashOuterJoin(leftKeys, rightKeys, LeftOuter, condition, left, right),
Seq(
(1, 2.0, null, null),
(2, 1.0, 2, 3.0),
(3, 3.0, null, null)
- ))
+ ).map(Row.fromTuple))
- checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) =>
+ checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
ShuffledHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right),
Seq(
(2, 1.0, 2, 3.0),
(null, null, 3, 2.0),
(null, null, 4, 1.0)
- ))
+ ).map(Row.fromTuple))
- checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) =>
+ checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
ShuffledHashOuterJoin(leftKeys, rightKeys, FullOuter, condition, left, right),
Seq(
(1, 2.0, null, null),
@@ -65,24 +66,24 @@ class OuterJoinSuite extends SparkPlanTest {
(3, 3.0, null, null),
(null, null, 3, 2.0),
(null, null, 4, 1.0)
- ))
+ ).map(Row.fromTuple))
}
test("broadcast hash outer join") {
- checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) =>
+ checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
BroadcastHashOuterJoin(leftKeys, rightKeys, LeftOuter, condition, left, right),
Seq(
(1, 2.0, null, null),
(2, 1.0, 2, 3.0),
(3, 3.0, null, null)
- ))
+ ).map(Row.fromTuple))
- checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) =>
+ checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
BroadcastHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right),
Seq(
(2, 1.0, 2, 3.0),
(null, null, 3, 2.0),
(null, null, 4, 1.0)
- ))
+ ).map(Row.fromTuple))
}
}